mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge pull request #5834 from explosion/feature/vectors
This commit is contained in:
commit
7a21775cd0
|
@ -20,20 +20,20 @@ seed = 0
|
||||||
accumulate_gradient = 1
|
accumulate_gradient = 1
|
||||||
use_pytorch_for_gpu_memory = false
|
use_pytorch_for_gpu_memory = false
|
||||||
# Control how scores are printed and checkpoints are evaluated.
|
# Control how scores are printed and checkpoints are evaluated.
|
||||||
scores = ["speed", "tags_acc", "uas", "las", "ents_f"]
|
eval_batch_size = 128
|
||||||
score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||||
# These settings are invalid for the transformer models.
|
|
||||||
init_tok2vec = null
|
init_tok2vec = null
|
||||||
discard_oversize = false
|
discard_oversize = false
|
||||||
omit_extra_lookups = false
|
|
||||||
batch_by = "words"
|
batch_by = "words"
|
||||||
use_gpu = -1
|
|
||||||
raw_text = null
|
raw_text = null
|
||||||
tag_map = null
|
tag_map = null
|
||||||
|
vectors = null
|
||||||
|
base_model = null
|
||||||
|
morph_rules = null
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
start = 1000
|
start = 100
|
||||||
stop = 1000
|
stop = 1000
|
||||||
compound = 1.001
|
compound = 1.001
|
||||||
|
|
||||||
|
@ -46,74 +46,79 @@ L2 = 0.01
|
||||||
grad_clip = 1.0
|
grad_clip = 1.0
|
||||||
use_averages = false
|
use_averages = false
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
#learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
|
|
||||||
[training.optimizer.learn_rate]
|
|
||||||
@schedules = "warmup_linear.v1"
|
|
||||||
warmup_steps = 250
|
|
||||||
total_steps = 20000
|
|
||||||
initial_rate = 0.001
|
|
||||||
|
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
base_model = null
|
load_vocab_data = false
|
||||||
vectors = null
|
pipeline = ["tok2vec", "ner", "tagger", "parser"]
|
||||||
|
|
||||||
[nlp.pipeline]
|
[nlp.tokenizer]
|
||||||
|
@tokenizers = "spacy.Tokenizer.v1"
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec]
|
[nlp.lemmatizer]
|
||||||
|
@lemmatizers = "spacy.Lemmatizer.v1"
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
factory = "tok2vec"
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.ner]
|
||||||
[nlp.pipeline.ner]
|
|
||||||
factory = "ner"
|
factory = "ner"
|
||||||
learn_tokens = false
|
learn_tokens = false
|
||||||
min_action_freq = 1
|
min_action_freq = 1
|
||||||
|
|
||||||
[nlp.pipeline.tagger]
|
[components.tagger]
|
||||||
factory = "tagger"
|
factory = "tagger"
|
||||||
|
|
||||||
[nlp.pipeline.parser]
|
[components.parser]
|
||||||
factory = "parser"
|
factory = "parser"
|
||||||
learn_tokens = false
|
learn_tokens = false
|
||||||
min_action_freq = 30
|
min_action_freq = 30
|
||||||
|
|
||||||
[nlp.pipeline.tagger.model]
|
[components.tagger.model]
|
||||||
@architectures = "spacy.Tagger.v1"
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
[nlp.pipeline.tagger.model.tok2vec]
|
[components.tagger.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${nlp.pipeline.tok2vec.model:width}
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
|
||||||
[nlp.pipeline.parser.model]
|
[components.parser.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v1"
|
||||||
nr_feature_tokens = 8
|
nr_feature_tokens = 8
|
||||||
hidden_width = 128
|
hidden_width = 128
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = true
|
use_upper = true
|
||||||
|
|
||||||
[nlp.pipeline.parser.model.tok2vec]
|
[components.parser.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${nlp.pipeline.tok2vec.model:width}
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
|
||||||
[nlp.pipeline.ner.model]
|
[components.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v1"
|
||||||
nr_feature_tokens = 3
|
nr_feature_tokens = 3
|
||||||
hidden_width = 128
|
hidden_width = 128
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
use_upper = true
|
use_upper = true
|
||||||
|
|
||||||
[nlp.pipeline.ner.model.tok2vec]
|
[components.ner.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${nlp.pipeline.tok2vec.model:width}
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model]
|
[components.tok2vec.model]
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
@architectures = "spacy.Tok2Vec.v1"
|
||||||
pretrained_vectors = ${nlp:vectors}
|
|
||||||
width = 128
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
rows = 2000
|
||||||
|
also_embed_subwords = true
|
||||||
|
also_use_static_vectors = false
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||||
|
width = 96
|
||||||
depth = 4
|
depth = 4
|
||||||
window_size = 1
|
window_size = 1
|
||||||
embed_size = 7000
|
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
|
||||||
dropout = ${training:dropout}
|
|
||||||
|
|
|
@ -9,11 +9,11 @@ max_epochs = 100
|
||||||
orth_variant_level = 0.0
|
orth_variant_level = 0.0
|
||||||
gold_preproc = true
|
gold_preproc = true
|
||||||
max_length = 0
|
max_length = 0
|
||||||
scores = ["tag_acc", "dep_uas", "dep_las"]
|
scores = ["tag_acc", "dep_uas", "dep_las", "speed"]
|
||||||
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
||||||
limit = 0
|
limit = 0
|
||||||
seed = 0
|
seed = 0
|
||||||
accumulate_gradient = 2
|
accumulate_gradient = 1
|
||||||
discard_oversize = false
|
discard_oversize = false
|
||||||
raw_text = null
|
raw_text = null
|
||||||
tag_map = null
|
tag_map = null
|
||||||
|
@ -22,7 +22,7 @@ base_model = null
|
||||||
|
|
||||||
eval_batch_size = 128
|
eval_batch_size = 128
|
||||||
use_pytorch_for_gpu_memory = false
|
use_pytorch_for_gpu_memory = false
|
||||||
batch_by = "padded"
|
batch_by = "words"
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
|
@ -64,8 +64,8 @@ min_action_freq = 1
|
||||||
@architectures = "spacy.Tagger.v1"
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
[components.tagger.model.tok2vec]
|
[components.tagger.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${components.tok2vec.model:width}
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
|
||||||
[components.parser.model]
|
[components.parser.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v1"
|
||||||
|
@ -74,16 +74,22 @@ hidden_width = 64
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
|
|
||||||
[components.parser.model.tok2vec]
|
[components.parser.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${components.tok2vec.model:width}
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
|
||||||
[components.tok2vec.model]
|
[components.tok2vec.model]
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
@architectures = "spacy.Tok2Vec.v1"
|
||||||
pretrained_vectors = ${training:vectors}
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = ${components.tok2vec.model.encode:width}
|
||||||
|
rows = 2000
|
||||||
|
also_embed_subwords = true
|
||||||
|
also_use_static_vectors = false
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||||
width = 96
|
width = 96
|
||||||
depth = 4
|
depth = 4
|
||||||
window_size = 1
|
window_size = 1
|
||||||
embed_size = 2000
|
|
||||||
maxout_pieces = 3
|
maxout_pieces = 3
|
||||||
subword_features = true
|
|
||||||
dropout = null
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ from ...util import ensure_path, working_dir
|
||||||
from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum
|
from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find a solution for caches
|
# TODO: find a solution for caches
|
||||||
# CACHES = [
|
# CACHES = [
|
||||||
# Path.home() / ".torch",
|
# Path.home() / ".torch",
|
||||||
|
|
|
@ -80,16 +80,20 @@ def train(
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
msg.info(f"Loading config and nlp from: {config_path}")
|
msg.info(f"Loading config and nlp from: {config_path}")
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path)
|
||||||
|
if config.get("training", {}).get("seed") is not None:
|
||||||
|
fix_random_seed(config["training"]["seed"])
|
||||||
with show_validation_error():
|
with show_validation_error():
|
||||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||||
if config["training"]["base_model"]:
|
if config["training"]["base_model"]:
|
||||||
base_nlp = util.load_model(config["training"]["base_model"])
|
|
||||||
# TODO: do something to check base_nlp against regular nlp described in config?
|
# TODO: do something to check base_nlp against regular nlp described in config?
|
||||||
nlp = base_nlp
|
# If everything matches it will look something like:
|
||||||
|
# base_nlp = util.load_model(config["training"]["base_model"])
|
||||||
|
# nlp = base_nlp
|
||||||
|
raise NotImplementedError("base_model not supported yet.")
|
||||||
|
if config["training"]["vectors"] is not None:
|
||||||
|
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||||
verify_config(nlp)
|
verify_config(nlp)
|
||||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||||
if config["training"]["seed"] is not None:
|
|
||||||
fix_random_seed(config["training"]["seed"])
|
|
||||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||||
# It feels kind of weird to not have a default for this.
|
# It feels kind of weird to not have a default for this.
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
|
@ -242,7 +246,7 @@ def create_evaluation_callback(
|
||||||
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
||||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||||
dev_examples = corpus.dev_dataset(
|
dev_examples = corpus.dev_dataset(
|
||||||
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
|
nlp, gold_preproc=cfg["gold_preproc"]
|
||||||
)
|
)
|
||||||
dev_examples = list(dev_examples)
|
dev_examples = list(dev_examples)
|
||||||
n_words = sum(len(ex.predicted) for ex in dev_examples)
|
n_words = sum(len(ex.predicted) for ex in dev_examples)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .vocab import Vocab, create_vocab
|
||||||
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
|
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
|
||||||
from .gold import Example
|
from .gold import Example
|
||||||
from .scorer import Scorer
|
from .scorer import Scorer
|
||||||
from .util import link_vectors_to_models, create_default_optimizer, registry
|
from .util import create_default_optimizer, registry
|
||||||
from .util import SimpleFrozenDict, combine_score_weights
|
from .util import SimpleFrozenDict, combine_score_weights
|
||||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||||
|
@ -1051,7 +1051,6 @@ class Language:
|
||||||
if self.vocab.vectors.data.shape[1] >= 1:
|
if self.vocab.vectors.data.shape[1] >= 1:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = create_default_optimizer()
|
sgd = create_default_optimizer()
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
|
@ -1084,7 +1083,6 @@ class Language:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
if self.vocab.vectors.data.shape[1] >= 1:
|
if self.vocab.vectors.data.shape[1] >= 1:
|
||||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = create_default_optimizer()
|
sgd = create_default_optimizer()
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
|
@ -1410,6 +1408,10 @@ class Language:
|
||||||
nlp = cls(
|
nlp = cls(
|
||||||
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer,
|
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer,
|
||||||
)
|
)
|
||||||
|
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||||
|
# inside stuff like the spacy train function. If we loaded them here,
|
||||||
|
# then we would load them twice at runtime: once when we make from config,
|
||||||
|
# and then again when we load from disk.
|
||||||
pipeline = config.get("components", {})
|
pipeline = config.get("components", {})
|
||||||
for pipe_name in config["nlp"]["pipeline"]:
|
for pipe_name in config["nlp"]["pipeline"]:
|
||||||
if pipe_name not in pipeline:
|
if pipe_name not in pipeline:
|
||||||
|
@ -1618,8 +1620,6 @@ def _fix_pretrained_vectors_name(nlp: Language) -> None:
|
||||||
nlp.vocab.vectors.name = vectors_name
|
nlp.vocab.vectors.name = vectors_name
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E092)
|
raise ValueError(Errors.E092)
|
||||||
if nlp.vocab.vectors.size != 0:
|
|
||||||
link_vectors_to_models(nlp.vocab)
|
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
if not hasattr(proc, "cfg"):
|
if not hasattr(proc, "cfg"):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
|
from typing import List
|
||||||
from thinc.api import Model
|
from thinc.api import Model
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
def CharacterEmbed(nM, nC):
|
def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
|
||||||
# nM: Number of dimensions per character. nC: Number of characters.
|
# nM: Number of dimensions per character. nC: Number of characters.
|
||||||
nO = nM * nC if (nM is not None and nC is not None) else None
|
|
||||||
return Model(
|
return Model(
|
||||||
"charembed",
|
"charembed",
|
||||||
forward,
|
forward,
|
||||||
init=init,
|
init=init,
|
||||||
dims={"nM": nM, "nC": nC, "nO": nO, "nV": 256},
|
dims={"nM": nM, "nC": nC, "nO": nM * nC, "nV": 256},
|
||||||
params={"E": None},
|
params={"E": None},
|
||||||
).initialize()
|
)
|
||||||
|
|
||||||
|
|
||||||
def init(model, X=None, Y=None):
|
def init(model, X=None, Y=None):
|
||||||
|
|
|
@ -5,11 +5,11 @@ from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_
|
||||||
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued
|
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued
|
||||||
from thinc.api import Relu, residual, expand_window, FeatureExtractor
|
from thinc.api import Relu, residual, expand_window, FeatureExtractor
|
||||||
|
|
||||||
from ..spacy_vectors import SpacyVectors
|
|
||||||
from ... import util
|
from ... import util
|
||||||
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ..extract_ngrams import extract_ngrams
|
from ..extract_ngrams import extract_ngrams
|
||||||
|
from ..staticvectors import StaticVectors
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||||
|
@ -102,13 +102,7 @@ def build_text_classifier(
|
||||||
)
|
)
|
||||||
|
|
||||||
if pretrained_vectors:
|
if pretrained_vectors:
|
||||||
nlp = util.load_model(pretrained_vectors)
|
static_vectors = StaticVectors(width)
|
||||||
vectors = nlp.vocab.vectors
|
|
||||||
vector_dim = vectors.data.shape[1]
|
|
||||||
|
|
||||||
static_vectors = SpacyVectors(vectors) >> with_array(
|
|
||||||
Linear(width, vector_dim)
|
|
||||||
)
|
|
||||||
vector_layer = trained_vectors | static_vectors
|
vector_layer = trained_vectors | static_vectors
|
||||||
vectors_width = width * 2
|
vectors_width = width * 2
|
||||||
else:
|
else:
|
||||||
|
@ -159,16 +153,11 @@ def build_text_classifier(
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||||
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
||||||
nlp = util.load_model(pretrained_vectors)
|
|
||||||
vectors = nlp.vocab.vectors
|
|
||||||
vector_dim = vectors.data.shape[1]
|
|
||||||
|
|
||||||
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
||||||
with Model.define_operators({">>": chain, "**": clone}):
|
with Model.define_operators({">>": chain, "**": clone}):
|
||||||
model = (
|
model = (
|
||||||
SpacyVectors(vectors)
|
StaticVectors(width)
|
||||||
>> list2ragged()
|
>> list2ragged()
|
||||||
>> with_ragged(0, Linear(width, vector_dim))
|
|
||||||
>> ParametricAttention(width)
|
>> ParametricAttention(width)
|
||||||
>> reduce_sum()
|
>> reduce_sum()
|
||||||
>> residual(Relu(width, width)) ** 2
|
>> residual(Relu(width, width)) ** 2
|
||||||
|
|
|
@ -1,223 +1,140 @@
|
||||||
from thinc.api import chain, clone, concatenate, with_array, uniqued
|
from typing import Optional, List
|
||||||
from thinc.api import Model, noop, with_padded, Maxout, expand_window
|
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||||
from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
|
from thinc.api import Model, noop, list2ragged, ragged2list
|
||||||
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
|
from thinc.api import FeatureExtractor, HashEmbed
|
||||||
|
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
|
from ...tokens import Doc
|
||||||
from ... import util
|
from ... import util
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ...ml import _character_embed
|
from ...ml import _character_embed
|
||||||
|
from ..staticvectors import StaticVectors
|
||||||
from ...pipeline.tok2vec import Tok2VecListener
|
from ...pipeline.tok2vec import Tok2VecListener
|
||||||
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
|
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2VecTensors.v1")
|
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||||
def tok2vec_tensors_v1(width, upstream="*"):
|
def tok2vec_listener_v1(width, upstream="*"):
|
||||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.VocabVectors.v1")
|
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||||
def get_vocab_vectors(name):
|
def build_hash_embed_cnn_tok2vec(
|
||||||
nlp = util.load_model(name)
|
*,
|
||||||
return nlp.vocab.vectors
|
width: int,
|
||||||
|
depth: int,
|
||||||
|
embed_size: int,
|
||||||
|
window_size: int,
|
||||||
|
maxout_pieces: int,
|
||||||
|
subword_features: bool,
|
||||||
|
dropout: Optional[float],
|
||||||
|
pretrained_vectors: Optional[bool]
|
||||||
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||||
|
with subword features and a CNN with layer-normalized maxout."""
|
||||||
|
return build_Tok2Vec_model(
|
||||||
|
embed=MultiHashEmbed(
|
||||||
|
width=width,
|
||||||
|
rows=embed_size,
|
||||||
|
also_embed_subwords=subword_features,
|
||||||
|
also_use_static_vectors=bool(pretrained_vectors),
|
||||||
|
),
|
||||||
|
encode=MaxoutWindowEncoder(
|
||||||
|
width=width,
|
||||||
|
depth=depth,
|
||||||
|
window_size=window_size,
|
||||||
|
maxout_pieces=maxout_pieces
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||||
def Tok2Vec(extract, embed, encode):
|
def build_Tok2Vec_model(
|
||||||
field_size = 0
|
embed: Model[List[Doc], List[Floats2d]],
|
||||||
if encode.attrs.get("receptive_field", None):
|
encode: Model[List[Floats2d], List[Floats2d]],
|
||||||
field_size = encode.attrs["receptive_field"]
|
) -> Model[List[Doc], List[Floats2d]]:
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
|
||||||
tok2vec = extract >> with_array(embed >> encode, pad=field_size)
|
receptive_field = encode.attrs.get("receptive_field", 0)
|
||||||
|
tok2vec = chain(embed, with_array(encode, pad=receptive_field))
|
||||||
tok2vec.set_dim("nO", encode.get_dim("nO"))
|
tok2vec.set_dim("nO", encode.get_dim("nO"))
|
||||||
tok2vec.set_ref("embed", embed)
|
tok2vec.set_ref("embed", embed)
|
||||||
tok2vec.set_ref("encode", encode)
|
tok2vec.set_ref("encode", encode)
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.Doc2Feats.v1")
|
|
||||||
def Doc2Feats(columns):
|
|
||||||
return FeatureExtractor(columns)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
|
||||||
def hash_embed_cnn(
|
|
||||||
pretrained_vectors,
|
|
||||||
width,
|
|
||||||
depth,
|
|
||||||
embed_size,
|
|
||||||
maxout_pieces,
|
|
||||||
window_size,
|
|
||||||
subword_features,
|
|
||||||
dropout,
|
|
||||||
):
|
|
||||||
# Does not use character embeddings: set to False by default
|
|
||||||
return build_Tok2Vec_model(
|
|
||||||
width=width,
|
|
||||||
embed_size=embed_size,
|
|
||||||
pretrained_vectors=pretrained_vectors,
|
|
||||||
conv_depth=depth,
|
|
||||||
bilstm_depth=0,
|
|
||||||
maxout_pieces=maxout_pieces,
|
|
||||||
window_size=window_size,
|
|
||||||
subword_features=subword_features,
|
|
||||||
char_embed=False,
|
|
||||||
nM=0,
|
|
||||||
nC=0,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashCharEmbedCNN.v1")
|
|
||||||
def hash_charembed_cnn(
|
|
||||||
pretrained_vectors,
|
|
||||||
width,
|
|
||||||
depth,
|
|
||||||
embed_size,
|
|
||||||
maxout_pieces,
|
|
||||||
window_size,
|
|
||||||
nM,
|
|
||||||
nC,
|
|
||||||
dropout,
|
|
||||||
):
|
|
||||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
|
||||||
return build_Tok2Vec_model(
|
|
||||||
width=width,
|
|
||||||
embed_size=embed_size,
|
|
||||||
pretrained_vectors=pretrained_vectors,
|
|
||||||
conv_depth=depth,
|
|
||||||
bilstm_depth=0,
|
|
||||||
maxout_pieces=maxout_pieces,
|
|
||||||
window_size=window_size,
|
|
||||||
subword_features=False,
|
|
||||||
char_embed=True,
|
|
||||||
nM=nM,
|
|
||||||
nC=nC,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
|
|
||||||
def hash_embed_bilstm_v1(
|
|
||||||
pretrained_vectors,
|
|
||||||
width,
|
|
||||||
depth,
|
|
||||||
embed_size,
|
|
||||||
subword_features,
|
|
||||||
maxout_pieces,
|
|
||||||
dropout,
|
|
||||||
):
|
|
||||||
# Does not use character embeddings: set to False by default
|
|
||||||
return build_Tok2Vec_model(
|
|
||||||
width=width,
|
|
||||||
embed_size=embed_size,
|
|
||||||
pretrained_vectors=pretrained_vectors,
|
|
||||||
bilstm_depth=depth,
|
|
||||||
conv_depth=0,
|
|
||||||
maxout_pieces=maxout_pieces,
|
|
||||||
window_size=1,
|
|
||||||
subword_features=subword_features,
|
|
||||||
char_embed=False,
|
|
||||||
nM=0,
|
|
||||||
nC=0,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
|
|
||||||
def hash_char_embed_bilstm_v1(
|
|
||||||
pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC, dropout
|
|
||||||
):
|
|
||||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
|
||||||
return build_Tok2Vec_model(
|
|
||||||
width=width,
|
|
||||||
embed_size=embed_size,
|
|
||||||
pretrained_vectors=pretrained_vectors,
|
|
||||||
bilstm_depth=depth,
|
|
||||||
conv_depth=0,
|
|
||||||
maxout_pieces=maxout_pieces,
|
|
||||||
window_size=1,
|
|
||||||
subword_features=False,
|
|
||||||
char_embed=True,
|
|
||||||
nM=nM,
|
|
||||||
nC=nC,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
|
|
||||||
def LayerNormalizedMaxout(width, maxout_pieces):
|
|
||||||
return Maxout(nO=width, nP=maxout_pieces, dropout=0.0, normalize=True)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||||
def MultiHashEmbed(
|
def MultiHashEmbed(
|
||||||
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||||
):
|
):
|
||||||
norm = HashEmbed(
|
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||||
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6
|
|
||||||
)
|
seed = 7
|
||||||
if use_subwords:
|
|
||||||
prefix = HashEmbed(
|
def make_hash_embed(feature):
|
||||||
nO=width,
|
nonlocal seed
|
||||||
nV=rows // 2,
|
seed += 1
|
||||||
column=columns.index("PREFIX"),
|
return HashEmbed(
|
||||||
dropout=dropout,
|
width,
|
||||||
seed=7,
|
rows if feature == NORM else rows // 2,
|
||||||
)
|
column=cols.index(feature),
|
||||||
suffix = HashEmbed(
|
seed=seed,
|
||||||
nO=width,
|
dropout=0.0,
|
||||||
nV=rows // 2,
|
|
||||||
column=columns.index("SUFFIX"),
|
|
||||||
dropout=dropout,
|
|
||||||
seed=8,
|
|
||||||
)
|
|
||||||
shape = HashEmbed(
|
|
||||||
nO=width,
|
|
||||||
nV=rows // 2,
|
|
||||||
column=columns.index("SHAPE"),
|
|
||||||
dropout=dropout,
|
|
||||||
seed=9,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if pretrained_vectors:
|
if also_embed_subwords:
|
||||||
glove = StaticVectors(
|
embeddings = [
|
||||||
vectors=pretrained_vectors.data,
|
make_hash_embed(NORM),
|
||||||
nO=width,
|
make_hash_embed(PREFIX),
|
||||||
column=columns.index(ID),
|
make_hash_embed(SUFFIX),
|
||||||
dropout=dropout,
|
make_hash_embed(SHAPE),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
embeddings = [make_hash_embed(NORM)]
|
||||||
|
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||||
|
if also_use_static_vectors:
|
||||||
|
model = chain(
|
||||||
|
concatenate(
|
||||||
|
chain(
|
||||||
|
FeatureExtractor(cols),
|
||||||
|
list2ragged(),
|
||||||
|
with_array(concatenate(*embeddings)),
|
||||||
|
),
|
||||||
|
StaticVectors(width, dropout=0.0),
|
||||||
|
),
|
||||||
|
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||||
|
ragged2list(),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
model = chain(
|
||||||
if not use_subwords and not pretrained_vectors:
|
FeatureExtractor(cols),
|
||||||
embed_layer = norm
|
list2ragged(),
|
||||||
else:
|
with_array(concatenate(*embeddings)),
|
||||||
if use_subwords and pretrained_vectors:
|
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||||
concat_columns = glove | norm | prefix | suffix | shape
|
ragged2list(),
|
||||||
elif use_subwords:
|
)
|
||||||
concat_columns = norm | prefix | suffix | shape
|
return model
|
||||||
else:
|
|
||||||
concat_columns = glove | norm
|
|
||||||
|
|
||||||
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
|
|
||||||
|
|
||||||
return embed_layer
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||||
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
||||||
norm = HashEmbed(
|
model = chain(
|
||||||
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5
|
concatenate(
|
||||||
|
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||||
|
chain(
|
||||||
|
FeatureExtractor([NORM]),
|
||||||
|
list2ragged(),
|
||||||
|
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)),
|
||||||
|
ragged2list()
|
||||||
)
|
)
|
||||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
return model
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
|
||||||
embed_layer = chr_embed | features >> with_array(norm)
|
|
||||||
embed_layer.set_dim("nO", nM * nC + width)
|
|
||||||
return embed_layer
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
||||||
def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth):
|
def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: int):
|
||||||
cnn = chain(
|
cnn = chain(
|
||||||
expand_window(window_size=window_size),
|
expand_window(window_size=window_size),
|
||||||
Maxout(
|
Maxout(
|
||||||
|
@ -238,8 +155,12 @@ def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth):
|
||||||
def MishWindowEncoder(width, window_size, depth):
|
def MishWindowEncoder(width, window_size, depth):
|
||||||
cnn = chain(
|
cnn = chain(
|
||||||
expand_window(window_size=window_size),
|
expand_window(window_size=window_size),
|
||||||
Mish(nO=width, nI=width * ((window_size * 2) + 1)),
|
Mish(
|
||||||
LayerNorm(width),
|
nO=width,
|
||||||
|
nI=width * ((window_size * 2) + 1),
|
||||||
|
dropout=0.0,
|
||||||
|
normalize=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
model = clone(residual(cnn), depth)
|
model = clone(residual(cnn), depth)
|
||||||
model.set_dim("nO", width)
|
model.set_dim("nO", width)
|
||||||
|
@ -247,133 +168,7 @@ def MishWindowEncoder(width, window_size, depth):
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
||||||
def TorchBiLSTMEncoder(width, depth):
|
def BiLSTMEncoder(width, depth, dropout):
|
||||||
import torch.nn
|
|
||||||
|
|
||||||
# TODO FIX
|
|
||||||
from thinc.api import PyTorchRNNWrapper
|
|
||||||
|
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
return noop()
|
return noop()
|
||||||
return with_padded(
|
return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))
|
||||||
PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_Tok2Vec_model(
|
|
||||||
width,
|
|
||||||
embed_size,
|
|
||||||
pretrained_vectors,
|
|
||||||
window_size,
|
|
||||||
maxout_pieces,
|
|
||||||
subword_features,
|
|
||||||
char_embed,
|
|
||||||
nM,
|
|
||||||
nC,
|
|
||||||
conv_depth,
|
|
||||||
bilstm_depth,
|
|
||||||
dropout,
|
|
||||||
) -> Model:
|
|
||||||
if char_embed:
|
|
||||||
subword_features = False
|
|
||||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
|
||||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
|
||||||
norm = HashEmbed(
|
|
||||||
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0
|
|
||||||
)
|
|
||||||
if subword_features:
|
|
||||||
prefix = HashEmbed(
|
|
||||||
nO=width,
|
|
||||||
nV=embed_size // 2,
|
|
||||||
column=cols.index(PREFIX),
|
|
||||||
dropout=None,
|
|
||||||
seed=1,
|
|
||||||
)
|
|
||||||
suffix = HashEmbed(
|
|
||||||
nO=width,
|
|
||||||
nV=embed_size // 2,
|
|
||||||
column=cols.index(SUFFIX),
|
|
||||||
dropout=None,
|
|
||||||
seed=2,
|
|
||||||
)
|
|
||||||
shape = HashEmbed(
|
|
||||||
nO=width,
|
|
||||||
nV=embed_size // 2,
|
|
||||||
column=cols.index(SHAPE),
|
|
||||||
dropout=None,
|
|
||||||
seed=3,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prefix, suffix, shape = (None, None, None)
|
|
||||||
if pretrained_vectors is not None:
|
|
||||||
glove = StaticVectors(
|
|
||||||
vectors=pretrained_vectors.data,
|
|
||||||
nO=width,
|
|
||||||
column=cols.index(ID),
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
if subword_features:
|
|
||||||
columns = 5
|
|
||||||
embed = uniqued(
|
|
||||||
(glove | norm | prefix | suffix | shape)
|
|
||||||
>> Maxout(
|
|
||||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
|
||||||
),
|
|
||||||
column=cols.index(ORTH),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
columns = 2
|
|
||||||
embed = uniqued(
|
|
||||||
(glove | norm)
|
|
||||||
>> Maxout(
|
|
||||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
|
||||||
),
|
|
||||||
column=cols.index(ORTH),
|
|
||||||
)
|
|
||||||
elif subword_features:
|
|
||||||
columns = 4
|
|
||||||
embed = uniqued(
|
|
||||||
concatenate(norm, prefix, suffix, shape)
|
|
||||||
>> Maxout(
|
|
||||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
|
||||||
),
|
|
||||||
column=cols.index(ORTH),
|
|
||||||
)
|
|
||||||
elif char_embed:
|
|
||||||
embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | FeatureExtractor(
|
|
||||||
cols
|
|
||||||
) >> with_array(norm)
|
|
||||||
reduce_dimensions = Maxout(
|
|
||||||
nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embed = norm
|
|
||||||
|
|
||||||
convolution = residual(
|
|
||||||
expand_window(window_size=window_size)
|
|
||||||
>> Maxout(
|
|
||||||
nO=width,
|
|
||||||
nI=width * ((window_size * 2) + 1),
|
|
||||||
nP=maxout_pieces,
|
|
||||||
dropout=0.0,
|
|
||||||
normalize=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if char_embed:
|
|
||||||
tok2vec = embed >> with_array(
|
|
||||||
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tok2vec = FeatureExtractor(cols) >> with_array(
|
|
||||||
embed >> convolution ** conv_depth, pad=conv_depth
|
|
||||||
)
|
|
||||||
|
|
||||||
if bilstm_depth >= 1:
|
|
||||||
tok2vec = tok2vec >> PyTorchLSTM(
|
|
||||||
nO=width, nI=width, depth=bilstm_depth, bi=True
|
|
||||||
)
|
|
||||||
if tok2vec.has_dim("nO") is not False:
|
|
||||||
tok2vec.set_dim("nO", width)
|
|
||||||
tok2vec.set_ref("embed", embed)
|
|
||||||
return tok2vec
|
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
import numpy
|
|
||||||
from thinc.api import Model, Unserializable
|
|
||||||
|
|
||||||
|
|
||||||
def SpacyVectors(vectors) -> Model:
|
|
||||||
attrs = {"vectors": Unserializable(vectors)}
|
|
||||||
model = Model("spacy_vectors", forward, attrs=attrs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def forward(model, docs, is_train: bool):
|
|
||||||
batch = []
|
|
||||||
vectors = model.attrs["vectors"].obj
|
|
||||||
for doc in docs:
|
|
||||||
indices = numpy.zeros((len(doc),), dtype="i")
|
|
||||||
for i, word in enumerate(doc):
|
|
||||||
if word.orth in vectors.key2row:
|
|
||||||
indices[i] = vectors.key2row[word.orth]
|
|
||||||
else:
|
|
||||||
indices[i] = 0
|
|
||||||
batch_vectors = vectors.data[indices]
|
|
||||||
batch.append(batch_vectors)
|
|
||||||
|
|
||||||
def backprop(dY):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return batch, backprop
|
|
100
spacy/ml/staticvectors.py
Normal file
100
spacy/ml/staticvectors.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
from typing import List, Tuple, Callable, Optional, cast
|
||||||
|
|
||||||
|
from thinc.initializers import glorot_uniform_init
|
||||||
|
from thinc.util import partial
|
||||||
|
from thinc.types import Ragged, Floats2d, Floats1d
|
||||||
|
from thinc.api import Model, Ops, registry
|
||||||
|
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
@registry.layers("spacy.StaticVectors.v1")
|
||||||
|
def StaticVectors(
|
||||||
|
nO: Optional[int] = None,
|
||||||
|
nM: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
dropout: Optional[float] = None,
|
||||||
|
init_W: Callable = glorot_uniform_init,
|
||||||
|
key_attr: str = "ORTH"
|
||||||
|
) -> Model[List[Doc], Ragged]:
|
||||||
|
"""Embed Doc objects with their vocab's vectors table, applying a learned
|
||||||
|
linear projection to control the dimensionality. If a dropout rate is
|
||||||
|
specified, the dropout is applied per dimension over the whole batch.
|
||||||
|
"""
|
||||||
|
return Model(
|
||||||
|
"static_vectors",
|
||||||
|
forward,
|
||||||
|
init=partial(init, init_W),
|
||||||
|
params={"W": None},
|
||||||
|
attrs={"key_attr": key_attr, "dropout_rate": dropout},
|
||||||
|
dims={"nO": nO, "nM": nM},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
||||||
|
) -> Tuple[Ragged, Callable]:
|
||||||
|
if not len(docs):
|
||||||
|
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||||
|
key_attr = model.attrs["key_attr"]
|
||||||
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
|
V = cast(Floats2d, docs[0].vocab.vectors.data)
|
||||||
|
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
||||||
|
rows = model.ops.flatten(
|
||||||
|
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
||||||
|
)
|
||||||
|
output = Ragged(
|
||||||
|
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
||||||
|
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
|
||||||
|
)
|
||||||
|
if mask is not None:
|
||||||
|
output.data *= mask
|
||||||
|
|
||||||
|
def backprop(d_output: Ragged) -> List[Doc]:
|
||||||
|
if mask is not None:
|
||||||
|
d_output.data *= mask
|
||||||
|
model.inc_grad(
|
||||||
|
"W",
|
||||||
|
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
return output, backprop
|
||||||
|
|
||||||
|
|
||||||
|
def init(
|
||||||
|
init_W: Callable,
|
||||||
|
model: Model[List[Doc], Ragged],
|
||||||
|
X: Optional[List[Doc]] = None,
|
||||||
|
Y: Optional[Ragged] = None,
|
||||||
|
) -> Model[List[Doc], Ragged]:
|
||||||
|
nM = model.get_dim("nM") if model.has_dim("nM") else None
|
||||||
|
nO = model.get_dim("nO") if model.has_dim("nO") else None
|
||||||
|
if X is not None and len(X):
|
||||||
|
nM = X[0].vocab.vectors.data.shape[1]
|
||||||
|
if Y is not None:
|
||||||
|
nO = Y.data.shape[1]
|
||||||
|
|
||||||
|
if nM is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot initialize StaticVectors layer: nM dimension unset. "
|
||||||
|
"This dimension refers to the width of the vectors table."
|
||||||
|
)
|
||||||
|
if nO is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot initialize StaticVectors layer: nO dimension unset. "
|
||||||
|
"This dimension refers to the output width, after the linear "
|
||||||
|
"projection has been applied."
|
||||||
|
)
|
||||||
|
model.set_dim("nM", nM)
|
||||||
|
model.set_dim("nO", nO)
|
||||||
|
model.set_param("W", init_W(model.ops, (nO, nM)))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_empty(ops: Ops, nO: int):
|
||||||
|
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
|
||||||
|
return ops.get_dropout_mask((nO,), rate) if rate is not None else None
|
|
@ -22,17 +22,23 @@ default_model_config = """
|
||||||
@architectures = "spacy.Tagger.v1"
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.HashCharEmbedCNN.v1"
|
@architectures = "spacy.Tok2Vec.v1"
|
||||||
pretrained_vectors = null
|
|
||||||
|
[model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.CharacterEmbed.v1"
|
||||||
width = 128
|
width = 128
|
||||||
depth = 4
|
rows = 7000
|
||||||
embed_size = 7000
|
|
||||||
window_size = 1
|
|
||||||
maxout_pieces = 3
|
|
||||||
nM = 64
|
nM = 64
|
||||||
nC = 8
|
nC = 8
|
||||||
dropout = null
|
|
||||||
|
[model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||||
|
width = 128
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,7 +155,6 @@ class Morphologizer(Tagger):
|
||||||
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
|
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
|
||||||
self.set_output(len(self.labels))
|
self.set_output(len(self.labels))
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
util.link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -11,7 +11,6 @@ from .tagger import Tagger
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..syntax import nonproj
|
from ..syntax import nonproj
|
||||||
from ..attrs import POS, ID
|
from ..attrs import POS, ID
|
||||||
from ..util import link_vectors_to_models
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +90,6 @@ class MultitaskObjective(Tagger):
|
||||||
if label is not None and label not in self.labels:
|
if label is not None and label not in self.labels:
|
||||||
self.labels[label] = len(self.labels)
|
self.labels[label] = len(self.labels)
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
@ -179,7 +177,6 @@ class ClozeMultitask(Pipe):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
||||||
self.model.output_layer.begin_training(X)
|
self.model.output_layer.begin_training(X)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import srsly
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
|
||||||
from ..util import link_vectors_to_models, create_default_optimizer
|
from ..util import create_default_optimizer
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
@ -147,8 +147,6 @@ class Pipe:
|
||||||
DOCS: https://spacy.io/api/pipe#begin_training
|
DOCS: https://spacy.io/api/pipe#begin_training
|
||||||
"""
|
"""
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
if hasattr(self, "vocab"):
|
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -138,7 +138,6 @@ class SentenceRecognizer(Tagger):
|
||||||
"""
|
"""
|
||||||
self.set_output(len(self.labels))
|
self.set_output(len(self.labels))
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
util.link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -168,7 +168,6 @@ class SimpleNER(Pipe):
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||||
util.link_vectors_to_models(self.vocab)
|
|
||||||
self.loss_func = SequenceCategoricalCrossentropy(
|
self.loss_func = SequenceCategoricalCrossentropy(
|
||||||
names=self.get_tag_names(), normalize=True, missing_value=None
|
names=self.get_tag_names(), normalize=True, missing_value=None
|
||||||
)
|
)
|
||||||
|
|
|
@ -318,7 +318,6 @@ class Tagger(Pipe):
|
||||||
self.model.initialize(X=doc_sample)
|
self.model.initialize(X=doc_sample)
|
||||||
# Get batch of example docs, example outputs to call begin_training().
|
# Get batch of example docs, example outputs to call begin_training().
|
||||||
# This lets the model infer shapes.
|
# This lets the model infer shapes.
|
||||||
util.link_vectors_to_models(self.vocab)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -356,7 +356,6 @@ class TextCategorizer(Pipe):
|
||||||
docs = [Doc(Vocab(), words=["hello"])]
|
docs = [Doc(Vocab(), words=["hello"])]
|
||||||
truths, _ = self._examples_to_truth(examples)
|
truths, _ = self._examples_to_truth(examples)
|
||||||
self.set_output(len(self.labels))
|
self.set_output(len(self.labels))
|
||||||
util.link_vectors_to_models(self.vocab)
|
|
||||||
self.model.initialize(X=docs, Y=truths)
|
self.model.initialize(X=docs, Y=truths)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..util import link_vectors_to_models, minibatch
|
from ..util import minibatch
|
||||||
|
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
|
@ -196,9 +196,8 @@ class Tok2Vec(Pipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tok2vec#begin_training
|
DOCS: https://spacy.io/api/tok2vec#begin_training
|
||||||
"""
|
"""
|
||||||
docs = [Doc(Vocab(), words=["hello"])]
|
docs = [Doc(self.vocab, words=["hello"])]
|
||||||
self.model.initialize(X=docs)
|
self.model.initialize(X=docs)
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
|
|
||||||
|
|
||||||
class Tok2VecListener(Model):
|
class Tok2VecListener(Model):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .transition_system cimport Transition
|
||||||
|
|
||||||
from ..compat import copy_array
|
from ..compat import copy_array
|
||||||
from ..errors import Errors, TempErrors
|
from ..errors import Errors, TempErrors
|
||||||
from ..util import link_vectors_to_models, create_default_optimizer
|
from ..util import create_default_optimizer
|
||||||
from .. import util
|
from .. import util
|
||||||
from . import nonproj
|
from . import nonproj
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
from .transition_system cimport Transition
|
from .transition_system cimport Transition
|
||||||
|
|
||||||
from ..util import link_vectors_to_models, create_default_optimizer, registry
|
from ..util import create_default_optimizer, registry
|
||||||
from ..compat import copy_array
|
from ..compat import copy_array
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -456,7 +456,6 @@ cdef class Parser:
|
||||||
self.model.initialize()
|
self.model.initialize()
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def to_disk(self, path, exclude=tuple()):
|
def to_disk(self, path, exclude=tuple()):
|
||||||
|
|
|
@ -9,7 +9,6 @@ from spacy.matcher import Matcher
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.compat import pickle
|
from spacy.compat import pickle
|
||||||
from spacy.util import link_vectors_to_models
|
|
||||||
import numpy
|
import numpy
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
@ -190,7 +189,6 @@ def test_issue2871():
|
||||||
_ = vocab[word] # noqa: F841
|
_ = vocab[word] # noqa: F841
|
||||||
vocab.set_vector(word, vector_data[0])
|
vocab.set_vector(word, vector_data[0])
|
||||||
vocab.vectors.name = "dummy_vectors"
|
vocab.vectors.name = "dummy_vectors"
|
||||||
link_vectors_to_models(vocab)
|
|
||||||
assert vocab["dog"].rank == 0
|
assert vocab["dog"].rank == 0
|
||||||
assert vocab["cat"].rank == 1
|
assert vocab["cat"].rank == 1
|
||||||
assert vocab["SUFFIX"].rank == 2
|
assert vocab["SUFFIX"].rank == 2
|
||||||
|
|
|
@ -5,6 +5,7 @@ from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.util import registry, deep_merge_configs, load_model_from_config
|
from spacy.util import registry, deep_merge_configs, load_model_from_config
|
||||||
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
||||||
|
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ factory = "tagger"
|
||||||
@architectures = "spacy.Tagger.v1"
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
[components.tagger.model.tok2vec]
|
[components.tagger.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
width = ${components.tok2vec.model:width}
|
width = ${components.tok2vec.model:width}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -68,18 +69,18 @@ dropout = null
|
||||||
@registry.architectures.register("my_test_parser")
|
@registry.architectures.register("my_test_parser")
|
||||||
def my_parser():
|
def my_parser():
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
width=321,
|
MultiHashEmbed(
|
||||||
embed_size=5432,
|
width=321,
|
||||||
pretrained_vectors=None,
|
rows=5432,
|
||||||
window_size=3,
|
also_embed_subwords=True,
|
||||||
maxout_pieces=4,
|
also_use_static_vectors=False
|
||||||
subword_features=True,
|
),
|
||||||
char_embed=True,
|
MaxoutWindowEncoder(
|
||||||
nM=64,
|
width=321,
|
||||||
nC=8,
|
window_size=3,
|
||||||
conv_depth=2,
|
maxout_pieces=4,
|
||||||
bilstm_depth=0,
|
depth=2
|
||||||
dropout=None,
|
)
|
||||||
)
|
)
|
||||||
parser = build_tb_parser_model(
|
parser = build_tb_parser_model(
|
||||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
||||||
|
|
|
@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from spacy.ml.models import build_Tok2Vec_model
|
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||||
|
|
||||||
|
|
||||||
|
def get_textcat_kwargs():
|
||||||
|
return {
|
||||||
|
"width": 64,
|
||||||
|
"embed_size": 2000,
|
||||||
|
"pretrained_vectors": None,
|
||||||
|
"exclusive_classes": False,
|
||||||
|
"ngram_size": 1,
|
||||||
|
"window_size": 1,
|
||||||
|
"conv_depth": 2,
|
||||||
|
"dropout": None,
|
||||||
|
"nO": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_textcat_cnn_kwargs():
|
||||||
|
return {
|
||||||
|
"tok2vec": test_tok2vec(),
|
||||||
|
"exclusive_classes": False,
|
||||||
|
"nO": 13,
|
||||||
|
}
|
||||||
|
|
||||||
def get_all_params(model):
|
def get_all_params(model):
|
||||||
params = []
|
params = []
|
||||||
for node in model.walk():
|
for node in model.walk():
|
||||||
|
@ -35,50 +55,34 @@ def get_gradient(model, Y):
|
||||||
raise ValueError(f"Could not get gradient for type {type(Y)}")
|
raise ValueError(f"Could not get gradient for type {type(Y)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tok2vec_kwargs():
|
||||||
|
# This actually creates models, so seems best to put it in a function.
|
||||||
|
return {
|
||||||
|
"embed": MultiHashEmbed(
|
||||||
|
width=32,
|
||||||
|
rows=500,
|
||||||
|
also_embed_subwords=True,
|
||||||
|
also_use_static_vectors=False
|
||||||
|
),
|
||||||
|
"encode": MaxoutWindowEncoder(
|
||||||
|
width=32,
|
||||||
|
depth=2,
|
||||||
|
maxout_pieces=2,
|
||||||
|
window_size=1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec():
|
def test_tok2vec():
|
||||||
return build_Tok2Vec_model(**TOK2VEC_KWARGS)
|
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||||
|
|
||||||
|
|
||||||
TOK2VEC_KWARGS = {
|
|
||||||
"width": 96,
|
|
||||||
"embed_size": 2000,
|
|
||||||
"subword_features": True,
|
|
||||||
"char_embed": False,
|
|
||||||
"conv_depth": 4,
|
|
||||||
"bilstm_depth": 0,
|
|
||||||
"maxout_pieces": 4,
|
|
||||||
"window_size": 1,
|
|
||||||
"dropout": 0.1,
|
|
||||||
"nM": 0,
|
|
||||||
"nC": 0,
|
|
||||||
"pretrained_vectors": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
TEXTCAT_KWARGS = {
|
|
||||||
"width": 64,
|
|
||||||
"embed_size": 2000,
|
|
||||||
"pretrained_vectors": None,
|
|
||||||
"exclusive_classes": False,
|
|
||||||
"ngram_size": 1,
|
|
||||||
"window_size": 1,
|
|
||||||
"conv_depth": 2,
|
|
||||||
"dropout": None,
|
|
||||||
"nO": 7,
|
|
||||||
}
|
|
||||||
|
|
||||||
TEXTCAT_CNN_KWARGS = {
|
|
||||||
"tok2vec": test_tok2vec(),
|
|
||||||
"exclusive_classes": False,
|
|
||||||
"nO": 13,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,model_func,kwargs",
|
"seed,model_func,kwargs",
|
||||||
[
|
[
|
||||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS),
|
(0, build_Tok2Vec_model, get_tok2vec_kwargs()),
|
||||||
(0, build_text_classifier, TEXTCAT_KWARGS),
|
(0, build_text_classifier, get_textcat_kwargs()),
|
||||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS),
|
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models_initialize_consistently(seed, model_func, kwargs):
|
def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||||
|
@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,model_func,kwargs,get_X",
|
"seed,model_func,kwargs,get_X",
|
||||||
[
|
[
|
||||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
(0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||||
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
(0, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||||
|
@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,dropout,model_func,kwargs,get_X",
|
"seed,dropout,model_func,kwargs,get_X",
|
||||||
[
|
[
|
||||||
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
(0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||||
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
(0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||||
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
(0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
||||||
|
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
||||||
|
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
@ -13,18 +15,18 @@ def test_empty_doc():
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
doc = Doc(vocab, words=[])
|
doc = Doc(vocab, words=[])
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
width,
|
MultiHashEmbed(
|
||||||
embed_size,
|
width=width,
|
||||||
pretrained_vectors=None,
|
rows=embed_size,
|
||||||
conv_depth=4,
|
also_use_static_vectors=False,
|
||||||
bilstm_depth=0,
|
also_embed_subwords=True
|
||||||
window_size=1,
|
),
|
||||||
maxout_pieces=3,
|
MaxoutWindowEncoder(
|
||||||
subword_features=True,
|
width=width,
|
||||||
char_embed=False,
|
depth=4,
|
||||||
nM=64,
|
window_size=1,
|
||||||
nC=8,
|
maxout_pieces=3
|
||||||
dropout=None,
|
)
|
||||||
)
|
)
|
||||||
tok2vec.initialize()
|
tok2vec.initialize()
|
||||||
vectors, backprop = tok2vec.begin_update([doc])
|
vectors, backprop = tok2vec.begin_update([doc])
|
||||||
|
@ -38,18 +40,18 @@ def test_empty_doc():
|
||||||
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
batch = get_batch(batch_size)
|
batch = get_batch(batch_size)
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
width,
|
MultiHashEmbed(
|
||||||
embed_size,
|
width=width,
|
||||||
pretrained_vectors=None,
|
rows=embed_size,
|
||||||
conv_depth=4,
|
also_use_static_vectors=False,
|
||||||
bilstm_depth=0,
|
also_embed_subwords=True
|
||||||
window_size=1,
|
),
|
||||||
maxout_pieces=3,
|
MaxoutWindowEncoder(
|
||||||
subword_features=True,
|
width=width,
|
||||||
char_embed=False,
|
depth=4,
|
||||||
nM=64,
|
window_size=1,
|
||||||
nC=8,
|
maxout_pieces=3,
|
||||||
dropout=None,
|
)
|
||||||
)
|
)
|
||||||
tok2vec.initialize()
|
tok2vec.initialize()
|
||||||
vectors, backprop = tok2vec.begin_update(batch)
|
vectors, backprop = tok2vec.begin_update(batch)
|
||||||
|
@ -60,24 +62,25 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tok2vec_config",
|
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||||
[
|
[
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def test_tok2vec_configs(tok2vec_config):
|
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
||||||
|
embed_config["width"] = width
|
||||||
|
encode_config["width"] = width
|
||||||
docs = get_batch(3)
|
docs = get_batch(3)
|
||||||
tok2vec = build_Tok2Vec_model(**tok2vec_config)
|
tok2vec = build_Tok2Vec_model(
|
||||||
|
embed_arch(**embed_config),
|
||||||
|
encode_arch(**encode_config)
|
||||||
|
)
|
||||||
tok2vec.initialize(docs)
|
tok2vec.initialize(docs)
|
||||||
vectors, backprop = tok2vec.begin_update(docs)
|
vectors, backprop = tok2vec.begin_update(docs)
|
||||||
assert len(vectors) == len(docs)
|
assert len(vectors) == len(docs)
|
||||||
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])
|
assert vectors[0].shape == (len(docs[0]), width)
|
||||||
backprop(vectors)
|
backprop(vectors)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import importlib.util
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import thinc
|
import thinc
|
||||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer, Model
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
@ -24,6 +24,8 @@ import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import shlex
|
import shlex
|
||||||
import inspect
|
import inspect
|
||||||
|
from thinc.types import Unserializable
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cupy.random
|
import cupy.random
|
||||||
|
@ -187,6 +189,20 @@ def get_module_path(module: ModuleType) -> Path:
|
||||||
return Path(sys.modules[module.__module__].__file__).parent
|
return Path(sys.modules[module.__module__].__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
def load_vectors_into_model(
|
||||||
|
nlp: "Language", name: Union[str, Path], *, add_strings=True
|
||||||
|
) -> None:
|
||||||
|
"""Load word vectors from an installed model or path into a model instance."""
|
||||||
|
vectors_nlp = load_model(name)
|
||||||
|
nlp.vocab.vectors = vectors_nlp.vocab.vectors
|
||||||
|
if add_strings:
|
||||||
|
# I guess we should add the strings from the vectors_nlp model?
|
||||||
|
# E.g. if someone does a similarity query, they might expect the strings.
|
||||||
|
for key in nlp.vocab.vectors.key2row:
|
||||||
|
if key in vectors_nlp.vocab.strings:
|
||||||
|
nlp.vocab.strings.add(vectors_nlp.vocab.strings[key])
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: Union[str, Path],
|
name: Union[str, Path],
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
|
@ -1184,22 +1200,6 @@ class DummyTokenizer:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def link_vectors_to_models(vocab: "Vocab") -> None:
|
|
||||||
vectors = vocab.vectors
|
|
||||||
if vectors.name is None:
|
|
||||||
vectors.name = VECTORS_KEY
|
|
||||||
if vectors.data.size != 0:
|
|
||||||
warnings.warn(Warnings.W020.format(shape=vectors.data.shape))
|
|
||||||
for word in vocab:
|
|
||||||
if word.orth in vectors.key2row:
|
|
||||||
word.rank = vectors.key2row[word.orth]
|
|
||||||
else:
|
|
||||||
word.rank = 0
|
|
||||||
|
|
||||||
|
|
||||||
VECTORS_KEY = "spacy_pretrained_vectors"
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_optimizer() -> Optimizer:
|
def create_default_optimizer() -> Optimizer:
|
||||||
# TODO: Do we still want to allow env_opt?
|
# TODO: Do we still want to allow env_opt?
|
||||||
learn_rate = env_opt("learn_rate", 0.001)
|
learn_rate = env_opt("learn_rate", 0.001)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from .errors import Errors
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .attrs import intify_attrs, NORM, IS_STOP
|
from .attrs import intify_attrs, NORM, IS_STOP
|
||||||
from .vectors import Vectors
|
from .vectors import Vectors
|
||||||
from .util import link_vectors_to_models, registry
|
from .util import registry
|
||||||
from .lookups import Lookups, load_lookups
|
from .lookups import Lookups, load_lookups
|
||||||
from . import util
|
from . import util
|
||||||
from .lang.norm_exceptions import BASE_NORMS
|
from .lang.norm_exceptions import BASE_NORMS
|
||||||
|
@ -344,7 +344,6 @@ cdef class Vocab:
|
||||||
synonym = self.strings[syn_keys[i][0]]
|
synonym = self.strings[syn_keys[i][0]]
|
||||||
score = scores[i][0]
|
score = scores[i][0]
|
||||||
remap[word] = (synonym, score)
|
remap[word] = (synonym, score)
|
||||||
link_vectors_to_models(self)
|
|
||||||
return remap
|
return remap
|
||||||
|
|
||||||
def get_vector(self, orth, minn=None, maxn=None):
|
def get_vector(self, orth, minn=None, maxn=None):
|
||||||
|
@ -476,8 +475,6 @@ cdef class Vocab:
|
||||||
if "vectors" not in exclude:
|
if "vectors" not in exclude:
|
||||||
if self.vectors is not None:
|
if self.vectors is not None:
|
||||||
self.vectors.from_disk(path, exclude=["strings"])
|
self.vectors.from_disk(path, exclude=["strings"])
|
||||||
if self.vectors.name is not None:
|
|
||||||
link_vectors_to_models(self)
|
|
||||||
if "lookups" not in exclude:
|
if "lookups" not in exclude:
|
||||||
self.lookups.from_disk(path)
|
self.lookups.from_disk(path)
|
||||||
if "lexeme_norm" in self.lookups:
|
if "lexeme_norm" in self.lookups:
|
||||||
|
@ -537,8 +534,6 @@ cdef class Vocab:
|
||||||
)
|
)
|
||||||
self.length = 0
|
self.length = 0
|
||||||
self._by_orth = PreshMap()
|
self._by_orth = PreshMap()
|
||||||
if self.vectors.name is not None:
|
|
||||||
link_vectors_to_models(self)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _reset_cache(self, keys, strings):
|
def _reset_cache(self, keys, strings):
|
||||||
|
|
|
@ -5,54 +5,82 @@ menu:
|
||||||
- ['Other Embeddings', 'embeddings']
|
- ['Other Embeddings', 'embeddings']
|
||||||
---
|
---
|
||||||
|
|
||||||
<!-- TODO: rewrite and include both details on word vectors, other word embeddings, spaCy transformers, doc.tensor, tok2vec -->
|
|
||||||
|
|
||||||
## Word vectors and similarity
|
## Word vectors and similarity
|
||||||
|
|
||||||
> #### Training word vectors
|
An old idea in linguistics is that you can "know a word by the company it
|
||||||
>
|
keeps": that is, word meanings can be understood relationally, based on their
|
||||||
> Dense, real valued vectors representing distributional similarity information
|
patterns of usage. This idea inspired a branch of NLP research known as
|
||||||
> are now a cornerstone of practical NLP. The most common way to train these
|
"distributional semantics" that has aimed to compute databases of lexical knowledge
|
||||||
> vectors is the [Word2vec](https://en.wikipedia.org/wiki/Word2vec) family of
|
automatically. The [Word2vec](https://en.wikipedia.org/wiki/Word2vec) family of
|
||||||
> algorithms. If you need to train a word2vec model, we recommend the
|
algorithms are a key milestone in this line of research. For simplicity, we
|
||||||
> implementation in the Python library
|
will refer to a distributional word representation as a "word vector", and
|
||||||
> [Gensim](https://radimrehurek.com/gensim/).
|
algorithms that computes word vectors (such as GloVe, FastText, etc) as
|
||||||
|
"word2vec algorithms".
|
||||||
|
|
||||||
import Vectors101 from 'usage/101/\_vectors-similarity.md'
|
Word vector tables are included in some of the spaCy model packages we
|
||||||
|
distribute, and you can easily create your own model packages with word vectors
|
||||||
|
you train or download yourself. In some cases you can also add word vectors to
|
||||||
|
an existing pipeline, although each pipeline can only have a single word
|
||||||
|
vectors table, and a model package that already has word vectors is unlikely to
|
||||||
|
work correctly if you replace the vectors with new ones.
|
||||||
|
|
||||||
<Vectors101 />
|
## What's a word vector?
|
||||||
|
|
||||||
### Customizing word vectors {#custom}
|
For spaCy's purposes, a "word vector" is a 1-dimensional slice from
|
||||||
|
a 2-dimensional _vectors table_, with a deterministic mapping from word types
|
||||||
|
to rows in the table.
|
||||||
|
|
||||||
Word vectors let you import knowledge from raw text into your model. The
|
```python
|
||||||
knowledge is represented as a table of numbers, with one row per term in your
|
def what_is_a_word_vector(
|
||||||
vocabulary. If two terms are used in similar contexts, the algorithm that learns
|
word_id: int,
|
||||||
the vectors should assign them **rows that are quite similar**, while words that
|
key2row: Dict[int, int],
|
||||||
are used in different contexts will have quite different values. This lets you
|
vectors_table: Floats2d,
|
||||||
use the row-values assigned to the words as a kind of dictionary, to tell you
|
*,
|
||||||
some things about what the words in your text mean.
|
default_row: int=0
|
||||||
|
) -> Floats1d:
|
||||||
|
return vectors_table[key2row.get(word_id, default_row)]
|
||||||
|
```
|
||||||
|
|
||||||
Word vectors are particularly useful for terms which **aren't well represented
|
word2vec algorithms try to produce vectors tables that let you estimate useful
|
||||||
in your labelled training data**. For instance, if you're doing named entity
|
relationships between words using simple linear algebra operations. For
|
||||||
recognition, there will always be lots of names that you don't have examples of.
|
instance, you can often find close synonyms of a word by finding the vectors
|
||||||
For instance, imagine your training data happens to contain some examples of the
|
closest to it by cosine distance, and then finding the words that are mapped to
|
||||||
term "Microsoft", but it doesn't contain any examples of the term "Symantec". In
|
those neighboring vectors. Word vectors can also be useful as features in
|
||||||
your raw text sample, there are plenty of examples of both terms, and they're
|
statistical models.
|
||||||
used in similar contexts. The word vectors make that fact available to the
|
|
||||||
entity recognition model. It still won't see examples of "Symantec" labelled as
|
|
||||||
a company. However, it'll see that "Symantec" has a word vector that usually
|
|
||||||
corresponds to company terms, so it can **make the inference**.
|
|
||||||
|
|
||||||
In order to make best use of the word vectors, you want the word vectors table
|
The key difference between word vectors and contextual language models such as
|
||||||
to cover a **very large vocabulary**. However, most words are rare, so most of
|
ElMo, BERT and GPT-2 is that word vectors model _lexical types_, rather than
|
||||||
the rows in a large word vectors table will be accessed very rarely, or never at
|
_tokens_. If you have a list of terms with no context around them, a model like
|
||||||
all. You can usually cover more than **95% of the tokens** in your corpus with
|
BERT can't really help you. BERT is designed to understand language in context,
|
||||||
just **a few thousand rows** in the vector table. However, it's those **5% of
|
which isn't what you have. A word vectors table will be a much better fit for
|
||||||
rare terms** where the word vectors are **most useful**. The problem is that
|
your task. However, if you do have words in context --- whole sentences or
|
||||||
increasing the size of the vector table produces rapidly diminishing returns in
|
paragraphs of running text --- word vectors will only provide a very rough
|
||||||
coverage over these rare terms.
|
approximation of what the text is about.
|
||||||
|
|
||||||
### Converting word vectors for use in spaCy {#converting new="2.0.10"}
|
Word vectors are also very computationally efficient, as they map a word to a
|
||||||
|
vector with a single indexing operation. Word vectors are therefore useful as a
|
||||||
|
way to improve the accuracy of neural network models, especially models that
|
||||||
|
are small or have received little or no pretraining. In spaCy, word vector
|
||||||
|
tables are only used as static features. spaCy does not backpropagate gradients
|
||||||
|
to the pretrained word vectors table. The static vectors table is usually used
|
||||||
|
in combination with a smaller table of learned task-specific embeddings.
|
||||||
|
|
||||||
|
## Using word vectors directly
|
||||||
|
|
||||||
|
spaCy stores word vector information in the `vocab.vectors` attribute, so you
|
||||||
|
can access the whole vectors table from most spaCy objects. You can also access
|
||||||
|
the vector for a `Doc`, `Span`, `Token` or `Lexeme` instance via the `vector`
|
||||||
|
attribute. If your `Doc` or `Span` has multiple tokens, the average of the
|
||||||
|
word vectors will be returned, excluding any "out of vocabulary" entries that
|
||||||
|
have no vector available. If none of the words have a vector, a zeroed vector
|
||||||
|
will be returned.
|
||||||
|
|
||||||
|
The `vector` attribute is a read-only numpy or cupy array (depending on whether
|
||||||
|
you've configured spaCy to use GPU memory), with dtype `float32`. The array is
|
||||||
|
read-only so that spaCy can avoid unnecessary copy operations where possible.
|
||||||
|
You can modify the vectors via the `Vocab` or `Vectors` table.
|
||||||
|
|
||||||
|
### Converting word vectors for use in spaCy
|
||||||
|
|
||||||
Custom word vectors can be trained using a number of open-source libraries, such
|
Custom word vectors can be trained using a number of open-source libraries, such
|
||||||
as [Gensim](https://radimrehurek.com/gensim), [Fast Text](https://fasttext.cc),
|
as [Gensim](https://radimrehurek.com/gensim), [Fast Text](https://fasttext.cc),
|
||||||
|
@ -151,20 +179,7 @@ This will create a spaCy model with vectors for the first 10,000 words in the
|
||||||
vectors model. All other words in the vectors model are mapped to the closest
|
vectors model. All other words in the vectors model are mapped to the closest
|
||||||
vector among those retained.
|
vector among those retained.
|
||||||
|
|
||||||
### Adding vectors {#custom-vectors-add new="2"}
|
### Adding vectors
|
||||||
|
|
||||||
spaCy's new [`Vectors`](/api/vectors) class greatly improves the way word
|
|
||||||
vectors are stored, accessed and used. The data is stored in two structures:
|
|
||||||
|
|
||||||
- An array, which can be either on CPU or [GPU](#gpu).
|
|
||||||
- A dictionary mapping string-hashes to rows in the table.
|
|
||||||
|
|
||||||
Keep in mind that the `Vectors` class itself has no
|
|
||||||
[`StringStore`](/api/stringstore), so you have to store the hash-to-string
|
|
||||||
mapping separately. If you need to manage the strings, you should use the
|
|
||||||
`Vectors` via the [`Vocab`](/api/vocab) class, e.g. `vocab.vectors`. To add
|
|
||||||
vectors to the vocabulary, you can use the
|
|
||||||
[`Vocab.set_vector`](/api/vocab#set_vector) method.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### Adding vectors
|
### Adding vectors
|
||||||
|
@ -196,38 +211,3 @@ For more details on **adding hooks** and **overwriting** the built-in `Doc`,
|
||||||
|
|
||||||
### Storing vectors on a GPU {#gpu}
|
### Storing vectors on a GPU {#gpu}
|
||||||
|
|
||||||
If you're using a GPU, it's much more efficient to keep the word vectors on the
|
|
||||||
device. You can do that by setting the [`Vectors.data`](/api/vectors#attributes)
|
|
||||||
attribute to a `cupy.ndarray` object if you're using spaCy or
|
|
||||||
[Chainer](https://chainer.org), or a `torch.Tensor` object if you're using
|
|
||||||
[PyTorch](http://pytorch.org). The `data` object just needs to support
|
|
||||||
`__iter__` and `__getitem__`, so if you're using another library such as
|
|
||||||
[TensorFlow](https://www.tensorflow.org), you could also create a wrapper for
|
|
||||||
your vectors data.
|
|
||||||
|
|
||||||
```python
|
|
||||||
### spaCy, Thinc or Chainer
|
|
||||||
import cupy.cuda
|
|
||||||
from spacy.vectors import Vectors
|
|
||||||
|
|
||||||
vector_table = numpy.zeros((3, 300), dtype="f")
|
|
||||||
vectors = Vectors(["dog", "cat", "orange"], vector_table)
|
|
||||||
with cupy.cuda.Device(0):
|
|
||||||
vectors.data = cupy.asarray(vectors.data)
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
### PyTorch
|
|
||||||
import torch
|
|
||||||
from spacy.vectors import Vectors
|
|
||||||
|
|
||||||
vector_table = numpy.zeros((3, 300), dtype="f")
|
|
||||||
vectors = Vectors(["dog", "cat", "orange"], vector_table)
|
|
||||||
vectors.data = torch.Tensor(vectors.data).cuda(0)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Other embeddings {#embeddings}
|
|
||||||
|
|
||||||
<!-- TODO: explain spacy-transformers, doc.tensor, tok2vec? -->
|
|
||||||
|
|
||||||
<!-- TODO: mention sense2vec somewhere? -->
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user