mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	
							parent
							
								
									593a22cf2d
								
							
						
					
					
						commit
						0a6b68848f
					
				| 
						 | 
				
			
			@ -411,7 +411,10 @@ class SpanCategorizer(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
        keeps = scores >= threshold
 | 
			
		||||
        ranked = (scores * -1).argsort()
 | 
			
		||||
        keeps[ranked[:, max_positive:]] = False
 | 
			
		||||
        if max_positive is not None:
 | 
			
		||||
            filter = ranked[:, max_positive:]
 | 
			
		||||
            for i, row in enumerate(filter):
 | 
			
		||||
                keeps[i, row] = False
 | 
			
		||||
        spans.attrs["scores"] = scores[keeps].flatten()
 | 
			
		||||
 | 
			
		||||
        indices = self.model.ops.to_numpy(indices)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
import pytest
 | 
			
		||||
from numpy.testing import assert_equal, assert_array_equal
 | 
			
		||||
import numpy
 | 
			
		||||
from numpy.testing import assert_equal, assert_array_equal, assert_almost_equal
 | 
			
		||||
from thinc.api import get_current_ops
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy.training import Example
 | 
			
		||||
| 
						 | 
				
			
			@ -71,6 +72,55 @@ def test_explicit_labels():
 | 
			
		|||
    assert spancat.labels == ("PERSON", "LOC")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
 | 
			
		||||
)
 | 
			
		||||
def test_make_spangroup(max_positive, nr_results):
 | 
			
		||||
    fix_random_seed(0)
 | 
			
		||||
    nlp = Language()
 | 
			
		||||
    spancat = nlp.add_pipe(
 | 
			
		||||
        "spancat",
 | 
			
		||||
        config={"spans_key": SPAN_KEY, "threshold": 0.5, "max_positive": max_positive},
 | 
			
		||||
    )
 | 
			
		||||
    doc = nlp.make_doc("Greater London")
 | 
			
		||||
    ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
 | 
			
		||||
    indices = ngram_suggester([doc])[0].dataXd
 | 
			
		||||
    assert_array_equal(indices, numpy.asarray([[0, 1], [1, 2], [0, 2]]))
 | 
			
		||||
    labels = ["Thing", "City", "Person", "GreatCity"]
 | 
			
		||||
    scores = numpy.asarray(
 | 
			
		||||
        [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
 | 
			
		||||
    )
 | 
			
		||||
    spangroup = spancat._make_span_group(doc, indices, scores, labels)
 | 
			
		||||
    assert len(spangroup) == nr_results
 | 
			
		||||
 | 
			
		||||
    # first span is always the second token "London"
 | 
			
		||||
    assert spangroup[0].text == "London"
 | 
			
		||||
    assert spangroup[0].label_ == "City"
 | 
			
		||||
    assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
 | 
			
		||||
 | 
			
		||||
    # second span depends on the number of positives that were allowed
 | 
			
		||||
    assert spangroup[1].text == "Greater London"
 | 
			
		||||
    if max_positive == 1:
 | 
			
		||||
        assert spangroup[1].label_ == "GreatCity"
 | 
			
		||||
        assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
 | 
			
		||||
    else:
 | 
			
		||||
        assert spangroup[1].label_ == "Thing"
 | 
			
		||||
        assert_almost_equal(0.8, spangroup.attrs["scores"][1], 5)
 | 
			
		||||
 | 
			
		||||
    if nr_results > 2:
 | 
			
		||||
        assert spangroup[2].text == "Greater London"
 | 
			
		||||
        if max_positive == 2:
 | 
			
		||||
            assert spangroup[2].label_ == "GreatCity"
 | 
			
		||||
            assert_almost_equal(0.9, spangroup.attrs["scores"][2], 5)
 | 
			
		||||
        else:
 | 
			
		||||
            assert spangroup[2].label_ == "City"
 | 
			
		||||
            assert_almost_equal(0.7, spangroup.attrs["scores"][2], 5)
 | 
			
		||||
 | 
			
		||||
    assert spangroup[-1].text == "Greater London"
 | 
			
		||||
    assert spangroup[-1].label_ == "GreatCity"
 | 
			
		||||
    assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_simple_train():
 | 
			
		||||
    fix_random_seed(0)
 | 
			
		||||
    nlp = Language()
 | 
			
		||||
| 
						 | 
				
			
			@ -90,6 +140,9 @@ def test_simple_train():
 | 
			
		|||
    scores = nlp.evaluate(get_examples())
 | 
			
		||||
    assert f"spans_{SPAN_KEY}_f" in scores
 | 
			
		||||
    assert scores[f"spans_{SPAN_KEY}_f"] == 1.0
 | 
			
		||||
    # also test that the spancat works for just a single entity in a sentence
 | 
			
		||||
    doc = nlp("London")
 | 
			
		||||
    assert len(doc.spans[spancat.key]) == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_ngram_suggester(en_tokenizer):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user