mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Fix Tok2Vec for empty batches (#10324)
* Add test for tok2vec with vectors and empty docs * Add shortcut for empty batch in Tok2Vec.predict * Avoid types
This commit is contained in:
		
							parent
							
								
									034ac0acf4
								
							
						
					
					
						commit
						749631ad28
					
				|  | @ -118,6 +118,10 @@ class Tok2Vec(TrainablePipe): | ||||||
| 
 | 
 | ||||||
|         DOCS: https://spacy.io/api/tok2vec#predict |         DOCS: https://spacy.io/api/tok2vec#predict | ||||||
|         """ |         """ | ||||||
|  |         if not any(len(doc) for doc in docs): | ||||||
|  |             # Handle cases where there are no tokens in any docs. | ||||||
|  |             width = self.model.get_dim("nO") | ||||||
|  |             return [self.model.ops.alloc((0, width)) for doc in docs] | ||||||
|         tokvecs = self.model.predict(docs) |         tokvecs = self.model.predict(docs) | ||||||
|         batch_id = Tok2VecListener.get_batch_id(docs) |         batch_id = Tok2VecListener.get_batch_id(docs) | ||||||
|         for listener in self.listeners: |         for listener in self.listeners: | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ from spacy.lang.en import English | ||||||
| from thinc.api import Config, get_current_ops | from thinc.api import Config, get_current_ops | ||||||
| from numpy.testing import assert_array_equal | from numpy.testing import assert_array_equal | ||||||
| 
 | 
 | ||||||
| from ..util import get_batch, make_tempdir | from ..util import get_batch, make_tempdir, add_vecs_to_vocab | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_empty_doc(): | def test_empty_doc(): | ||||||
|  | @ -134,9 +134,25 @@ TRAIN_DATA = [ | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_tok2vec_listener(): | @pytest.mark.parametrize("with_vectors", (False, True)) | ||||||
|  | def test_tok2vec_listener(with_vectors): | ||||||
|     orig_config = Config().from_str(cfg_string) |     orig_config = Config().from_str(cfg_string) | ||||||
|  |     orig_config["components"]["tok2vec"]["model"]["embed"][ | ||||||
|  |         "include_static_vectors" | ||||||
|  |     ] = with_vectors | ||||||
|     nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) |     nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) | ||||||
|  | 
 | ||||||
|  |     if with_vectors: | ||||||
|  |         ops = get_current_ops() | ||||||
|  |         vectors = [ | ||||||
|  |             ("apple", ops.asarray([1, 2, 3])), | ||||||
|  |             ("orange", ops.asarray([-1, -2, -3])), | ||||||
|  |             ("and", ops.asarray([-1, -1, -1])), | ||||||
|  |             ("juice", ops.asarray([5, 5, 10])), | ||||||
|  |             ("pie", ops.asarray([7, 6.3, 8.9])), | ||||||
|  |         ] | ||||||
|  |         add_vecs_to_vocab(nlp.vocab, vectors) | ||||||
|  | 
 | ||||||
|     assert nlp.pipe_names == ["tok2vec", "tagger"] |     assert nlp.pipe_names == ["tok2vec", "tagger"] | ||||||
|     tagger = nlp.get_pipe("tagger") |     tagger = nlp.get_pipe("tagger") | ||||||
|     tok2vec = nlp.get_pipe("tok2vec") |     tok2vec = nlp.get_pipe("tok2vec") | ||||||
|  | @ -163,6 +179,9 @@ def test_tok2vec_listener(): | ||||||
|     ops = get_current_ops() |     ops = get_current_ops() | ||||||
|     assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor)) |     assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor)) | ||||||
| 
 | 
 | ||||||
|  |     # test with empty doc | ||||||
|  |     doc = nlp("") | ||||||
|  | 
 | ||||||
|     # TODO: should this warn or error? |     # TODO: should this warn or error? | ||||||
|     nlp.select_pipes(disable="tok2vec") |     nlp.select_pipes(disable="tok2vec") | ||||||
|     assert nlp.pipe_names == ["tagger"] |     assert nlp.pipe_names == ["tagger"] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user