mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	wire up different make_spangroups for single and multilabel
This commit is contained in:
		
							parent
							
								
									52e7324df4
								
							
						
					
					
						commit
						dceeb02b94
					
				|  | @ -237,6 +237,7 @@ def make_spancat_singlelabel( | |||
|         allow_overlap=allow_overlap, | ||||
|         name=name, | ||||
|         scorer=scorer, | ||||
|         single_label=True | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -463,9 +464,23 @@ class SpanCategorizer(TrainablePipe): | |||
|         offset = 0 | ||||
|         for i, doc in enumerate(docs): | ||||
|             indices_i = indices[i].dataXd | ||||
|             doc.spans[self.key] = self._make_span_group( | ||||
|                 doc, indices_i, scores[offset : offset + indices.lengths[i]], labels  # type: ignore[arg-type] | ||||
|             if self.single_label: | ||||
|                 allow_overlap = cast(bool, self.cfg["allow_overlap"]) | ||||
|                 doc.spans[self.key] = self._make_span_group_singlelabel( | ||||
|                     doc, | ||||
|                     indices_i, | ||||
|                     scores[offset : offset + indices.lengths[i]], | ||||
|                     labels,  # type: ignore[arg-type] | ||||
|                     allow_overlap | ||||
|                 ) | ||||
|             else: | ||||
|                 doc.spans[self.key] = self._make_span_group_multilabel( | ||||
|                     doc, | ||||
|                     indices_i, | ||||
|                     scores[offset : offset + indices.lengths[i]], | ||||
|                     labels,  # type: ignore[arg-type] | ||||
|                 ) | ||||
| 
 | ||||
|             offset += indices.lengths[i] | ||||
| 
 | ||||
|     def update( | ||||
|  |  | |||
|  | @ -129,7 +129,7 @@ def test_make_spangroup(max_positive, nr_results): | |||
|     scores = numpy.asarray( | ||||
|         [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" | ||||
|     ) | ||||
|     spangroup = spancat._make_span_group(doc, indices, scores, labels) | ||||
|     spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels) | ||||
|     assert len(spangroup) == nr_results | ||||
| 
 | ||||
|     # first span is always the second token "London" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user