mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix tok2vec
This commit is contained in:
		
							parent
							
								
									fe0cdcd461
								
							
						
					
					
						commit
						099e9331c5
					
				|  | @ -1,13 +1,15 @@ | |||
| from typing import Optional, List | ||||
| from thinc.api import chain, clone, concatenate, with_array, with_padded | ||||
| from thinc.api import Model, noop | ||||
| from thinc.api import FeatureExtractor, HashEmbed, StaticVectors | ||||
| from thincapi import expand_window, residual, Maxout, Mish | ||||
| from thinc.api import Model, noop, list2ragged, ragged2list | ||||
| from thinc.api import FeatureExtractor, HashEmbed | ||||
| from thinc.api import expand_window, residual, Maxout, Mish | ||||
| from thinc.types import Floats2d | ||||
| 
 | ||||
| from ...tokens import Doc | ||||
| from ... import util | ||||
| from ...util import registry | ||||
| from ...ml import _character_embed | ||||
| from ..staticvectors import StaticVectors | ||||
| from ...pipeline.tok2vec import Tok2VecListener | ||||
| from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE | ||||
| 
 | ||||
|  | @ -21,20 +23,19 @@ def tok2vec_listener_v1(width, upstream="*"): | |||
| @registry.architectures.register("spacy.Tok2Vec.v1") | ||||
| def Tok2Vec( | ||||
|     embed: Model[List[Doc], List[Floats2d]], | ||||
|     encode: Model[List[Floats2d], List[Floats2d] | ||||
|     encode: Model[List[Floats2d], List[Floats2d]] | ||||
| ) -> Model[List[Doc], List[Floats2d]]: | ||||
|     tok2vec = with_array( | ||||
|         chain(embed, encode), | ||||
|         pad=encode.attrs.get("receptive_field", 0) | ||||
|     ) | ||||
| 
 | ||||
|     receptive_field = encode.attrs.get("receptive_field", 0) | ||||
|     tok2vec = chain(embed, with_array(encode, pad=receptive_field)) | ||||
|     tok2vec.set_dim("nO", encode.get_dim("nO")) | ||||
|     tok2vec.set_ref("embed", embed) | ||||
|     tok2vec.set_ref("encode", encode) | ||||
|     return tok2vec | ||||
| 
 | ||||
| 
 | ||||
| @registry.architectures.register("spacy.HashEmbed.v1") | ||||
| def HashEmbed( | ||||
| @registry.architectures.register("spacy.MultiHashEmbed.v1") | ||||
| def MultiHashEmbed( | ||||
|     width: int, | ||||
|     rows: int, | ||||
|     also_embed_subwords: bool, | ||||
|  | @ -56,9 +57,9 @@ def HashEmbed( | |||
|      | ||||
|     if also_embed_subwords: | ||||
|         embeddings = [ | ||||
|             make_hash_embed(NORM) | ||||
|             make_hash_embed(PREFIX) | ||||
|             make_hash_embed(SUFFIX) | ||||
|             make_hash_embed(NORM), | ||||
|             make_hash_embed(PREFIX), | ||||
|             make_hash_embed(SUFFIX), | ||||
|             make_hash_embed(SHAPE) | ||||
|         ] | ||||
|     else: | ||||
|  | @ -67,15 +68,25 @@ def HashEmbed( | |||
|     if also_use_static_vectors: | ||||
|         model = chain( | ||||
|             concatenate( | ||||
|                 chain(FeatureExtractor(cols), concatenate(*embeddings)), | ||||
|                 chain( | ||||
|                     FeatureExtractor(cols), | ||||
|                     list2ragged(), | ||||
|                     with_array(concatenate(*embeddings)) | ||||
|                 ), | ||||
|                 StaticVectors(width, dropout=0.0) | ||||
|             ), | ||||
|             Maxout(width, pieces=3, dropout=0.0, normalize=True) | ||||
|             with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), | ||||
|             ragged2list() | ||||
|         ) | ||||
|     else: | ||||
|         model = chain( | ||||
|             chain(FeatureExtractor(cols), concatenate(*embeddings)), | ||||
|             Maxout(width, pieces=3, dropout=0.0, normalize=True) | ||||
|             chain( | ||||
|                 FeatureExtractor(cols), | ||||
|                 list2ragged(), | ||||
|                 with_array(concatenate(*embeddings)) | ||||
|             ), | ||||
|             with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), | ||||
|             ragged2list() | ||||
|         ) | ||||
|     return model | ||||
|   | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user