mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix tok2vec
This commit is contained in:
		
							parent
							
								
									fe0cdcd461
								
							
						
					
					
						commit
						099e9331c5
					
				|  | @ -1,13 +1,15 @@ | ||||||
| from typing import Optional, List | from typing import Optional, List | ||||||
| from thinc.api import chain, clone, concatenate, with_array, with_padded | from thinc.api import chain, clone, concatenate, with_array, with_padded | ||||||
| from thinc.api import Model, noop | from thinc.api import Model, noop, list2ragged, ragged2list | ||||||
| from thinc.api import FeatureExtractor, HashEmbed, StaticVectors | from thinc.api import FeatureExtractor, HashEmbed | ||||||
| from thincapi import expand_window, residual, Maxout, Mish | from thinc.api import expand_window, residual, Maxout, Mish | ||||||
| from thinc.types import Floats2d | from thinc.types import Floats2d | ||||||
| 
 | 
 | ||||||
|  | from ...tokens import Doc | ||||||
| from ... import util | from ... import util | ||||||
| from ...util import registry | from ...util import registry | ||||||
| from ...ml import _character_embed | from ...ml import _character_embed | ||||||
|  | from ..staticvectors import StaticVectors | ||||||
| from ...pipeline.tok2vec import Tok2VecListener | from ...pipeline.tok2vec import Tok2VecListener | ||||||
| from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE | from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE | ||||||
| 
 | 
 | ||||||
|  | @ -21,20 +23,19 @@ def tok2vec_listener_v1(width, upstream="*"): | ||||||
| @registry.architectures.register("spacy.Tok2Vec.v1") | @registry.architectures.register("spacy.Tok2Vec.v1") | ||||||
| def Tok2Vec( | def Tok2Vec( | ||||||
|     embed: Model[List[Doc], List[Floats2d]], |     embed: Model[List[Doc], List[Floats2d]], | ||||||
|     encode: Model[List[Floats2d], List[Floats2d] |     encode: Model[List[Floats2d], List[Floats2d]] | ||||||
| ) -> Model[List[Doc], List[Floats2d]]: | ) -> Model[List[Doc], List[Floats2d]]: | ||||||
|     tok2vec = with_array( | 
 | ||||||
|         chain(embed, encode), |     receptive_field = encode.attrs.get("receptive_field", 0) | ||||||
|         pad=encode.attrs.get("receptive_field", 0) |     tok2vec = chain(embed, with_array(encode, pad=receptive_field)) | ||||||
|     ) |  | ||||||
|     tok2vec.set_dim("nO", encode.get_dim("nO")) |     tok2vec.set_dim("nO", encode.get_dim("nO")) | ||||||
|     tok2vec.set_ref("embed", embed) |     tok2vec.set_ref("embed", embed) | ||||||
|     tok2vec.set_ref("encode", encode) |     tok2vec.set_ref("encode", encode) | ||||||
|     return tok2vec |     return tok2vec | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @registry.architectures.register("spacy.HashEmbed.v1") | @registry.architectures.register("spacy.MultiHashEmbed.v1") | ||||||
| def HashEmbed( | def MultiHashEmbed( | ||||||
|     width: int, |     width: int, | ||||||
|     rows: int, |     rows: int, | ||||||
|     also_embed_subwords: bool, |     also_embed_subwords: bool, | ||||||
|  | @ -56,9 +57,9 @@ def HashEmbed( | ||||||
|      |      | ||||||
|     if also_embed_subwords: |     if also_embed_subwords: | ||||||
|         embeddings = [ |         embeddings = [ | ||||||
|             make_hash_embed(NORM) |             make_hash_embed(NORM), | ||||||
|             make_hash_embed(PREFIX) |             make_hash_embed(PREFIX), | ||||||
|             make_hash_embed(SUFFIX) |             make_hash_embed(SUFFIX), | ||||||
|             make_hash_embed(SHAPE) |             make_hash_embed(SHAPE) | ||||||
|         ] |         ] | ||||||
|     else: |     else: | ||||||
|  | @ -67,15 +68,25 @@ def HashEmbed( | ||||||
|     if also_use_static_vectors: |     if also_use_static_vectors: | ||||||
|         model = chain( |         model = chain( | ||||||
|             concatenate( |             concatenate( | ||||||
|                 chain(FeatureExtractor(cols), concatenate(*embeddings)), |                 chain( | ||||||
|  |                     FeatureExtractor(cols), | ||||||
|  |                     list2ragged(), | ||||||
|  |                     with_array(concatenate(*embeddings)) | ||||||
|  |                 ), | ||||||
|                 StaticVectors(width, dropout=0.0) |                 StaticVectors(width, dropout=0.0) | ||||||
|             ), |             ), | ||||||
|             Maxout(width, pieces=3, dropout=0.0, normalize=True) |             with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), | ||||||
|  |             ragged2list() | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         model = chain( |         model = chain( | ||||||
|             chain(FeatureExtractor(cols), concatenate(*embeddings)), |             chain( | ||||||
|             Maxout(width, pieces=3, dropout=0.0, normalize=True) |                 FeatureExtractor(cols), | ||||||
|  |                 list2ragged(), | ||||||
|  |                 with_array(concatenate(*embeddings)) | ||||||
|  |             ), | ||||||
|  |             with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), | ||||||
|  |             ragged2list() | ||||||
|         ) |         ) | ||||||
|     return model |     return model | ||||||
|   |   | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user