mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	Fix spancat for empty docs and zero suggestions (#9654)
* Fix spancat for empty docs and zero suggestions * Use ops.xp.zeros in test
This commit is contained in:
		
							parent
							
								
									67d8c8a081
								
							
						
					
					
						commit
						c9baf9d196
					
				|  | @ -28,7 +28,13 @@ def forward( | |||
|     X, spans = source_spans | ||||
|     assert spans.dataXd.ndim == 2 | ||||
|     indices = _get_span_indices(ops, spans, X.lengths) | ||||
|     Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0])  # type: ignore[arg-type, index] | ||||
|     if len(indices) > 0: | ||||
|         Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0])  # type: ignore[arg-type, index] | ||||
|     else: | ||||
|         Y = Ragged( | ||||
|             ops.xp.zeros(X.dataXd.shape, dtype=X.dataXd.dtype), | ||||
|             ops.xp.zeros((len(X.lengths),), dtype="i"), | ||||
|         ) | ||||
|     x_shape = X.dataXd.shape | ||||
|     x_lengths = X.lengths | ||||
| 
 | ||||
|  | @ -53,7 +59,7 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d: | |||
|         for j in range(spans_i.shape[0]): | ||||
|             indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1]))  # type: ignore[call-overload, index] | ||||
|         offset += length | ||||
|     return ops.flatten(indices) | ||||
|     return ops.flatten(indices, dtype="i", ndim_if_empty=1) | ||||
| 
 | ||||
| 
 | ||||
| def _ensure_cpu(spans: Ragged, lengths: Ints1d) -> Tuple[Ragged, Ints1d]: | ||||
|  |  | |||
|  | @ -78,7 +78,7 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester: | |||
|         if len(spans) > 0: | ||||
|             output = Ragged(ops.xp.vstack(spans), lengths_array) | ||||
|         else: | ||||
|             output = Ragged(ops.xp.zeros((0, 0)), lengths_array) | ||||
|             output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) | ||||
| 
 | ||||
|         assert output.dataXd.ndim == 2 | ||||
|         return output | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| import pytest | ||||
| import numpy | ||||
| from numpy.testing import assert_array_equal, assert_almost_equal | ||||
| from thinc.api import get_current_ops | ||||
| from thinc.api import get_current_ops, Ragged | ||||
| 
 | ||||
| from spacy import util | ||||
| from spacy.lang.en import English | ||||
|  | @ -29,6 +29,7 @@ TRAIN_DATA_OVERLAPPING = [ | |||
|         "I like London and Berlin", | ||||
|         {"spans": {SPAN_KEY: [(7, 13, "LOC"), (18, 24, "LOC"), (7, 24, "DOUBLE_LOC")]}}, | ||||
|     ), | ||||
|     ("", {"spans": {SPAN_KEY: []}}), | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -365,3 +366,31 @@ def test_overfitting_IO_overlapping(): | |||
|             "London and Berlin", | ||||
|         } | ||||
|         assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} | ||||
| 
 | ||||
| 
 | ||||
| def test_zero_suggestions(): | ||||
|     # Test with a suggester that returns 0 suggestions | ||||
| 
 | ||||
|     @registry.misc("test_zero_suggester") | ||||
|     def make_zero_suggester(): | ||||
|         def zero_suggester(docs, *, ops=None): | ||||
|             if ops is None: | ||||
|                 ops = get_current_ops() | ||||
|             return Ragged( | ||||
|                 ops.xp.zeros((0, 0), dtype="i"), ops.xp.zeros((len(docs),), dtype="i") | ||||
|             ) | ||||
| 
 | ||||
|         return zero_suggester | ||||
| 
 | ||||
|     fix_random_seed(0) | ||||
|     nlp = English() | ||||
|     spancat = nlp.add_pipe( | ||||
|         "spancat", | ||||
|         config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY}, | ||||
|     ) | ||||
|     train_examples = make_examples(nlp) | ||||
|     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||
|     assert spancat.model.get_dim("nO") == 2 | ||||
|     assert set(spancat.labels) == {"LOC", "PERSON"} | ||||
| 
 | ||||
|     nlp.update(train_examples, sgd=optimizer) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user