mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Allow CharacterEmbed to specify feature
This commit is contained in:
		
							parent
							
								
									0a8a124a6e
								
							
						
					
					
						commit
						684a77870b
					
				| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
from typing import Optional, List
 | 
					from typing import Optional, List, Union
 | 
				
			||||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
 | 
					from thinc.api import chain, clone, concatenate, with_array, with_padded
 | 
				
			||||||
from thinc.api import Model, noop, list2ragged, ragged2list
 | 
					from thinc.api import Model, noop, list2ragged, ragged2list
 | 
				
			||||||
from thinc.api import FeatureExtractor, HashEmbed
 | 
					from thinc.api import FeatureExtractor, HashEmbed
 | 
				
			||||||
| 
						 | 
					@ -165,7 +165,8 @@ def MultiHashEmbed(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
 | 
					@registry.architectures.register("spacy.CharacterEmbed.v1")
 | 
				
			||||||
def CharacterEmbed(
 | 
					def CharacterEmbed(
 | 
				
			||||||
    width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool
 | 
					    width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
 | 
				
			||||||
 | 
					    feature: Union[int, str]="NORM"
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    """Construct an embedded representation based on character embeddings, using
 | 
					    """Construct an embedded representation based on character embeddings, using
 | 
				
			||||||
    a feed-forward network. A fixed number of UTF-8 byte characters are used for
 | 
					    a feed-forward network. A fixed number of UTF-8 byte characters are used for
 | 
				
			||||||
| 
						 | 
					@ -183,7 +184,8 @@ def CharacterEmbed(
 | 
				
			||||||
    also concatenated on, and the result is then passed through a feed-forward
 | 
					    also concatenated on, and the result is then passed through a feed-forward
 | 
				
			||||||
    network to construct a single vector to represent the information.
 | 
					    network to construct a single vector to represent the information.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    width (int): The width of the output vector and the NORM hash embedding.
 | 
					    feature (int or str): An attribute to embed, to concatenate with the characters.
 | 
				
			||||||
 | 
					    width (int): The width of the output vector and the feature embedding.
 | 
				
			||||||
    rows (int): The number of rows in the NORM hash embedding table.
 | 
					    rows (int): The number of rows in the NORM hash embedding table.
 | 
				
			||||||
    nM (int): The dimensionality of the character embeddings. Recommended values
 | 
					    nM (int): The dimensionality of the character embeddings. Recommended values
 | 
				
			||||||
        are between 16 and 64.
 | 
					        are between 16 and 64.
 | 
				
			||||||
| 
						 | 
					@ -193,12 +195,15 @@ def CharacterEmbed(
 | 
				
			||||||
    also_use_static_vectors (bool): Whether to also use static word vectors.
 | 
					    also_use_static_vectors (bool): Whether to also use static word vectors.
 | 
				
			||||||
        Requires a vectors table to be loaded in the Doc objects' vocab.
 | 
					        Requires a vectors table to be loaded in the Doc objects' vocab.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    feature = intify_attr(feature)
 | 
				
			||||||
 | 
					    if feature is None:
 | 
				
			||||||
 | 
					        raise ValueError("Invalid feature: Must be a token attribute.")
 | 
				
			||||||
    if also_use_static_vectors:
 | 
					    if also_use_static_vectors:
 | 
				
			||||||
        model = chain(
 | 
					        model = chain(
 | 
				
			||||||
            concatenate(
 | 
					            concatenate(
 | 
				
			||||||
                chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
 | 
					                chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
 | 
				
			||||||
                chain(
 | 
					                chain(
 | 
				
			||||||
                    FeatureExtractor([NORM]),
 | 
					                    FeatureExtractor([feature]),
 | 
				
			||||||
                    list2ragged(),
 | 
					                    list2ragged(),
 | 
				
			||||||
                    with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
 | 
					                    with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
| 
						 | 
					@ -214,7 +219,7 @@ def CharacterEmbed(
 | 
				
			||||||
            concatenate(
 | 
					            concatenate(
 | 
				
			||||||
                chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
 | 
					                chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
 | 
				
			||||||
                chain(
 | 
					                chain(
 | 
				
			||||||
                    FeatureExtractor([NORM]),
 | 
					                    FeatureExtractor([feature]),
 | 
				
			||||||
                    list2ragged(),
 | 
					                    list2ragged(),
 | 
				
			||||||
                    with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
 | 
					                    with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user