mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-29 06:57:49 +03:00 
			
		
		
		
	single label make_spangroup test
This commit is contained in:
		
							parent
							
								
									6fc25f64dd
								
							
						
					
					
						commit
						ec941a128d
					
				|  | @ -108,13 +108,13 @@ def test_doc_gc(): | |||
|             # XXX This fails with length 0 sometimes | ||||
|             assert len(spangroup) > 0 | ||||
|             with pytest.raises(RuntimeError): | ||||
|                 span = spangroup[0] | ||||
|                 spangroup[0] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)] | ||||
| ) | ||||
| def test_make_spangroup(max_positive, nr_results): | ||||
| def test_make_spangroup_multilabel(max_positive, nr_results): | ||||
|     fix_random_seed(0) | ||||
|     nlp = Language() | ||||
|     spancat = nlp.add_pipe( | ||||
|  | @ -160,6 +160,55 @@ def test_make_spangroup(max_positive, nr_results): | |||
|     assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "threshold,allow_overlap,nr_results", [ | ||||
|         (0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)] | ||||
| ) | ||||
| def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results): | ||||
|     fix_random_seed(0) | ||||
|     nlp = Language() | ||||
|     spancat = nlp.add_pipe( | ||||
|         "spancat", | ||||
|         config={ | ||||
|             "spans_key": SPAN_KEY, | ||||
|             "threshold": threshold, | ||||
|             "max_positive": 1, | ||||
|         }, | ||||
|     ) | ||||
|     doc = nlp.make_doc("Greater London") | ||||
|     ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2]) | ||||
|     indices = ngram_suggester([doc])[0].dataXd | ||||
|     assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) | ||||
|     labels = ["Thing", "City", "Person", "GreatCity"] | ||||
|     scores = numpy.asarray( | ||||
|         [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" | ||||
|     ) | ||||
|     spangroup = spancat._make_span_group_singlelabel( | ||||
|         doc, indices, scores, labels, allow_overlap | ||||
|     ) | ||||
|     assert len(spangroup) == nr_results | ||||
|     if threshold > 0.4: | ||||
|         if allow_overlap: | ||||
|             assert spangroup[0].text == "London" | ||||
|             assert spangroup[0].label_ == "City" | ||||
|             assert spangroup[1].text == "Greater London" | ||||
|             assert spangroup[1].label_ == "GreatCity" | ||||
| 
 | ||||
|         else: | ||||
|             assert spangroup[0].text == "Greater London" | ||||
|             assert spangroup[0].label_ == "GreatCity" | ||||
|     else: | ||||
|         if allow_overlap: | ||||
|             assert spangroup[0].text == "Greater" | ||||
|             assert spangroup[0].label_ == "City" | ||||
|             assert spangroup[1].text == "London" | ||||
|             assert spangroup[1].label_ == "City" | ||||
|             assert spangroup[2].text == "Greater London" | ||||
|             assert spangroup[2].label_ == "GreatCity" | ||||
|         else: | ||||
|             assert spangroup[0].text == "Greater London" | ||||
| 
 | ||||
| 
 | ||||
| def test_ngram_suggester(en_tokenizer): | ||||
|     # test different n-gram lengths | ||||
|     for size in [1, 2, 3]: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user