From cb9654e98c6d2fe34cedd7d8dc43e233d133ba84 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 00:52:50 +0200 Subject: [PATCH 01/37] WIP on new StaticVectors --- spacy/ml/models/tok2vec.py | 46 ++++++++++++++++++++------------------ spacy/util.py | 27 ++++++++++++++-------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 1766fa80e..caa9c467c 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,7 +1,9 @@ +from typing import Optional, List from thinc.api import chain, clone, concatenate, with_array, uniqued from thinc.api import Model, noop, with_padded, Maxout, expand_window from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM from thinc.api import residual, LayerNorm, FeatureExtractor, Mish +from thinc.types import Floats2d from ... import util from ...util import registry @@ -42,15 +44,15 @@ def Doc2Feats(columns): @registry.architectures.register("spacy.HashEmbedCNN.v1") def hash_embed_cnn( - pretrained_vectors, - width, - depth, - embed_size, - maxout_pieces, - window_size, - subword_features, - dropout, -): + pretrained_vectors: str, + width: int, + depth: int, + embed_size: int, + maxout_pieces: int, + window_size: int, + subword_features: bool, + dropout: float, +) -> Model[List[Doc], List[Floats2d]: # Does not use character embeddings: set to False by default return build_Tok2Vec_model( width=width, @@ -182,7 +184,7 @@ def MultiHashEmbed( if pretrained_vectors: glove = StaticVectors( - vectors=pretrained_vectors.data, + vectors_name=pretrained_vectors, nO=width, column=columns.index(ID), dropout=dropout, @@ -261,18 +263,18 @@ def TorchBiLSTMEncoder(width, depth): def build_Tok2Vec_model( - width, - embed_size, - pretrained_vectors, - window_size, - maxout_pieces, - subword_features, - char_embed, - nM, - nC, - conv_depth, - bilstm_depth, - dropout, + width: int, + embed_size: int, + pretrained_vectors: Optional[str], + window_size: int, + maxout_pieces: int, + subword_features: bool, + char_embed: bool, + nM: int, + nC: int, + conv_depth: int, + bilstm_depth: int, + dropout: float, ) -> Model: if char_embed: subword_features = False diff --git a/spacy/util.py b/spacy/util.py index d1951145f..de6d9831b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -24,6 +24,8 @@ import tempfile import shutil import shlex import inspect +from thinc.types import Unserializable + try: import cupy.random @@ -1184,20 +1186,27 @@ class DummyTokenizer: return self -def link_vectors_to_models(vocab: "Vocab") -> None: +def link_vectors_to_models( + vocab: "Vocab", + models: List[Model]=[], + *, + vectors_name_attr="vectors_name", + vectors_attr="vectors", + key2row_attr="key2row", + default_vectors_name="spacy_pretrained_vectors" +) -> None: + """Supply vectors data to models.""" vectors = vocab.vectors if vectors.name is None: - vectors.name = VECTORS_KEY + vectors.name = default_vectors_name if vectors.data.size != 0: warnings.warn(Warnings.W020.format(shape=vectors.data.shape)) - for word in vocab: - if word.orth in vectors.key2row: - word.rank = vectors.key2row[word.orth] - else: - word.rank = 0 - -VECTORS_KEY = "spacy_pretrained_vectors" + for model in models: + for node in model.walk(): + if node.attrs.get(vectors_name_attr) == vectors.name: + node.attrs[vectors_attr] = Unserializable(vectors.data) + node.attrs[key2row_attr] = Unserializable(vectors.key2row) def create_default_optimizer() -> Optimizer: From 9cc72622248808a7cd6807ed0d2f3afbfef4770b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 12:17:09 +0200 Subject: [PATCH 02/37] Draft StaticVectors layer --- spacy/ml/staticvectors.py | 98 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 spacy/ml/staticvectors.py diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py new file mode 100644 index 000000000..4c9e53563 --- /dev/null +++ b/spacy/ml/staticvectors.py @@ -0,0 +1,98 @@ +from typing import List, Tuple, Callable, Optional, cast + +from thinc.initializers import glorot_uniform_init +from thinc.util import partial +from thinc.types import Ragged, Floats2d, Floats1d +from thinc.api import Model, Ops, registry + +from ..tokens import Doc + + +@registry.layers("spacy.StaticVectors.v1") +def StaticVectors( + nO: Optional[int] = None, + nM: Optional[int] = None, + *, + dropout: Optional[float] = None, + init_W: Callable = glorot_uniform_init, + key_attr: str="ORTH" +) -> Model[List[Doc], Ragged]: + """Embed Doc objects with their vocab's vectors table, applying a learned + linear projection to control the dimensionality. If a dropout rate is + specified, the dropout is applied per dimension over the whole batch. + """ + return Model( + "static_vectors", + forward, + init=partial(init, init_W), + params={"W": None}, + attrs={"key_attr": key_attr, "dropout_rate": dropout}, + dims={"nO": nO, "nM": nM}, + ) + + +def forward( + model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool +) -> Tuple[Ragged, Callable]: + if not len(docs): + return _handle_empty(model.ops, model.get_dim("nO")) + key_attr = model.attrs["key_attr"] + W = cast(Floats2d, model.get_param("W")) + V = cast(Floats2d, docs[0].vocab.vectors.data) + mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) + + rows = model.ops.flatten( + [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] + ) + output = Ragged( + model.ops.gemm(V[rows], W, trans2=True), + model.ops.asarray([len(doc) for doc in docs], dtype="i") + ) + if mask is not None: + output.data *= mask + + def backprop(d_output: Ragged) -> List[Doc]: + if mask is not None: + d_output.data *= mask + model.inc_grad("W", model.ops.gemm(d_output.data, V[rows], trans1=True)) + return [] + + return output, backprop + + +def init( + init_W: Callable, + model: Model[List[Doc], Ragged], + X: Optional[List[Doc]] = None, + Y: Optional[Ragged] = None, +) -> Model[List[Doc], Ragged]: + nM = model.get_dim("nM") if model.has_dim("nM") else None + nO = model.get_dim("nO") if model.has_dim("nO") else None + if X is not None and len(X): + nM = X[0].vocab.vectors.data.shape[1] + if Y is not None: + nO = Y.data.shape[1] + + if nM is None: + raise ValueError( + "Cannot initialize StaticVectors layer: nM dimension unset. " + "This dimension refers to the width of the vectors table." + ) + if nO is None: + raise ValueError( + "Cannot initialize StaticVectors layer: nO dimension unset. " + "This dimension refers to the output width, after the linear " + "projection has been applied." + ) + model.set_dim("nM", nM) + model.set_dim("nO", nO) + model.set_param("W", init_W(model.ops, (nO, nM))) + return model + + +def _handle_empty(ops: Ops, nO: int): + return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: [] + + +def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]: + return ops.get_dropout_mask((nO,), rate) if rate is not None else None From c6b4f63c7c96a8c1dd52bb3afc1aade8fbfdfc3a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 12:18:28 +0200 Subject: [PATCH 03/37] Remove obsolete function --- spacy/ml/spacy_vectors.py | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 spacy/ml/spacy_vectors.py diff --git a/spacy/ml/spacy_vectors.py b/spacy/ml/spacy_vectors.py deleted file mode 100644 index 2a4988494..000000000 --- a/spacy/ml/spacy_vectors.py +++ /dev/null @@ -1,27 +0,0 @@ -import numpy -from thinc.api import Model, Unserializable - - -def SpacyVectors(vectors) -> Model: - attrs = {"vectors": Unserializable(vectors)} - model = Model("spacy_vectors", forward, attrs=attrs) - return model - - -def forward(model, docs, is_train: bool): - batch = [] - vectors = model.attrs["vectors"].obj - for doc in docs: - indices = numpy.zeros((len(doc),), dtype="i") - for i, word in enumerate(doc): - if word.orth in vectors.key2row: - indices[i] = vectors.key2row[word.orth] - else: - indices[i] = 0 - batch_vectors = vectors.data[indices] - batch.append(batch_vectors) - - def backprop(dY): - return None - - return batch, backprop From 123f8b832d7c00e5479e3d814bbaadb54ba54966 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 13:51:43 +0200 Subject: [PATCH 04/37] Refactor Tok2Vec model --- spacy/ml/models/tok2vec.py | 365 ++++++------------------------------- 1 file changed, 57 insertions(+), 308 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index caa9c467c..4bcd61625 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,8 +1,8 @@ from typing import Optional, List -from thinc.api import chain, clone, concatenate, with_array, uniqued -from thinc.api import Model, noop, with_padded, Maxout, expand_window -from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM -from thinc.api import residual, LayerNorm, FeatureExtractor, Mish +from thinc.api import chain, clone, concatenate, with_array, with_padded +from thinc.api import Model, noop +from thinc.api import FeatureExtractor, HashEmbed, StaticVectors +from thincapi import expand_window, residual, Maxout, Mish from thinc.types import Floats2d from ... import util @@ -12,199 +12,72 @@ from ...pipeline.tok2vec import Tok2VecListener from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE -@registry.architectures.register("spacy.Tok2VecTensors.v1") -def tok2vec_tensors_v1(width, upstream="*"): +@registry.architectures.register("spacy.Tok2VecListener.v1") +def tok2vec_listener_v1(width, upstream="*"): tok2vec = Tok2VecListener(upstream_name=upstream, width=width) return tok2vec -@registry.architectures.register("spacy.VocabVectors.v1") -def get_vocab_vectors(name): - nlp = util.load_model(name) - return nlp.vocab.vectors - - @registry.architectures.register("spacy.Tok2Vec.v1") -def Tok2Vec(extract, embed, encode): - field_size = 0 - if encode.attrs.get("receptive_field", None): - field_size = encode.attrs["receptive_field"] - with Model.define_operators({">>": chain, "|": concatenate}): - tok2vec = extract >> with_array(embed >> encode, pad=field_size) +def Tok2Vec( + embed: Model[List[Doc], List[Floats2d]], + encode: Model[List[Floats2d], List[Floats2d] +) -> Model[List[Doc], List[Floats2d]]: + tok2vec = with_array( + chain(embed, encode), + pad=encode.attrs.get("receptive_field", 0) + ) tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_ref("embed", embed) tok2vec.set_ref("encode", encode) return tok2vec -@registry.architectures.register("spacy.Doc2Feats.v1") -def Doc2Feats(columns): - return FeatureExtractor(columns) - - -@registry.architectures.register("spacy.HashEmbedCNN.v1") -def hash_embed_cnn( - pretrained_vectors: str, +@registry.architectures.register("spacy.HashEmbed.v1") +def HashEmbed( width: int, - depth: int, - embed_size: int, - maxout_pieces: int, - window_size: int, - subword_features: bool, - dropout: float, -) -> Model[List[Doc], List[Floats2d]: - # Does not use character embeddings: set to False by default - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - conv_depth=depth, - bilstm_depth=0, - maxout_pieces=maxout_pieces, - window_size=window_size, - subword_features=subword_features, - char_embed=False, - nM=0, - nC=0, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.HashCharEmbedCNN.v1") -def hash_charembed_cnn( - pretrained_vectors, - width, - depth, - embed_size, - maxout_pieces, - window_size, - nM, - nC, - dropout, + rows: int, + also_embed_subwords: bool, + also_use_static_vectors: bool ): - # Allows using character embeddings by setting nC, nM and char_embed=True - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - conv_depth=depth, - bilstm_depth=0, - maxout_pieces=maxout_pieces, - window_size=window_size, - subword_features=False, - char_embed=True, - nM=nM, - nC=nC, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.HashEmbedBiLSTM.v1") -def hash_embed_bilstm_v1( - pretrained_vectors, - width, - depth, - embed_size, - subword_features, - maxout_pieces, - dropout, -): - # Does not use character embeddings: set to False by default - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - bilstm_depth=depth, - conv_depth=0, - maxout_pieces=maxout_pieces, - window_size=1, - subword_features=subword_features, - char_embed=False, - nM=0, - nC=0, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1") -def hash_char_embed_bilstm_v1( - pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC, dropout -): - # Allows using character embeddings by setting nC, nM and char_embed=True - return build_Tok2Vec_model( - width=width, - embed_size=embed_size, - pretrained_vectors=pretrained_vectors, - bilstm_depth=depth, - conv_depth=0, - maxout_pieces=maxout_pieces, - window_size=1, - subword_features=False, - char_embed=True, - nM=nM, - nC=nC, - dropout=dropout, - ) - - -@registry.architectures.register("spacy.LayerNormalizedMaxout.v1") -def LayerNormalizedMaxout(width, maxout_pieces): - return Maxout(nO=width, nP=maxout_pieces, dropout=0.0, normalize=True) - - -@registry.architectures.register("spacy.MultiHashEmbed.v1") -def MultiHashEmbed( - columns, width, rows, use_subwords, pretrained_vectors, mix, dropout -): - norm = HashEmbed( - nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6 - ) - if use_subwords: - prefix = HashEmbed( - nO=width, - nV=rows // 2, - column=columns.index("PREFIX"), - dropout=dropout, - seed=7, + cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH] + + seed = 7 + def make_hash_embed(feature): + nonlocal seed + seed += 1 + return HashEmbed( + width, + rows if feature == NORM else rows // 2, + column=cols.index(feature), + seed=seed ) - suffix = HashEmbed( - nO=width, - nV=rows // 2, - column=columns.index("SUFFIX"), - dropout=dropout, - seed=8, + + if also_embed_subwords: + embeddings = [ + make_hash_embed(NORM) + make_hash_embed(PREFIX) + make_hash_embed(SUFFIX) + make_hash_embed(SHAPE) + ] + else: + embeddings = [make_hash_embed(NORM)] + + if also_use_static_vectors: + model = chain( + concatenate( + chain(FeatureExtractor(cols), concatenate(*embeddings)), + StaticVectors(width, dropout=dropout) + ), + Maxout(width, dropout=dropout, normalize=True) ) - shape = HashEmbed( - nO=width, - nV=rows // 2, - column=columns.index("SHAPE"), - dropout=dropout, - seed=9, + else: + model = chain( + chain(FeatureExtractor(cols), concatenate(*embeddings)), + Maxout(width, concat_size, dropout=dropout, normalize=True) ) - - if pretrained_vectors: - glove = StaticVectors( - vectors_name=pretrained_vectors, - nO=width, - column=columns.index(ID), - dropout=dropout, - ) - - with Model.define_operators({">>": chain, "|": concatenate}): - if not use_subwords and not pretrained_vectors: - embed_layer = norm - else: - if use_subwords and pretrained_vectors: - concat_columns = glove | norm | prefix | suffix | shape - elif use_subwords: - concat_columns = norm | prefix | suffix | shape - else: - concat_columns = glove | norm - - embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH")) - - return embed_layer - + return model + @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): @@ -219,7 +92,7 @@ def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): @registry.architectures.register("spacy.MaxoutWindowEncoder.v1") -def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth): +def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: int): cnn = chain( expand_window(window_size=window_size), Maxout( @@ -249,133 +122,9 @@ def MishWindowEncoder(width, window_size, depth): @registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") -def TorchBiLSTMEncoder(width, depth): - import torch.nn - - # TODO FIX - from thinc.api import PyTorchRNNWrapper - +def BiLSTMEncoder(width, depth, dropout): if depth == 0: return noop() return with_padded( - PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True)) + PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout) ) - - -def build_Tok2Vec_model( - width: int, - embed_size: int, - pretrained_vectors: Optional[str], - window_size: int, - maxout_pieces: int, - subword_features: bool, - char_embed: bool, - nM: int, - nC: int, - conv_depth: int, - bilstm_depth: int, - dropout: float, -) -> Model: - if char_embed: - subword_features = False - cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] - with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): - norm = HashEmbed( - nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0 - ) - if subword_features: - prefix = HashEmbed( - nO=width, - nV=embed_size // 2, - column=cols.index(PREFIX), - dropout=None, - seed=1, - ) - suffix = HashEmbed( - nO=width, - nV=embed_size // 2, - column=cols.index(SUFFIX), - dropout=None, - seed=2, - ) - shape = HashEmbed( - nO=width, - nV=embed_size // 2, - column=cols.index(SHAPE), - dropout=None, - seed=3, - ) - else: - prefix, suffix, shape = (None, None, None) - if pretrained_vectors is not None: - glove = StaticVectors( - vectors=pretrained_vectors.data, - nO=width, - column=cols.index(ID), - dropout=dropout, - ) - - if subword_features: - columns = 5 - embed = uniqued( - (glove | norm | prefix | suffix | shape) - >> Maxout( - nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, - ), - column=cols.index(ORTH), - ) - else: - columns = 2 - embed = uniqued( - (glove | norm) - >> Maxout( - nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, - ), - column=cols.index(ORTH), - ) - elif subword_features: - columns = 4 - embed = uniqued( - concatenate(norm, prefix, suffix, shape) - >> Maxout( - nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True, - ), - column=cols.index(ORTH), - ) - elif char_embed: - embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | FeatureExtractor( - cols - ) >> with_array(norm) - reduce_dimensions = Maxout( - nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True, - ) - else: - embed = norm - - convolution = residual( - expand_window(window_size=window_size) - >> Maxout( - nO=width, - nI=width * ((window_size * 2) + 1), - nP=maxout_pieces, - dropout=0.0, - normalize=True, - ) - ) - if char_embed: - tok2vec = embed >> with_array( - reduce_dimensions >> convolution ** conv_depth, pad=conv_depth - ) - else: - tok2vec = FeatureExtractor(cols) >> with_array( - embed >> convolution ** conv_depth, pad=conv_depth - ) - - if bilstm_depth >= 1: - tok2vec = tok2vec >> PyTorchLSTM( - nO=width, nI=width, depth=bilstm_depth, bi=True - ) - if tok2vec.has_dim("nO") is not False: - tok2vec.set_dim("nO", width) - tok2vec.set_ref("embed", embed) - return tok2vec From 034d803b7a4f0118b1d981554e62a1cf0a0da721 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 13:52:05 +0200 Subject: [PATCH 05/37] Update ptb config --- .../ptb-joint-pos-dep/defaults.cfg | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/experiments/ptb-joint-pos-dep/defaults.cfg b/examples/experiments/ptb-joint-pos-dep/defaults.cfg index d694ceac8..48741c433 100644 --- a/examples/experiments/ptb-joint-pos-dep/defaults.cfg +++ b/examples/experiments/ptb-joint-pos-dep/defaults.cfg @@ -64,7 +64,7 @@ min_action_freq = 1 @architectures = "spacy.Tagger.v1" [components.tagger.model.tok2vec] -@architectures = "spacy.Tok2VecTensors.v1" +@architectures = "spacy.Tok2VecListener.v1" width = ${components.tok2vec.model:width} [components.parser.model] @@ -74,16 +74,21 @@ hidden_width = 64 maxout_pieces = 3 [components.parser.model.tok2vec] -@architectures = "spacy.Tok2VecTensors.v1" +@architectures = "spacy.Tok2VecListener.v1" width = ${components.tok2vec.model:width} [components.tok2vec.model] -@architectures = "spacy.HashEmbedCNN.v1" -pretrained_vectors = ${training:vectors} +@architectures = "spacy.Tok2Vec.v1" + +[components.tok2vec.model.embed] +@architectures = "spacy.HashEmbed.v1" width = 96 +rows = 2000 +also_use_subwords = true +also_use_static_vectors = false + +[components.tok2vec.model.encode] +@architectures = "spacy.MaxoutWindowEncode.v1" depth = 4 window_size = 1 -embed_size = 2000 maxout_pieces = 3 -subword_features = true -dropout = null From fe0cdcd461f4c77db4ce120b3dee2d960d83d605 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 13:59:46 +0200 Subject: [PATCH 06/37] Fixes --- spacy/ml/models/tok2vec.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 4bcd61625..4c4bd0d22 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -50,7 +50,8 @@ def HashEmbed( width, rows if feature == NORM else rows // 2, column=cols.index(feature), - seed=seed + seed=seed, + dropout=0.0 ) if also_embed_subwords: @@ -67,14 +68,14 @@ def HashEmbed( model = chain( concatenate( chain(FeatureExtractor(cols), concatenate(*embeddings)), - StaticVectors(width, dropout=dropout) + StaticVectors(width, dropout=0.0) ), - Maxout(width, dropout=dropout, normalize=True) + Maxout(width, pieces=3, dropout=0.0, normalize=True) ) else: model = chain( chain(FeatureExtractor(cols), concatenate(*embeddings)), - Maxout(width, concat_size, dropout=dropout, normalize=True) + Maxout(width, pieces=3, dropout=0.0, normalize=True) ) return model From 099e9331c50e92ba96bd3cc23498a023492965a7 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 15:51:40 +0200 Subject: [PATCH 07/37] Fix tok2vec --- spacy/ml/models/tok2vec.py | 45 ++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 4c4bd0d22..448f9d1d0 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -1,13 +1,15 @@ from typing import Optional, List from thinc.api import chain, clone, concatenate, with_array, with_padded -from thinc.api import Model, noop -from thinc.api import FeatureExtractor, HashEmbed, StaticVectors -from thincapi import expand_window, residual, Maxout, Mish +from thinc.api import Model, noop, list2ragged, ragged2list +from thinc.api import FeatureExtractor, HashEmbed +from thinc.api import expand_window, residual, Maxout, Mish from thinc.types import Floats2d +from ...tokens import Doc from ... import util from ...util import registry from ...ml import _character_embed +from ..staticvectors import StaticVectors from ...pipeline.tok2vec import Tok2VecListener from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE @@ -21,20 +23,19 @@ def tok2vec_listener_v1(width, upstream="*"): @registry.architectures.register("spacy.Tok2Vec.v1") def Tok2Vec( embed: Model[List[Doc], List[Floats2d]], - encode: Model[List[Floats2d], List[Floats2d] + encode: Model[List[Floats2d], List[Floats2d]] ) -> Model[List[Doc], List[Floats2d]]: - tok2vec = with_array( - chain(embed, encode), - pad=encode.attrs.get("receptive_field", 0) - ) + + receptive_field = encode.attrs.get("receptive_field", 0) + tok2vec = chain(embed, with_array(encode, pad=receptive_field)) tok2vec.set_dim("nO", encode.get_dim("nO")) tok2vec.set_ref("embed", embed) tok2vec.set_ref("encode", encode) return tok2vec -@registry.architectures.register("spacy.HashEmbed.v1") -def HashEmbed( +@registry.architectures.register("spacy.MultiHashEmbed.v1") +def MultiHashEmbed( width: int, rows: int, also_embed_subwords: bool, @@ -56,9 +57,9 @@ def HashEmbed( if also_embed_subwords: embeddings = [ - make_hash_embed(NORM) - make_hash_embed(PREFIX) - make_hash_embed(SUFFIX) + make_hash_embed(NORM), + make_hash_embed(PREFIX), + make_hash_embed(SUFFIX), make_hash_embed(SHAPE) ] else: @@ -67,15 +68,25 @@ def HashEmbed( if also_use_static_vectors: model = chain( concatenate( - chain(FeatureExtractor(cols), concatenate(*embeddings)), + chain( + FeatureExtractor(cols), + list2ragged(), + with_array(concatenate(*embeddings)) + ), StaticVectors(width, dropout=0.0) ), - Maxout(width, pieces=3, dropout=0.0, normalize=True) + with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), + ragged2list() ) else: model = chain( - chain(FeatureExtractor(cols), concatenate(*embeddings)), - Maxout(width, pieces=3, dropout=0.0, normalize=True) + chain( + FeatureExtractor(cols), + list2ragged(), + with_array(concatenate(*embeddings)) + ), + with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), + ragged2list() ) return model From 9987ea9e4ddd920a8035c097b3360a80560352cb Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 15:52:02 +0200 Subject: [PATCH 08/37] Fix Tok2Vec begin_training --- spacy/pipeline/tok2vec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 5bda12d1b..5caaf432f 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -196,7 +196,7 @@ class Tok2Vec(Pipe): DOCS: https://spacy.io/api/tok2vec#begin_training """ - docs = [Doc(Vocab(), words=["hello"])] + docs = [Doc(self.vocab, words=["hello"])] self.model.initialize(X=docs) link_vectors_to_models(self.vocab) From acc64e138aa4d334e58933f885debe763952b3e3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 15:52:20 +0200 Subject: [PATCH 09/37] Add import --- spacy/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/util.py b/spacy/util.py index de6d9831b..72e68463b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -7,7 +7,7 @@ import importlib.util import re from pathlib import Path import thinc -from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer +from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer, Model import functools import itertools import numpy.random From 984754e3be65ddd0ed3ab77835b62ca67bb1266a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 15:52:30 +0200 Subject: [PATCH 10/37] Update config --- .../ptb-joint-pos-dep/defaults.cfg | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/experiments/ptb-joint-pos-dep/defaults.cfg b/examples/experiments/ptb-joint-pos-dep/defaults.cfg index 48741c433..5850eaf3a 100644 --- a/examples/experiments/ptb-joint-pos-dep/defaults.cfg +++ b/examples/experiments/ptb-joint-pos-dep/defaults.cfg @@ -4,16 +4,16 @@ patience = 10000 eval_frequency = 200 dropout = 0.2 init_tok2vec = null -vectors = null +vectors = "tmp/fasttext_vectors/vocab" max_epochs = 100 orth_variant_level = 0.0 gold_preproc = true max_length = 0 -scores = ["tag_acc", "dep_uas", "dep_las"] +scores = ["tag_acc", "dep_uas", "dep_las", "speed"] score_weights = {"dep_las": 0.8, "tag_acc": 0.2} limit = 0 seed = 0 -accumulate_gradient = 2 +accumulate_gradient = 1 discard_oversize = false raw_text = null tag_map = null @@ -22,7 +22,7 @@ base_model = null eval_batch_size = 128 use_pytorch_for_gpu_memory = false -batch_by = "padded" +batch_by = "words" [training.batch_size] @schedules = "compounding.v1" @@ -65,7 +65,7 @@ min_action_freq = 1 [components.tagger.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" -width = ${components.tok2vec.model:width} +width = ${components.tok2vec.model.encode:width} [components.parser.model] @architectures = "spacy.TransitionBasedParser.v1" @@ -75,20 +75,21 @@ maxout_pieces = 3 [components.parser.model.tok2vec] @architectures = "spacy.Tok2VecListener.v1" -width = ${components.tok2vec.model:width} +width = ${components.tok2vec.model.encode:width} [components.tok2vec.model] @architectures = "spacy.Tok2Vec.v1" [components.tok2vec.model.embed] -@architectures = "spacy.HashEmbed.v1" -width = 96 +@architectures = "spacy.MultiHashEmbed.v1" +width = ${components.tok2vec.model.encode:width} rows = 2000 -also_use_subwords = true -also_use_static_vectors = false +also_embed_subwords = true +also_use_static_vectors = true [components.tok2vec.model.encode] -@architectures = "spacy.MaxoutWindowEncode.v1" +@architectures = "spacy.MaxoutWindowEncoder.v1" +width = 96 depth = 4 window_size = 1 maxout_pieces = 3 From 44d350dc9476aa897cbf8d9502f0b5c87a2efa89 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 15:52:46 +0200 Subject: [PATCH 11/37] Use spaCy's StaticVectors --- spacy/ml/models/textcat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index e5f4af2fb..a64a2487a 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -5,7 +5,6 @@ from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_ from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued from thinc.api import Relu, residual, expand_window, FeatureExtractor -from ..spacy_vectors import SpacyVectors from ... import util from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...util import registry From 475d7c1c7c4520b280fad01e2a3c8db5d60a594b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 15:52:55 +0200 Subject: [PATCH 12/37] Fix StaticVectors class --- spacy/ml/staticvectors.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index 4c9e53563..ce2c7efff 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -37,15 +37,14 @@ def forward( if not len(docs): return _handle_empty(model.ops, model.get_dim("nO")) key_attr = model.attrs["key_attr"] - W = cast(Floats2d, model.get_param("W")) + W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) V = cast(Floats2d, docs[0].vocab.vectors.data) mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) - rows = model.ops.flatten( [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] ) output = Ragged( - model.ops.gemm(V[rows], W, trans2=True), + model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True), model.ops.asarray([len(doc) for doc in docs], dtype="i") ) if mask is not None: @@ -54,7 +53,14 @@ def forward( def backprop(d_output: Ragged) -> List[Doc]: if mask is not None: d_output.data *= mask - model.inc_grad("W", model.ops.gemm(d_output.data, V[rows], trans1=True)) + model.inc_grad( + "W", + model.ops.gemm( + d_output.data, + model.ops.as_contig(V[rows]), + trans1=True + ) + ) return [] return output, backprop From df95e2af64a6c3d2862e4317a450e8e694e2d406 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 21:56:02 +0200 Subject: [PATCH 13/37] Add load_vectors_into_model util --- spacy/util.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/spacy/util.py b/spacy/util.py index 72e68463b..4e3a8d203 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -189,6 +189,23 @@ def get_module_path(module: ModuleType) -> Path: return Path(sys.modules[module.__module__].__file__).parent +def load_vectors_into_model( + nlp: "Language", + name: Union[str, Path], + *, + add_strings=True +) -> None: + """Load word vectors from an installed model or path into a model instance.""" + vectors_nlp = load_model(name) + nlp.vocab.vectors = vectors_nlp.vocab.vectors + if add_strings: + # I guess we should add the strings from the vectors_nlp model? + # E.g. if someone does a similarity query, they might expect the strings. + for key in nlp.vocab.vectors.key2row: + if key in vectors_nlp.strings: + nlp.vocab.strings.add(vectors_nlp.strings[key]) + + def load_model( name: Union[str, Path], disable: Iterable[str] = tuple(), From 30dd96c540ec07f3289649e8f946547407721eba Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 21:56:28 +0200 Subject: [PATCH 14/37] Load vectors in Language.from_config --- spacy/language.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spacy/language.py b/spacy/language.py index 9dd8a347e..9fde419b3 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1408,6 +1408,8 @@ class Language: nlp = cls( create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer, ) + if config["training"]["vectors"] is not None: + util.load_vectors_into_model(nlp, config["training"]["vectors"]) pipeline = config.get("components", {}) for pipe_name in config["nlp"]["pipeline"]: if pipe_name not in pipeline: From 7299419fe4cb68459bb300ad1c6c2b4885861db0 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 21:59:30 +0200 Subject: [PATCH 15/37] Dont load vectors in Language.from_config --- spacy/language.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 9fde419b3..3511a7691 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1408,8 +1408,10 @@ class Language: nlp = cls( create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer, ) - if config["training"]["vectors"] is not None: - util.load_vectors_into_model(nlp, config["training"]["vectors"]) + # Note that we don't load vectors here, instead they get loaded explicitly + # inside stuff like the spacy train function. If we loaded them here, + # then we would load them twice at runtime: once when we make from config, + # and then again when we load from disk. pipeline = config.get("components", {}) for pipe_name in config["nlp"]["pipeline"]: if pipe_name not in pipeline: From 7852a68a7530b36c512a54163a1511afe102b625 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 21:59:51 +0200 Subject: [PATCH 16/37] Fix load_vectors_into_model function --- spacy/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy/util.py b/spacy/util.py index 4e3a8d203..7a26011f1 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -202,8 +202,8 @@ def load_vectors_into_model( # I guess we should add the strings from the vectors_nlp model? # E.g. if someone does a similarity query, they might expect the strings. for key in nlp.vocab.vectors.key2row: - if key in vectors_nlp.strings: - nlp.vocab.strings.add(vectors_nlp.strings[key]) + if key in vectors_nlp.vocab.strings: + nlp.vocab.strings.add(vectors_nlp.vocab.strings[key]) def load_model( From 2aff3c4b5aff4f5e17fa67d07424580609719fad Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 22:00:24 +0200 Subject: [PATCH 17/37] Load vectors in 'spacy train' --- spacy/cli/train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index fbe3a5013..e152ae8ea 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -80,16 +80,20 @@ def train( msg.info("Using CPU") msg.info(f"Loading config and nlp from: {config_path}") config = Config().from_disk(config_path) + if config.get("training", {}).get("seed") is not None: + fix_random_seed(config["training"]["seed"]) with show_validation_error(): nlp, config = util.load_model_from_config(config, overrides=config_overrides) if config["training"]["base_model"]: - base_nlp = util.load_model(config["training"]["base_model"]) # TODO: do something to check base_nlp against regular nlp described in config? - nlp = base_nlp + # If everything matches it will look something like: + # base_nlp = util.load_model(config["training"]["base_model"]) + # nlp = base_nlp + raise NotImplementedError("base_model not supported yet.") + if config["training"]["vectors"] is not None: + util.load_vectors_into_model(nlp, config["training"]["vectors"]) verify_config(nlp) raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) - if config["training"]["seed"] is not None: - fix_random_seed(config["training"]["seed"]) if config["training"]["use_pytorch_for_gpu_memory"]: # It feels kind of weird to not have a default for this. use_pytorch_for_gpu_memory() From 0c17ea4c851d2d5996447f1da8d6de2b601e5ec7 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 22:02:34 +0200 Subject: [PATCH 18/37] Format --- spacy/ml/models/tok2vec.py | 32 ++++++++++++++------------------ spacy/ml/staticvectors.py | 14 +++++--------- spacy/util.py | 9 +++------ 3 files changed, 22 insertions(+), 33 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 448f9d1d0..f9183e709 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -23,7 +23,7 @@ def tok2vec_listener_v1(width, upstream="*"): @registry.architectures.register("spacy.Tok2Vec.v1") def Tok2Vec( embed: Model[List[Doc], List[Floats2d]], - encode: Model[List[Floats2d], List[Floats2d]] + encode: Model[List[Floats2d], List[Floats2d]], ) -> Model[List[Doc], List[Floats2d]]: receptive_field = encode.attrs.get("receptive_field", 0) @@ -36,14 +36,12 @@ def Tok2Vec( @registry.architectures.register("spacy.MultiHashEmbed.v1") def MultiHashEmbed( - width: int, - rows: int, - also_embed_subwords: bool, - also_use_static_vectors: bool + width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool ): cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH] - + seed = 7 + def make_hash_embed(feature): nonlocal seed seed += 1 @@ -52,15 +50,15 @@ def MultiHashEmbed( rows if feature == NORM else rows // 2, column=cols.index(feature), seed=seed, - dropout=0.0 + dropout=0.0, ) - + if also_embed_subwords: embeddings = [ make_hash_embed(NORM), make_hash_embed(PREFIX), make_hash_embed(SUFFIX), - make_hash_embed(SHAPE) + make_hash_embed(SHAPE), ] else: embeddings = [make_hash_embed(NORM)] @@ -71,25 +69,25 @@ def MultiHashEmbed( chain( FeatureExtractor(cols), list2ragged(), - with_array(concatenate(*embeddings)) + with_array(concatenate(*embeddings)), ), - StaticVectors(width, dropout=0.0) + StaticVectors(width, dropout=0.0), ), with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), - ragged2list() + ragged2list(), ) else: model = chain( chain( FeatureExtractor(cols), list2ragged(), - with_array(concatenate(*embeddings)) + with_array(concatenate(*embeddings)), ), with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), - ragged2list() + ragged2list(), ) return model - + @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): @@ -137,6 +135,4 @@ def MishWindowEncoder(width, window_size, depth): def BiLSTMEncoder(width, depth, dropout): if depth == 0: return noop() - return with_padded( - PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout) - ) + return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout)) diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index ce2c7efff..41afdbf80 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -15,7 +15,7 @@ def StaticVectors( *, dropout: Optional[float] = None, init_W: Callable = glorot_uniform_init, - key_attr: str="ORTH" + key_attr: str = "ORTH" ) -> Model[List[Doc], Ragged]: """Embed Doc objects with their vocab's vectors table, applying a learned linear projection to control the dimensionality. If a dropout rate is @@ -45,21 +45,17 @@ def forward( ) output = Ragged( model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True), - model.ops.asarray([len(doc) for doc in docs], dtype="i") + model.ops.asarray([len(doc) for doc in docs], dtype="i"), ) if mask is not None: output.data *= mask - + def backprop(d_output: Ragged) -> List[Doc]: if mask is not None: d_output.data *= mask model.inc_grad( "W", - model.ops.gemm( - d_output.data, - model.ops.as_contig(V[rows]), - trans1=True - ) + model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True), ) return [] @@ -78,7 +74,7 @@ def init( nM = X[0].vocab.vectors.data.shape[1] if Y is not None: nO = Y.data.shape[1] - + if nM is None: raise ValueError( "Cannot initialize StaticVectors layer: nM dimension unset. " diff --git a/spacy/util.py b/spacy/util.py index 7a26011f1..898e1c2c3 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -190,10 +190,7 @@ def get_module_path(module: ModuleType) -> Path: def load_vectors_into_model( - nlp: "Language", - name: Union[str, Path], - *, - add_strings=True + nlp: "Language", name: Union[str, Path], *, add_strings=True ) -> None: """Load word vectors from an installed model or path into a model instance.""" vectors_nlp = load_model(name) @@ -1205,12 +1202,12 @@ class DummyTokenizer: def link_vectors_to_models( vocab: "Vocab", - models: List[Model]=[], + models: List[Model] = [], *, vectors_name_attr="vectors_name", vectors_attr="vectors", key2row_attr="key2row", - default_vectors_name="spacy_pretrained_vectors" + default_vectors_name="spacy_pretrained_vectors", ) -> None: """Supply vectors data to models.""" vectors = vocab.vectors From 1784c95827f7a2fe8f8df88facced72af73cc961 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 22:17:47 +0200 Subject: [PATCH 19/37] Clean up link_vectors_to_models unused stuff --- spacy/cli/project/assets.py | 1 - spacy/language.py | 4 +--- spacy/pipeline/morphologizer.pyx | 1 - spacy/pipeline/multitask.pyx | 3 --- spacy/pipeline/pipe.pyx | 4 +--- spacy/pipeline/senter.pyx | 1 - spacy/pipeline/simple_ner.py | 1 - spacy/pipeline/tagger.pyx | 1 - spacy/pipeline/textcat.py | 1 - spacy/pipeline/tok2vec.py | 3 +-- spacy/syntax/_parser_model.pyx | 2 +- spacy/syntax/nn_parser.pyx | 3 +-- spacy/tests/regression/test_issue2501-3000.py | 2 -- spacy/util.py | 23 ------------------- spacy/vocab.pyx | 7 +----- 15 files changed, 6 insertions(+), 51 deletions(-) diff --git a/spacy/cli/project/assets.py b/spacy/cli/project/assets.py index 1bd28cb7e..e42935e2f 100644 --- a/spacy/cli/project/assets.py +++ b/spacy/cli/project/assets.py @@ -11,7 +11,6 @@ from ...util import ensure_path, working_dir from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum - # TODO: find a solution for caches # CACHES = [ # Path.home() / ".torch", diff --git a/spacy/language.py b/spacy/language.py index 3511a7691..4b7651d65 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -21,7 +21,7 @@ from .vocab import Vocab, create_vocab from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs from .gold import Example from .scorer import Scorer -from .util import link_vectors_to_models, create_default_optimizer, registry +from .util import create_default_optimizer, registry from .util import SimpleFrozenDict, combine_score_weights from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES @@ -1049,7 +1049,6 @@ class Language: if self.vocab.vectors.data.shape[1] >= 1: ops = get_current_ops() self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) - link_vectors_to_models(self.vocab) if sgd is None: sgd = create_default_optimizer() self._optimizer = sgd @@ -1082,7 +1081,6 @@ class Language: ops = get_current_ops() if self.vocab.vectors.data.shape[1] >= 1: self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) - link_vectors_to_models(self.vocab) if sgd is None: sgd = create_default_optimizer() self._optimizer = sgd diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index a6be129ba..56ef44cb9 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -149,7 +149,6 @@ class Morphologizer(Tagger): self.cfg["labels_pos"][norm_label] = POS_IDS[pos] self.set_output(len(self.labels)) self.model.initialize() - util.link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd diff --git a/spacy/pipeline/multitask.pyx b/spacy/pipeline/multitask.pyx index 4945afe4f..97826aaa6 100644 --- a/spacy/pipeline/multitask.pyx +++ b/spacy/pipeline/multitask.pyx @@ -11,7 +11,6 @@ from .tagger import Tagger from ..language import Language from ..syntax import nonproj from ..attrs import POS, ID -from ..util import link_vectors_to_models from ..errors import Errors @@ -91,7 +90,6 @@ class MultitaskObjective(Tagger): if label is not None and label not in self.labels: self.labels[label] = len(self.labels) self.model.initialize() - link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd @@ -179,7 +177,6 @@ class ClozeMultitask(Pipe): pass def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None): - link_vectors_to_models(self.vocab) self.model.initialize() X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO"))) self.model.output_layer.begin_training(X) diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx index f8ca28724..e4f7989b8 100644 --- a/spacy/pipeline/pipe.pyx +++ b/spacy/pipeline/pipe.pyx @@ -3,7 +3,7 @@ import srsly from ..tokens.doc cimport Doc -from ..util import link_vectors_to_models, create_default_optimizer +from ..util import create_default_optimizer from ..errors import Errors from .. import util @@ -145,8 +145,6 @@ class Pipe: DOCS: https://spacy.io/api/pipe#begin_training """ self.model.initialize() - if hasattr(self, "vocab"): - link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 743ceb32b..568e6031b 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -138,7 +138,6 @@ class SentenceRecognizer(Tagger): """ self.set_output(len(self.labels)) self.model.initialize() - util.link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd diff --git a/spacy/pipeline/simple_ner.py b/spacy/pipeline/simple_ner.py index ec7ab6b7a..9b9872b77 100644 --- a/spacy/pipeline/simple_ner.py +++ b/spacy/pipeline/simple_ner.py @@ -168,7 +168,6 @@ class SimpleNER(Pipe): self.model.initialize() if pipeline is not None: self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg) - util.link_vectors_to_models(self.vocab) self.loss_func = SequenceCategoricalCrossentropy( names=self.get_tag_names(), normalize=True, missing_value=None ) diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index c52a7889b..b3f996acb 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -318,7 +318,6 @@ class Tagger(Pipe): self.model.initialize(X=doc_sample) # Get batch of example docs, example outputs to call begin_training(). # This lets the model infer shapes. - util.link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 2aaa4a769..c235a2594 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -356,7 +356,6 @@ class TextCategorizer(Pipe): docs = [Doc(Vocab(), words=["hello"])] truths, _ = self._examples_to_truth(examples) self.set_output(len(self.labels)) - util.link_vectors_to_models(self.vocab) self.model.initialize(X=docs, Y=truths) if sgd is None: sgd = self.create_optimizer() diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 5caaf432f..5e9e5b40e 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -7,7 +7,7 @@ from ..tokens import Doc from ..vocab import Vocab from ..language import Language from ..errors import Errors -from ..util import link_vectors_to_models, minibatch +from ..util import minibatch default_model_config = """ @@ -198,7 +198,6 @@ class Tok2Vec(Pipe): """ docs = [Doc(self.vocab, words=["hello"])] self.model.initialize(X=docs) - link_vectors_to_models(self.vocab) class Tok2VecListener(Model): diff --git a/spacy/syntax/_parser_model.pyx b/spacy/syntax/_parser_model.pyx index 7acee5efd..eedd84bac 100644 --- a/spacy/syntax/_parser_model.pyx +++ b/spacy/syntax/_parser_model.pyx @@ -21,7 +21,7 @@ from .transition_system cimport Transition from ..compat import copy_array from ..errors import Errors, TempErrors -from ..util import link_vectors_to_models, create_default_optimizer +from ..util import create_default_optimizer from .. import util from . import nonproj diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 5313ec9bd..a0ee13a0a 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -29,7 +29,7 @@ from .stateclass cimport StateClass from ._state cimport StateC from .transition_system cimport Transition -from ..util import link_vectors_to_models, create_default_optimizer, registry +from ..util import create_default_optimizer, registry from ..compat import copy_array from ..errors import Errors, Warnings from .. import util @@ -456,7 +456,6 @@ cdef class Parser: self.model.initialize() if pipeline is not None: self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg) - link_vectors_to_models(self.vocab) return sgd def to_disk(self, path, exclude=tuple()): diff --git a/spacy/tests/regression/test_issue2501-3000.py b/spacy/tests/regression/test_issue2501-3000.py index ac0867189..cf4e402e2 100644 --- a/spacy/tests/regression/test_issue2501-3000.py +++ b/spacy/tests/regression/test_issue2501-3000.py @@ -9,7 +9,6 @@ from spacy.matcher import Matcher from spacy.tokens import Doc, Span from spacy.vocab import Vocab from spacy.compat import pickle -from spacy.util import link_vectors_to_models import numpy import random @@ -190,7 +189,6 @@ def test_issue2871(): _ = vocab[word] # noqa: F841 vocab.set_vector(word, vector_data[0]) vocab.vectors.name = "dummy_vectors" - link_vectors_to_models(vocab) assert vocab["dog"].rank == 0 assert vocab["cat"].rank == 1 assert vocab["SUFFIX"].rank == 2 diff --git a/spacy/util.py b/spacy/util.py index 898e1c2c3..677f5e8e0 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1200,29 +1200,6 @@ class DummyTokenizer: return self -def link_vectors_to_models( - vocab: "Vocab", - models: List[Model] = [], - *, - vectors_name_attr="vectors_name", - vectors_attr="vectors", - key2row_attr="key2row", - default_vectors_name="spacy_pretrained_vectors", -) -> None: - """Supply vectors data to models.""" - vectors = vocab.vectors - if vectors.name is None: - vectors.name = default_vectors_name - if vectors.data.size != 0: - warnings.warn(Warnings.W020.format(shape=vectors.data.shape)) - - for model in models: - for node in model.walk(): - if node.attrs.get(vectors_name_attr) == vectors.name: - node.attrs[vectors_attr] = Unserializable(vectors.data) - node.attrs[key2row_attr] = Unserializable(vectors.key2row) - - def create_default_optimizer() -> Optimizer: # TODO: Do we still want to allow env_opt? learn_rate = env_opt("learn_rate", 0.001) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index f41ad2356..b7337b92e 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -16,7 +16,7 @@ from .errors import Errors from .lemmatizer import Lemmatizer from .attrs import intify_attrs, NORM, IS_STOP from .vectors import Vectors -from .util import link_vectors_to_models, registry +from .util import registry from .lookups import Lookups, load_lookups from . import util from .lang.norm_exceptions import BASE_NORMS @@ -344,7 +344,6 @@ cdef class Vocab: synonym = self.strings[syn_keys[i][0]] score = scores[i][0] remap[word] = (synonym, score) - link_vectors_to_models(self) return remap def get_vector(self, orth, minn=None, maxn=None): @@ -476,8 +475,6 @@ cdef class Vocab: if "vectors" not in exclude: if self.vectors is not None: self.vectors.from_disk(path, exclude=["strings"]) - if self.vectors.name is not None: - link_vectors_to_models(self) if "lookups" not in exclude: self.lookups.from_disk(path) if "lexeme_norm" in self.lookups: @@ -537,8 +534,6 @@ cdef class Vocab: ) self.length = 0 self._by_orth = PreshMap() - if self.vectors.name is not None: - link_vectors_to_models(self) return self def _reset_cache(self, keys, strings): From c35d6282fcd1d74f209e34b0f90c09cbe2882ded Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 22:43:06 +0200 Subject: [PATCH 20/37] Add previous HashEmbedCNN tok2vec to make transition easier --- spacy/ml/models/tok2vec.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index f9183e709..881f25a3b 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -20,8 +20,37 @@ def tok2vec_listener_v1(width, upstream="*"): return tok2vec +@registry.architectures.register("spacy.HashEmbedCNN.v1") +def build_hash_embed_cnn_tok2vec( + *, + width: int, + depth: int, + embed_size: int, + window_size: int, + maxout_pieces: int, + subword_features: bool, + dropout: Optional[float], + pretrained_vectors: Optional[bool] +) -> Model[List[Doc], List[Floats2d]]: + """Build spaCy's 'standard' tok2vec layer, which uses hash embedding + with subword features and a CNN with layer-normalized maxout.""" + return build_Tok2Vec_model( + embed=MultiHashEmbed( + width=width, + rows=embed_size, + also_embed_subwords=subword_features, + also_use_static_vectors=bool(pretrained_vectors), + ), + encode=MaxoutWindowEncoder( + width=width, + depth=depth, + window_size=window_size, + maxout_pieces=maxout_pieces + ) + ) + @registry.architectures.register("spacy.Tok2Vec.v1") -def Tok2Vec( +def build_Tok2Vec_model( embed: Model[List[Doc], List[Floats2d]], encode: Model[List[Floats2d], List[Floats2d]], ) -> Model[List[Doc], List[Floats2d]]: @@ -62,7 +91,7 @@ def MultiHashEmbed( ] else: embeddings = [make_hash_embed(NORM)] - + concat_size = width * (len(embeddings) + also_use_static_vectors) if also_use_static_vectors: model = chain( concatenate( @@ -73,7 +102,7 @@ def MultiHashEmbed( ), StaticVectors(width, dropout=0.0), ), - with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), + with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)), ragged2list(), ) else: @@ -83,7 +112,7 @@ def MultiHashEmbed( list2ragged(), with_array(concatenate(*embeddings)), ), - with_array(Maxout(width, nP=3, dropout=0.0, normalize=True)), + with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)), ragged2list(), ) return model From 20e9098e3f527fadff62aa31bb6342bc14763e91 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 22:43:19 +0200 Subject: [PATCH 21/37] Update tests --- .../tests/serialize/test_serialize_config.py | 24 ++--- spacy/tests/test_models.py | 94 ++++++++++--------- 2 files changed, 61 insertions(+), 57 deletions(-) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 90a79994e..25673b8c4 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -68,18 +68,18 @@ dropout = null @registry.architectures.register("my_test_parser") def my_parser(): tok2vec = build_Tok2Vec_model( - width=321, - embed_size=5432, - pretrained_vectors=None, - window_size=3, - maxout_pieces=4, - subword_features=True, - char_embed=True, - nM=64, - nC=8, - conv_depth=2, - bilstm_depth=0, - dropout=None, + MultiHashEmbed( + width=321, + embed_size=5432, + also_embed_subwords=True, + also_use_static_vectors=False + ), + MaxoutWindowEncoder( + width=321, + window_size=3, + maxout_pieces=4, + depth=2 + ) ) parser = build_tb_parser_model( tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5 diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index fc1988fcd..4c38ea6c6 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate from numpy.testing import assert_array_equal import numpy -from spacy.ml.models import build_Tok2Vec_model +from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier from spacy.lang.en import English from spacy.lang.en.examples import sentences as EN_SENTENCES +def get_textcat_kwargs(): + return { + "width": 64, + "embed_size": 2000, + "pretrained_vectors": None, + "exclusive_classes": False, + "ngram_size": 1, + "window_size": 1, + "conv_depth": 2, + "dropout": None, + "nO": 7, + } + +def get_textcat_cnn_kwargs(): + return { + "tok2vec": test_tok2vec(), + "exclusive_classes": False, + "nO": 13, + } + def get_all_params(model): params = [] for node in model.walk(): @@ -35,50 +55,34 @@ def get_gradient(model, Y): raise ValueError(f"Could not get gradient for type {type(Y)}") +def get_tok2vec_kwargs(): + # This actually creates models, so seems best to put it in a function. + return { + "embed": MultiHashEmbed( + width=32, + rows=500, + also_embed_subwords=True, + also_use_static_vectors=False + ), + "encode": MaxoutWindowEncoder( + width=32, + depth=2, + maxout_pieces=2, + window_size=1, + ) + } + + def test_tok2vec(): - return build_Tok2Vec_model(**TOK2VEC_KWARGS) - - -TOK2VEC_KWARGS = { - "width": 96, - "embed_size": 2000, - "subword_features": True, - "char_embed": False, - "conv_depth": 4, - "bilstm_depth": 0, - "maxout_pieces": 4, - "window_size": 1, - "dropout": 0.1, - "nM": 0, - "nC": 0, - "pretrained_vectors": None, -} - -TEXTCAT_KWARGS = { - "width": 64, - "embed_size": 2000, - "pretrained_vectors": None, - "exclusive_classes": False, - "ngram_size": 1, - "window_size": 1, - "conv_depth": 2, - "dropout": None, - "nO": 7, -} - -TEXTCAT_CNN_KWARGS = { - "tok2vec": test_tok2vec(), - "exclusive_classes": False, - "nO": 13, -} + return build_Tok2Vec_model(**get_tok2vec_kwargs()) @pytest.mark.parametrize( "seed,model_func,kwargs", [ - (0, build_Tok2Vec_model, TOK2VEC_KWARGS), - (0, build_text_classifier, TEXTCAT_KWARGS), - (0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS), + (0, build_Tok2Vec_model, get_tok2vec_kwargs()), + (0, build_text_classifier, get_textcat_kwargs()), + (0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()), ], ) def test_models_initialize_consistently(seed, model_func, kwargs): @@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs): @pytest.mark.parametrize( "seed,model_func,kwargs,get_X", [ - (0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs), - (0, build_text_classifier, TEXTCAT_KWARGS, get_docs), - (0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs), + (0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs), + (0, build_text_classifier, get_textcat_kwargs(), get_docs), + (0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs), ], ) def test_models_predict_consistently(seed, model_func, kwargs, get_X): @@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X): @pytest.mark.parametrize( "seed,dropout,model_func,kwargs,get_X", [ - (0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs), - (0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs), - (0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs), + (0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs), + (0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs), + (0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs), ], ) def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X): From 6a6b09bd32f6e687246f1162a645040e90019570 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 22:59:42 +0200 Subject: [PATCH 22/37] Update morphologizer model --- spacy/pipeline/morphologizer.pyx | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 56ef44cb9..e76b7fb77 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -22,17 +22,23 @@ default_model_config = """ @architectures = "spacy.Tagger.v1" [model.tok2vec] -@architectures = "spacy.HashCharEmbedCNN.v1" -pretrained_vectors = null +@architectures = "spacy.Tok2Vec.v1" + +[model.tok2vec.embed] +@architectures = "spacy.CharacterEmbed.v1" width = 128 -depth = 4 -embed_size = 7000 -window_size = 1 -maxout_pieces = 3 +rows = 7000 nM = 64 nC = 8 -dropout = null + +[model.tok2vec.encode] +@architectures = "spacy.MaxoutWindowEncoder.v1" +width = 128 +depth = 4 +window_size = 1 +maxout_pieces = 3 """ + DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"] From 00de30bcc28379ffb28be4d0b0c28ce9391eabb8 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 23:06:30 +0200 Subject: [PATCH 23/37] Update CharacterEmbed function --- spacy/ml/models/tok2vec.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 881f25a3b..acd9dc0b0 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -119,15 +119,16 @@ def MultiHashEmbed( @registry.architectures.register("spacy.CharacterEmbed.v1") -def CharacterEmbed(columns, width, rows, nM, nC, features, dropout): - norm = HashEmbed( - nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5 +def CharacterEmbed(width: int, rows: int, nM: int, nC: int): + model = concatenate( + _character_embed.CharacterEmbed(nM=nM, nC=nC), + chain( + FeatureExtractor([NORM]), + with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)) + ) ) - chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) - with Model.define_operators({">>": chain, "|": concatenate}): - embed_layer = chr_embed | features >> with_array(norm) - embed_layer.set_dim("nO", nM * nC + width) - return embed_layer + model.set_dim("nO", nM * nC + width) + return model @registry.architectures.register("spacy.MaxoutWindowEncoder.v1") From c7d1ece3ebf5e4fb45d14faf106a0aba7b179ee2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Jul 2020 23:06:46 +0200 Subject: [PATCH 24/37] Update tests --- .../tests/serialize/test_serialize_config.py | 1 + spacy/tests/test_tok2vec.py | 52 ++++++++++--------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 25673b8c4..ef5c7f8f4 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -5,6 +5,7 @@ from spacy.lang.en import English from spacy.language import Language from spacy.util import registry, deep_merge_configs, load_model_from_config from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model +from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder from ..util import make_tempdir diff --git a/spacy/tests/test_tok2vec.py b/spacy/tests/test_tok2vec.py index 32f4c5774..6b7170fe3 100644 --- a/spacy/tests/test_tok2vec.py +++ b/spacy/tests/test_tok2vec.py @@ -1,6 +1,7 @@ import pytest from spacy.ml.models.tok2vec import build_Tok2Vec_model +from spacy.ml.models.tok2vec import MultiHashEmbed, MaxoutWindowEncoder from spacy.vocab import Vocab from spacy.tokens import Doc @@ -13,18 +14,18 @@ def test_empty_doc(): vocab = Vocab() doc = Doc(vocab, words=[]) tok2vec = build_Tok2Vec_model( - width, - embed_size, - pretrained_vectors=None, - conv_depth=4, - bilstm_depth=0, - window_size=1, - maxout_pieces=3, - subword_features=True, - char_embed=False, - nM=64, - nC=8, - dropout=None, + MultiHashEmbed( + width=width, + rows=embed_size, + also_use_static_vectors=False, + also_embed_subwords=True + ), + MaxoutWindowEncoder( + width=width, + depth=4, + window_size=1, + maxout_pieces=3 + ) ) tok2vec.initialize() vectors, backprop = tok2vec.begin_update([doc]) @@ -38,18 +39,18 @@ def test_empty_doc(): def test_tok2vec_batch_sizes(batch_size, width, embed_size): batch = get_batch(batch_size) tok2vec = build_Tok2Vec_model( - width, - embed_size, - pretrained_vectors=None, - conv_depth=4, - bilstm_depth=0, - window_size=1, - maxout_pieces=3, - subword_features=True, - char_embed=False, - nM=64, - nC=8, - dropout=None, + MultiHashEmbed( + width=width, + rows=embed_size, + also_use_static_vectors=False, + also_embed_subwords=True + ), + MaxoutWindowEncoder( + width=width, + depth=4, + window_size=1, + maxout_pieces=3, + ) ) tok2vec.initialize() vectors, backprop = tok2vec.begin_update(batch) @@ -59,6 +60,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): # fmt: off +@pytest.mark.xfail(reason="TODO: Update for new signature") @pytest.mark.parametrize( "tok2vec_config", [ @@ -75,7 +77,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): # fmt: on def test_tok2vec_configs(tok2vec_config): docs = get_batch(3) - tok2vec = build_Tok2Vec_model(**tok2vec_config) + tok2vec = build_Tok2Vec_model_from_old_args(**tok2vec_config) tok2vec.initialize(docs) vectors, backprop = tok2vec.begin_update(docs) assert len(vectors) == len(docs) From 97d36515747640b11e8447a6177ab867353b0915 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 13:38:13 +0200 Subject: [PATCH 25/37] Fix stray link_vectors_to_models call --- spacy/language.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 4b7651d65..0ec29f3b1 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1615,8 +1615,6 @@ def _fix_pretrained_vectors_name(nlp: Language) -> None: nlp.vocab.vectors.name = vectors_name else: raise ValueError(Errors.E092) - if nlp.vocab.vectors.size != 0: - link_vectors_to_models(nlp.vocab) for name, proc in nlp.pipeline: if not hasattr(proc, "cfg"): continue From 5ae862857108db67d331ed68e703b034436b9e08 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 13:38:30 +0200 Subject: [PATCH 26/37] Fix CharacterEmbed layer --- spacy/ml/_character_embed.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spacy/ml/_character_embed.py b/spacy/ml/_character_embed.py index 57fbf73b3..ab0cb85c7 100644 --- a/spacy/ml/_character_embed.py +++ b/spacy/ml/_character_embed.py @@ -1,16 +1,18 @@ +from typing import List from thinc.api import Model +from thinc.types import Floats2d +from ..tokens import Doc -def CharacterEmbed(nM, nC): +def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]: # nM: Number of dimensions per character. nC: Number of characters. - nO = nM * nC if (nM is not None and nC is not None) else None return Model( "charembed", forward, init=init, - dims={"nM": nM, "nC": nC, "nO": nO, "nV": 256}, + dims={"nM": nM, "nC": nC, "nO": nM * nC, "nV": 256}, params={"E": None}, - ).initialize() + ) def init(model, X=None, Y=None): From 07b47eaac8bc169fdf88677e0660b34ea5f24d7a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 13:38:41 +0200 Subject: [PATCH 27/37] Update tok2vec layer --- spacy/ml/models/tok2vec.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index acd9dc0b0..d81c9f918 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -107,11 +107,9 @@ def MultiHashEmbed( ) else: model = chain( - chain( - FeatureExtractor(cols), - list2ragged(), - with_array(concatenate(*embeddings)), - ), + FeatureExtractor(cols), + list2ragged(), + with_array(concatenate(*embeddings)), with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)), ragged2list(), ) @@ -120,14 +118,18 @@ def MultiHashEmbed( @registry.architectures.register("spacy.CharacterEmbed.v1") def CharacterEmbed(width: int, rows: int, nM: int, nC: int): - model = concatenate( - _character_embed.CharacterEmbed(nM=nM, nC=nC), - chain( - FeatureExtractor([NORM]), - with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)) - ) + model = chain( + concatenate( + chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()), + chain( + FeatureExtractor([NORM]), + list2ragged(), + with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)) + ) + ), + with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)), + ragged2list() ) - model.set_dim("nO", nM * nC + width) return model @@ -153,8 +155,12 @@ def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: def MishWindowEncoder(width, window_size, depth): cnn = chain( expand_window(window_size=window_size), - Mish(nO=width, nI=width * ((window_size * 2) + 1)), - LayerNorm(width), + Mish( + nO=width, + nI=width * ((window_size * 2) + 1), + dropout=0.0, + normalize=True + ), ) model = clone(residual(cnn), depth) model.set_dim("nO", width) From f0cf4a2dca7cd2685d0842dbe5111d541288d661 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 13:47:37 +0200 Subject: [PATCH 28/37] Update tests --- .../tests/serialize/test_serialize_config.py | 4 +-- spacy/tests/test_tok2vec.py | 29 ++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index ef5c7f8f4..ce35add42 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -41,7 +41,7 @@ factory = "tagger" @architectures = "spacy.Tagger.v1" [components.tagger.model.tok2vec] -@architectures = "spacy.Tok2VecTensors.v1" +@architectures = "spacy.Tok2VecListener.v1" width = ${components.tok2vec.model:width} """ @@ -71,7 +71,7 @@ def my_parser(): tok2vec = build_Tok2Vec_model( MultiHashEmbed( width=321, - embed_size=5432, + rows=5432, also_embed_subwords=True, also_use_static_vectors=False ), diff --git a/spacy/tests/test_tok2vec.py b/spacy/tests/test_tok2vec.py index 6b7170fe3..76b5e64df 100644 --- a/spacy/tests/test_tok2vec.py +++ b/spacy/tests/test_tok2vec.py @@ -1,7 +1,8 @@ import pytest from spacy.ml.models.tok2vec import build_Tok2Vec_model -from spacy.ml.models.tok2vec import MultiHashEmbed, MaxoutWindowEncoder +from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed +from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder from spacy.vocab import Vocab from spacy.tokens import Doc @@ -60,26 +61,26 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): # fmt: off -@pytest.mark.xfail(reason="TODO: Update for new signature") @pytest.mark.parametrize( - "tok2vec_config", + "width,embed_arch,embed_config,encode_arch,encode_config", [ - {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None}, - {"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None}, + (8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}), + (8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}), + (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}), + (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}), ], ) # fmt: on -def test_tok2vec_configs(tok2vec_config): +def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config): + embed_config["width"] = width + encode_config["width"] = width docs = get_batch(3) - tok2vec = build_Tok2Vec_model_from_old_args(**tok2vec_config) + tok2vec = build_Tok2Vec_model( + embed_arch(**embed_config), + encode_arch(**encode_config) + ) tok2vec.initialize(docs) vectors, backprop = tok2vec.begin_update(docs) assert len(vectors) == len(docs) - assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) + assert vectors[0].shape == (len(docs[0]), width) backprop(vectors) From 4bbbb41bf8d70a3acfd45d9da2172c4401fc5452 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 13:48:34 +0200 Subject: [PATCH 29/37] Update config --- examples/experiments/ptb-joint-pos-dep/defaults.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/experiments/ptb-joint-pos-dep/defaults.cfg b/examples/experiments/ptb-joint-pos-dep/defaults.cfg index 5850eaf3a..eed76cb7b 100644 --- a/examples/experiments/ptb-joint-pos-dep/defaults.cfg +++ b/examples/experiments/ptb-joint-pos-dep/defaults.cfg @@ -4,7 +4,7 @@ patience = 10000 eval_frequency = 200 dropout = 0.2 init_tok2vec = null -vectors = "tmp/fasttext_vectors/vocab" +vectors = null max_epochs = 100 orth_variant_level = 0.0 gold_preproc = true @@ -85,7 +85,7 @@ width = ${components.tok2vec.model.encode:width} width = ${components.tok2vec.model.encode:width} rows = 2000 also_embed_subwords = true -also_use_static_vectors = true +also_use_static_vectors = false [components.tok2vec.model.encode] @architectures = "spacy.MaxoutWindowEncoder.v1" From 105cf2996785961a2c6bf717c3d80736daaa1c60 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:23:13 +0200 Subject: [PATCH 30/37] Fix DocBin --- spacy/tokens/_serialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index 0a5fd0c59..bc371199a 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -50,7 +50,7 @@ class DocBin: self, attrs: Iterable[str] = ALL_ATTRS, store_user_data: bool = False, - docs=Iterable[Doc], + docs: Iterable[Doc]=[], ) -> None: """Create a DocBin object to hold serialized annotations. From b5bbfec591b9cb659bf51add783e0935fbee452b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:26:44 +0200 Subject: [PATCH 31/37] Update config --- examples/experiments/onto-joint/defaults.cfg | 83 +++++++++++--------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/examples/experiments/onto-joint/defaults.cfg b/examples/experiments/onto-joint/defaults.cfg index 95c2f28bd..d37929ff1 100644 --- a/examples/experiments/onto-joint/defaults.cfg +++ b/examples/experiments/onto-joint/defaults.cfg @@ -20,20 +20,20 @@ seed = 0 accumulate_gradient = 1 use_pytorch_for_gpu_memory = false # Control how scores are printed and checkpoints are evaluated. -scores = ["speed", "tags_acc", "uas", "las", "ents_f"] +eval_batch_size = 128 score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2} -# These settings are invalid for the transformer models. init_tok2vec = null discard_oversize = false -omit_extra_lookups = false batch_by = "words" -use_gpu = -1 raw_text = null tag_map = null +vectors = null +base_model = null +morph_rules = null [training.batch_size] @schedules = "compounding.v1" -start = 1000 +start = 100 stop = 1000 compound = 1.001 @@ -46,74 +46,79 @@ L2 = 0.01 grad_clip = 1.0 use_averages = false eps = 1e-8 -#learn_rate = 0.001 - -[training.optimizer.learn_rate] -@schedules = "warmup_linear.v1" -warmup_steps = 250 -total_steps = 20000 -initial_rate = 0.001 +learn_rate = 0.001 [nlp] lang = "en" -base_model = null -vectors = null +load_vocab_data = false +pipeline = ["tok2vec", "ner", "tagger", "parser"] -[nlp.pipeline] +[nlp.tokenizer] +@tokenizers = "spacy.Tokenizer.v1" -[nlp.pipeline.tok2vec] +[nlp.lemmatizer] +@lemmatizers = "spacy.Lemmatizer.v1" + +[components] + +[components.tok2vec] factory = "tok2vec" - -[nlp.pipeline.ner] +[components.ner] factory = "ner" learn_tokens = false min_action_freq = 1 -[nlp.pipeline.tagger] +[components.tagger] factory = "tagger" -[nlp.pipeline.parser] +[components.parser] factory = "parser" learn_tokens = false min_action_freq = 30 -[nlp.pipeline.tagger.model] +[components.tagger.model] @architectures = "spacy.Tagger.v1" -[nlp.pipeline.tagger.model.tok2vec] -@architectures = "spacy.Tok2VecTensors.v1" -width = ${nlp.pipeline.tok2vec.model:width} +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.encode:width} -[nlp.pipeline.parser.model] +[components.parser.model] @architectures = "spacy.TransitionBasedParser.v1" nr_feature_tokens = 8 hidden_width = 128 maxout_pieces = 2 use_upper = true -[nlp.pipeline.parser.model.tok2vec] -@architectures = "spacy.Tok2VecTensors.v1" -width = ${nlp.pipeline.tok2vec.model:width} +[components.parser.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.encode:width} -[nlp.pipeline.ner.model] +[components.ner.model] @architectures = "spacy.TransitionBasedParser.v1" nr_feature_tokens = 3 hidden_width = 128 maxout_pieces = 2 use_upper = true -[nlp.pipeline.ner.model.tok2vec] -@architectures = "spacy.Tok2VecTensors.v1" -width = ${nlp.pipeline.tok2vec.model:width} +[components.ner.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.encode:width} -[nlp.pipeline.tok2vec.model] -@architectures = "spacy.HashEmbedCNN.v1" -pretrained_vectors = ${nlp:vectors} -width = 128 +[components.tok2vec.model] +@architectures = "spacy.Tok2Vec.v1" + +[components.tok2vec.model.embed] +@architectures = "spacy.MultiHashEmbed.v1" +width = ${components.tok2vec.model.encode:width} +rows = 2000 +also_embed_subwords = true +also_use_static_vectors = false + +[components.tok2vec.model.encode] +@architectures = "spacy.MaxoutWindowEncoder.v1" +width = 96 depth = 4 window_size = 1 -embed_size = 7000 maxout_pieces = 3 -subword_features = true -dropout = ${training:dropout} From 9e1b11dd8158b938798ae5ba7384623b8ed535a1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:35:36 +0200 Subject: [PATCH 32/37] Update vectors in textcat --- spacy/ml/models/textcat.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index a64a2487a..139917581 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -9,6 +9,7 @@ from ... import util from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...util import registry from ..extract_ngrams import extract_ngrams +from ..staticvectors import StaticVectors @registry.architectures.register("spacy.TextCatCNN.v1") @@ -101,13 +102,7 @@ def build_text_classifier( ) if pretrained_vectors: - nlp = util.load_model(pretrained_vectors) - vectors = nlp.vocab.vectors - vector_dim = vectors.data.shape[1] - - static_vectors = SpacyVectors(vectors) >> with_array( - Linear(width, vector_dim) - ) + static_vectors = StaticVectors(width) vector_layer = trained_vectors | static_vectors vectors_width = width * 2 else: @@ -158,14 +153,10 @@ def build_text_classifier( @registry.architectures.register("spacy.TextCatLowData.v1") def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None): - nlp = util.load_model(pretrained_vectors) - vectors = nlp.vocab.vectors - vector_dim = vectors.data.shape[1] - # Note, before v.3, this was the default if setting "low_data" and "pretrained_dims" with Model.define_operators({">>": chain, "**": clone}): model = ( - SpacyVectors(vectors) + StaticVectors(width) >> list2ragged() >> with_ragged(0, Linear(width, vector_dim)) >> ParametricAttention(width) From c99a65307037440406a3b35d04440bb671e931d8 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:38:15 +0200 Subject: [PATCH 33/37] Adjust textcat model --- spacy/ml/models/textcat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 139917581..53200c165 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -158,7 +158,6 @@ def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None): model = ( StaticVectors(width) >> list2ragged() - >> with_ragged(0, Linear(width, vector_dim)) >> ParametricAttention(width) >> reduce_sum() >> residual(Relu(width, width)) ** 2 From 142b58be92dbc1ee63d3424f7afaf4fe44cab417 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:45:09 +0200 Subject: [PATCH 34/37] Fix import --- spacy/ml/models/tok2vec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index d81c9f918..1460b3005 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -2,7 +2,7 @@ from typing import Optional, List from thinc.api import chain, clone, concatenate, with_array, with_padded from thinc.api import Model, noop, list2ragged, ragged2list from thinc.api import FeatureExtractor, HashEmbed -from thinc.api import expand_window, residual, Maxout, Mish +from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM from thinc.types import Floats2d from ...tokens import Doc From 2af741d7e3b96be4e24319c2e8284fa168c6ab99 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:56:01 +0200 Subject: [PATCH 35/37] Fix train arg --- spacy/cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index e152ae8ea..b0bc145ff 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -246,7 +246,7 @@ def create_evaluation_callback( ) -> Callable[[], Tuple[float, Dict[str, float]]]: def evaluate() -> Tuple[float, Dict[str, float]]: dev_examples = corpus.dev_dataset( - nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True + nlp, gold_preproc=cfg["gold_preproc"] ) dev_examples = list(dev_examples) n_words = sum(len(ex.predicted) for ex in dev_examples) From ebdb3f5f04e32d05ed59046b4655a4bc6869bc79 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 14:56:11 +0200 Subject: [PATCH 36/37] Fix config --- examples/experiments/onto-joint/defaults.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/experiments/onto-joint/defaults.cfg b/examples/experiments/onto-joint/defaults.cfg index d37929ff1..0e0d4d4c3 100644 --- a/examples/experiments/onto-joint/defaults.cfg +++ b/examples/experiments/onto-joint/defaults.cfg @@ -21,7 +21,7 @@ accumulate_gradient = 1 use_pytorch_for_gpu_memory = false # Control how scores are printed and checkpoints are evaluated. eval_batch_size = 128 -score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2} +score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2} init_tok2vec = null discard_oversize = false batch_by = "words" From f7adc9d3b713ea473469e2371f5ce816bdc7e406 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 29 Jul 2020 17:10:06 +0200 Subject: [PATCH 37/37] Start rewriting vectors docs --- website/docs/usage/vectors-embeddings.md | 156 ++++++++++------------- 1 file changed, 68 insertions(+), 88 deletions(-) diff --git a/website/docs/usage/vectors-embeddings.md b/website/docs/usage/vectors-embeddings.md index 7725068ec..8f6315901 100644 --- a/website/docs/usage/vectors-embeddings.md +++ b/website/docs/usage/vectors-embeddings.md @@ -5,54 +5,82 @@ menu: - ['Other Embeddings', 'embeddings'] --- - - ## Word vectors and similarity -> #### Training word vectors -> -> Dense, real valued vectors representing distributional similarity information -> are now a cornerstone of practical NLP. The most common way to train these -> vectors is the [Word2vec](https://en.wikipedia.org/wiki/Word2vec) family of -> algorithms. If you need to train a word2vec model, we recommend the -> implementation in the Python library -> [Gensim](https://radimrehurek.com/gensim/). +An old idea in linguistics is that you can "know a word by the company it +keeps": that is, word meanings can be understood relationally, based on their +patterns of usage. This idea inspired a branch of NLP research known as +"distributional semantics" that has aimed to compute databases of lexical knowledge +automatically. The [Word2vec](https://en.wikipedia.org/wiki/Word2vec) family of +algorithms are a key milestone in this line of research. For simplicity, we +will refer to a distributional word representation as a "word vector", and +algorithms that computes word vectors (such as GloVe, FastText, etc) as +"word2vec algorithms". -import Vectors101 from 'usage/101/\_vectors-similarity.md' +Word vector tables are included in some of the spaCy model packages we +distribute, and you can easily create your own model packages with word vectors +you train or download yourself. In some cases you can also add word vectors to +an existing pipeline, although each pipeline can only have a single word +vectors table, and a model package that already has word vectors is unlikely to +work correctly if you replace the vectors with new ones. - +## What's a word vector? -### Customizing word vectors {#custom} +For spaCy's purposes, a "word vector" is a 1-dimensional slice from +a 2-dimensional _vectors table_, with a deterministic mapping from word types +to rows in the table. -Word vectors let you import knowledge from raw text into your model. The -knowledge is represented as a table of numbers, with one row per term in your -vocabulary. If two terms are used in similar contexts, the algorithm that learns -the vectors should assign them **rows that are quite similar**, while words that -are used in different contexts will have quite different values. This lets you -use the row-values assigned to the words as a kind of dictionary, to tell you -some things about what the words in your text mean. +```python +def what_is_a_word_vector( + word_id: int, + key2row: Dict[int, int], + vectors_table: Floats2d, + *, + default_row: int=0 +) -> Floats1d: + return vectors_table[key2row.get(word_id, default_row)] +``` -Word vectors are particularly useful for terms which **aren't well represented -in your labelled training data**. For instance, if you're doing named entity -recognition, there will always be lots of names that you don't have examples of. -For instance, imagine your training data happens to contain some examples of the -term "Microsoft", but it doesn't contain any examples of the term "Symantec". In -your raw text sample, there are plenty of examples of both terms, and they're -used in similar contexts. The word vectors make that fact available to the -entity recognition model. It still won't see examples of "Symantec" labelled as -a company. However, it'll see that "Symantec" has a word vector that usually -corresponds to company terms, so it can **make the inference**. +word2vec algorithms try to produce vectors tables that let you estimate useful +relationships between words using simple linear algebra operations. For +instance, you can often find close synonyms of a word by finding the vectors +closest to it by cosine distance, and then finding the words that are mapped to +those neighboring vectors. Word vectors can also be useful as features in +statistical models. -In order to make best use of the word vectors, you want the word vectors table -to cover a **very large vocabulary**. However, most words are rare, so most of -the rows in a large word vectors table will be accessed very rarely, or never at -all. You can usually cover more than **95% of the tokens** in your corpus with -just **a few thousand rows** in the vector table. However, it's those **5% of -rare terms** where the word vectors are **most useful**. The problem is that -increasing the size of the vector table produces rapidly diminishing returns in -coverage over these rare terms. +The key difference between word vectors and contextual language models such as +ElMo, BERT and GPT-2 is that word vectors model _lexical types_, rather than +_tokens_. If you have a list of terms with no context around them, a model like +BERT can't really help you. BERT is designed to understand language in context, +which isn't what you have. A word vectors table will be a much better fit for +your task. However, if you do have words in context --- whole sentences or +paragraphs of running text --- word vectors will only provide a very rough +approximation of what the text is about. -### Converting word vectors for use in spaCy {#converting new="2.0.10"} +Word vectors are also very computationally efficient, as they map a word to a +vector with a single indexing operation. Word vectors are therefore useful as a +way to improve the accuracy of neural network models, especially models that +are small or have received little or no pretraining. In spaCy, word vector +tables are only used as static features. spaCy does not backpropagate gradients +to the pretrained word vectors table. The static vectors table is usually used +in combination with a smaller table of learned task-specific embeddings. + +## Using word vectors directly + +spaCy stores word vector information in the `vocab.vectors` attribute, so you +can access the whole vectors table from most spaCy objects. You can also access +the vector for a `Doc`, `Span`, `Token` or `Lexeme` instance via the `vector` +attribute. If your `Doc` or `Span` has multiple tokens, the average of the +word vectors will be returned, excluding any "out of vocabulary" entries that +have no vector available. If none of the words have a vector, a zeroed vector +will be returned. + +The `vector` attribute is a read-only numpy or cupy array (depending on whether +you've configured spaCy to use GPU memory), with dtype `float32`. The array is +read-only so that spaCy can avoid unnecessary copy operations where possible. +You can modify the vectors via the `Vocab` or `Vectors` table. + +### Converting word vectors for use in spaCy Custom word vectors can be trained using a number of open-source libraries, such as [Gensim](https://radimrehurek.com/gensim), [Fast Text](https://fasttext.cc), @@ -151,20 +179,7 @@ This will create a spaCy model with vectors for the first 10,000 words in the vectors model. All other words in the vectors model are mapped to the closest vector among those retained. -### Adding vectors {#custom-vectors-add new="2"} - -spaCy's new [`Vectors`](/api/vectors) class greatly improves the way word -vectors are stored, accessed and used. The data is stored in two structures: - -- An array, which can be either on CPU or [GPU](#gpu). -- A dictionary mapping string-hashes to rows in the table. - -Keep in mind that the `Vectors` class itself has no -[`StringStore`](/api/stringstore), so you have to store the hash-to-string -mapping separately. If you need to manage the strings, you should use the -`Vectors` via the [`Vocab`](/api/vocab) class, e.g. `vocab.vectors`. To add -vectors to the vocabulary, you can use the -[`Vocab.set_vector`](/api/vocab#set_vector) method. +### Adding vectors ```python ### Adding vectors @@ -196,38 +211,3 @@ For more details on **adding hooks** and **overwriting** the built-in `Doc`, ### Storing vectors on a GPU {#gpu} -If you're using a GPU, it's much more efficient to keep the word vectors on the -device. You can do that by setting the [`Vectors.data`](/api/vectors#attributes) -attribute to a `cupy.ndarray` object if you're using spaCy or -[Chainer](https://chainer.org), or a `torch.Tensor` object if you're using -[PyTorch](http://pytorch.org). The `data` object just needs to support -`__iter__` and `__getitem__`, so if you're using another library such as -[TensorFlow](https://www.tensorflow.org), you could also create a wrapper for -your vectors data. - -```python -### spaCy, Thinc or Chainer -import cupy.cuda -from spacy.vectors import Vectors - -vector_table = numpy.zeros((3, 300), dtype="f") -vectors = Vectors(["dog", "cat", "orange"], vector_table) -with cupy.cuda.Device(0): - vectors.data = cupy.asarray(vectors.data) -``` - -```python -### PyTorch -import torch -from spacy.vectors import Vectors - -vector_table = numpy.zeros((3, 300), dtype="f") -vectors = Vectors(["dog", "cat", "orange"], vector_table) -vectors.data = torch.Tensor(vectors.data).cuda(0) -``` - -## Other embeddings {#embeddings} - - - -