mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Small spancat fixes (#8614)
* two small fixes + additional tests * rename
This commit is contained in:
		
							parent
							
								
									327f83573a
								
							
						
					
					
						commit
						3daf57d70c
					
				|  | @ -182,6 +182,7 @@ class SpanCategorizer(TrainablePipe): | ||||||
|             raise ValueError(Errors.E187) |             raise ValueError(Errors.E187) | ||||||
|         if label in self.labels: |         if label in self.labels: | ||||||
|             return 0 |             return 0 | ||||||
|  |         self._allow_extra_label() | ||||||
|         self.cfg["labels"].append(label) |         self.cfg["labels"].append(label) | ||||||
|         self.vocab.strings.add(label) |         self.vocab.strings.add(label) | ||||||
|         return 1 |         return 1 | ||||||
|  | @ -348,7 +349,7 @@ class SpanCategorizer(TrainablePipe): | ||||||
|                 self.add_label(label) |                 self.add_label(label) | ||||||
|         for eg in get_examples(): |         for eg in get_examples(): | ||||||
|             if labels is None: |             if labels is None: | ||||||
|                 for span in eg.reference.spans[self.key]: |                 for span in eg.reference.spans.get(self.key, []): | ||||||
|                     self.add_label(span.label_) |                     self.add_label(span.label_) | ||||||
|             if len(subbatch) < 10: |             if len(subbatch) < 10: | ||||||
|                 subbatch.append(eg) |                 subbatch.append(eg) | ||||||
|  |  | ||||||
|  | @ -1,3 +1,4 @@ | ||||||
|  | import pytest | ||||||
| from numpy.testing import assert_equal | from numpy.testing import assert_equal | ||||||
| from spacy.language import Language | from spacy.language import Language | ||||||
| from spacy.training import Example | from spacy.training import Example | ||||||
|  | @ -27,6 +28,47 @@ def make_get_examples(nlp): | ||||||
|     return get_examples |     return get_examples | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_no_label(): | ||||||
|  |     nlp = Language() | ||||||
|  |     nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         nlp.initialize() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_no_resize(): | ||||||
|  |     nlp = Language() | ||||||
|  |     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|  |     spancat.add_label("Thing") | ||||||
|  |     spancat.add_label("Phrase") | ||||||
|  |     assert spancat.labels == ("Thing", "Phrase") | ||||||
|  |     nlp.initialize() | ||||||
|  |     assert spancat.model.get_dim("nO") == 2 | ||||||
|  |     # this throws an error because the spancat can't be resized after initialization | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         spancat.add_label("Stuff") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_implicit_labels(): | ||||||
|  |     nlp = Language() | ||||||
|  |     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|  |     assert len(spancat.labels) == 0 | ||||||
|  |     train_examples = [] | ||||||
|  |     for t in TRAIN_DATA: | ||||||
|  |         train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) | ||||||
|  |     nlp.initialize(get_examples=lambda: train_examples) | ||||||
|  |     assert spancat.labels == ("PERSON", "LOC") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_explicit_labels(): | ||||||
|  |     nlp = Language() | ||||||
|  |     spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) | ||||||
|  |     assert len(spancat.labels) == 0 | ||||||
|  |     spancat.add_label("PERSON") | ||||||
|  |     spancat.add_label("LOC") | ||||||
|  |     nlp.initialize() | ||||||
|  |     assert spancat.labels == ("PERSON", "LOC") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_simple_train(): | def test_simple_train(): | ||||||
|     fix_random_seed(0) |     fix_random_seed(0) | ||||||
|     nlp = Language() |     nlp = Language() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user