mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	wire up different make_spangroups for single and multilabel
This commit is contained in:
		
							parent
							
								
									52e7324df4
								
							
						
					
					
						commit
						dceeb02b94
					
				|  | @ -237,6 +237,7 @@ def make_spancat_singlelabel( | ||||||
|         allow_overlap=allow_overlap, |         allow_overlap=allow_overlap, | ||||||
|         name=name, |         name=name, | ||||||
|         scorer=scorer, |         scorer=scorer, | ||||||
|  |         single_label=True | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -463,9 +464,23 @@ class SpanCategorizer(TrainablePipe): | ||||||
|         offset = 0 |         offset = 0 | ||||||
|         for i, doc in enumerate(docs): |         for i, doc in enumerate(docs): | ||||||
|             indices_i = indices[i].dataXd |             indices_i = indices[i].dataXd | ||||||
|             doc.spans[self.key] = self._make_span_group( |             if self.single_label: | ||||||
|                 doc, indices_i, scores[offset : offset + indices.lengths[i]], labels  # type: ignore[arg-type] |                 allow_overlap = cast(bool, self.cfg["allow_overlap"]) | ||||||
|  |                 doc.spans[self.key] = self._make_span_group_singlelabel( | ||||||
|  |                     doc, | ||||||
|  |                     indices_i, | ||||||
|  |                     scores[offset : offset + indices.lengths[i]], | ||||||
|  |                     labels,  # type: ignore[arg-type] | ||||||
|  |                     allow_overlap | ||||||
|                 ) |                 ) | ||||||
|  |             else: | ||||||
|  |                 doc.spans[self.key] = self._make_span_group_multilabel( | ||||||
|  |                     doc, | ||||||
|  |                     indices_i, | ||||||
|  |                     scores[offset : offset + indices.lengths[i]], | ||||||
|  |                     labels,  # type: ignore[arg-type] | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|             offset += indices.lengths[i] |             offset += indices.lengths[i] | ||||||
| 
 | 
 | ||||||
|     def update( |     def update( | ||||||
|  |  | ||||||
|  | @ -129,7 +129,7 @@ def test_make_spangroup(max_positive, nr_results): | ||||||
|     scores = numpy.asarray( |     scores = numpy.asarray( | ||||||
|         [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" |         [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" | ||||||
|     ) |     ) | ||||||
|     spangroup = spancat._make_span_group(doc, indices, scores, labels) |     spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels) | ||||||
|     assert len(spangroup) == nr_results |     assert len(spangroup) == nr_results | ||||||
| 
 | 
 | ||||||
|     # first span is always the second token "London" |     # first span is always the second token "London" | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user