Update spaCy for thinc 8.0.0 (#4920)

* Add load_from_config function

* Add train_from_config script

* Merge configs and expose via spacy.config

* Fix script

* Suggest create_evaluation_callback

* Hard-code for NER

* Fix errors

* Register command

* Add TODO

* Update train-from-config todos

* Fix imports

* Allow delayed setting of parser model nr_class

* Get train-from-config working

* Tidy up and fix scores and printing

* Hide traceback if cancelled

* Fix weighted score formatting

* Fix score formatting

* Make output_path optional

* Add Tok2Vec component

* Tidy up and add tok2vec_tensors

* Add option to copy docs in nlp.update

* Copy docs in nlp.update

* Adjust nlp.update() for set_annotations

* Don't shuffle pipes in nlp.update, decruft

* Support set_annotations arg in component update

* Support set_annotations in parser update

* Add get_gradients method

* Add get_gradients to parser

* Update errors.py

* Fix problems caused by merge

* Add _link_components method in nlp

* Add concept of 'listeners' and ControlledModel

* Support optional attributes arg in ControlledModel

* Try having tok2vec component in pipeline

* Fix tok2vec component

* Fix config

* Fix tok2vec

* Update for Example

* Update for Example

* Update config

* Add eg2doc util

* Update and add schemas/types

* Update schemas

* Fix nlp.update

* Fix tagger

* Remove hacks from train-from-config

* Remove hard-coded config str

* Calculate loss in tok2vec component

* Tidy up and use function signatures instead of models

* Support union types for registry models

* Minor cleaning in Language.update

* Make ControlledModel specifically Tok2VecListener

* Fix train_from_config

* Fix tok2vec

* Tidy up

* Add function for bilstm tok2vec

* Fix type

* Fix syntax

* Fix pytorch optimizer

* Add example configs

* Update for thinc describe changes

* Update for Thinc changes

* Update for dropout/sgd changes

* Update for dropout/sgd changes

* Unhack gradient update

* Work on refactoring _ml

* Remove _ml.py module

* WIP upgrade cli scripts for thinc

* Move some _ml stuff to util

* Import link_vectors from util

* Update train_from_config

* Import from util

* Import from util

* Temporarily add ml.component_models module

* Move ml methods

* Move typedefs

* Update load vectors

* Update gitignore

* Move imports

* Add PrecomputableAffine

* Fix imports

* Fix imports

* Fix imports

* Fix missing imports

* Update CLI scripts

* Update spacy.language

* Add stubs for building the models

* Update model definition

* Update create_default_optimizer

* Fix import

* Fix comment

* Update imports in tests

* Update imports in spacy.cli

* Fix import

* fix obsolete thinc imports

* update srsly pin

* from thinc to ml_datasets for example data such as imdb

* update ml_datasets pin

* using STATE.vectors

* small fix

* fix Sentencizer.pipe

* black formatting

* rename Affine to Linear as in thinc

* set validate explicitely to True

* rename with_square_sequences to with_list2padded

* rename with_flatten to with_list2array

* chaining layernorm

* small fixes

* revert Optimizer import

* build_nel_encoder with new thinc style

* fixes using model's get and set methods

* Tok2Vec in component models, various fixes

* fix up legacy tok2vec code

* add model initialize calls

* add in build_tagger_model

* small fixes

* setting model dims

* fixes for ParserModel

* various small fixes

* initialize thinc Models

* fixes

* consistent naming of window_size

* fixes, removing set_dropout

* work around Iterable issue

* remove legacy tok2vec

* util fix

* fix forward function of tok2vec listener

* more fixes

* trying to fix PrecomputableAffine (not succesful yet)

* alloc instead of allocate

* add morphologizer

* rename residual

* rename fixes

* Fix predict function

* Update parser and parser model

* fixing few more tests

* Fix precomputable affine

* Update component model

* Update parser model

* Move backprop padding to own function, for test

* Update test

* Fix p. affine

* Update NEL

* build_bow_text_classifier and extract_ngrams

* Fix parser init

* Fix test add label

* add build_simple_cnn_text_classifier

* Fix parser init

* Set gpu off by default in example

* Fix tok2vec listener

* Fix parser model

* Small fixes

* small fix for PyTorchLSTM parameters

* revert my_compounding hack (iterable fixed now)

* fix biLSTM

* Fix uniqued

* PyTorchRNNWrapper fix

* small fixes

* use helper function to calculate cosine loss

* small fixes for build_simple_cnn_text_classifier

* putting dropout default at 0.0 to ensure the layer gets built

* using thinc util's set_dropout_rate

* moving layer normalization inside of maxout definition to optimize dropout

* temp debugging in NEL

* fixed NEL model by using init defaults !

* fixing after set_dropout_rate refactor

* proper fix

* fix test_update_doc after refactoring optimizers in thinc

* Add CharacterEmbed layer

* Construct tagger Model

* Add missing import

* Remove unused stuff

* Work on textcat

* fix test (again :)) after optimizer refactor

* fixes to allow reading Tagger from_disk without overwriting dimensions

* don't build the tok2vec prematuraly

* fix CharachterEmbed init

* CharacterEmbed fixes

* Fix CharacterEmbed architecture

* fix imports

* renames from latest thinc update

* one more rename

* add initialize calls where appropriate

* fix parser initialization

* Update Thinc version

* Fix errors, auto-format and tidy up imports

* Fix validation

* fix if bias is cupy array

* revert for now

* ensure it's a numpy array before running bp in ParserStepModel

* no reason to call require_gpu twice

* use CupyOps.to_numpy instead of cupy directly

* fix initialize of ParserModel

* remove unnecessary import

* fixes for CosineDistance

* fix device renaming

* use refactored loss functions (Thinc PR 251)

* overfitting test for tagger

* experimental settings for the tagger: avoid zero-init and subword normalization

* clean up tagger overfitting test

* use previous default value for nP

* remove toy config

* bringing layernorm back (had a bug - fixed in thinc)

* revert setting nP explicitly

* remove setting default in constructor

* restore values as they used to be

* add overfitting test for NER

* add overfitting test for dep parser

* add overfitting test for textcat

* fixing init for linear (previously affine)

* larger eps window for textcat

* ensure doc is not None

* Require newer thinc

* Make float check vaguer

* Slop the textcat overfit test more

* Fix textcat test

* Fix exclusive classes for textcat

* fix after renaming of alloc methods

* fixing renames and mandatory arguments (staticvectors WIP)

* upgrade to thinc==8.0.0.dev3

* refer to vocab.vectors directly instead of its name

* rename alpha to learn_rate

* adding hashembed and staticvectors dropout

* upgrade to thinc 8.0.0.dev4

* add name back to avoid warning W020

* thinc dev4

* update srsly

* using thinc 8.0.0a0 !

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
Sofie Van Landeghem 2020-01-29 17:06:46 +01:00 committed by GitHub
parent 06b251dd1e
commit 569cc98982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
70 changed files with 2141 additions and 1675 deletions

4
.gitignore vendored
View File

@ -39,6 +39,7 @@ __pycache__/
.env* .env*
.~env/ .~env/
.venv .venv
env3.6/
venv/ venv/
.dev .dev
.denv .denv
@ -111,3 +112,6 @@ Desktop.ini
# Pycharm project files # Pycharm project files
*.idea *.idea
# IPython
.ipynb_checkpoints/

View File

@ -4,12 +4,12 @@ from random import shuffle
import logging import logging
import numpy as np import numpy as np
from spacy._ml import zero_init, create_default_optimizer from thinc.model import Model
from spacy.cli.pretrain import get_cossim_loss
from thinc.v2v import Model
from thinc.api import chain from thinc.api import chain
from thinc.neural._classes.affine import Affine from thinc.loss import CosineDistance
from thinc.layers import Linear
from spacy.util import create_default_optimizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,6 +34,7 @@ class EntityEncoder:
self.input_dim = input_dim self.input_dim = input_dim
self.desc_width = desc_width self.desc_width = desc_width
self.epochs = epochs self.epochs = epochs
self.distance = CosineDistance(ignore_zeros=True, normalize=False)
def apply_encoder(self, description_list): def apply_encoder(self, description_list):
if self.encoder is None: if self.encoder is None:
@ -132,21 +133,17 @@ class EntityEncoder:
def _build_network(self, orig_width, hidden_with): def _build_network(self, orig_width, hidden_with):
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
# very simple encoder-decoder model # very simple encoder-decoder model
self.encoder = Affine(hidden_with, orig_width) self.encoder = Linear(hidden_with, orig_width)
self.model = self.encoder >> zero_init( # TODO: removed the zero_init here - is oK?
Affine(orig_width, hidden_with, drop_factor=0.0) self.model = self.encoder >> Linear(orig_width, hidden_with)
) self.sgd = create_default_optimizer()
self.sgd = create_default_optimizer(self.model.ops)
def _update(self, vectors): def _update(self, vectors):
truths = self.model.ops.asarray(vectors)
predictions, bp_model = self.model.begin_update( predictions, bp_model = self.model.begin_update(
np.asarray(vectors), drop=self.DROP truths, drop=self.DROP
) )
loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) d_scores, loss = self.distance(predictions, truths)
bp_model(d_scores, sgd=self.sgd) bp_model(d_scores, sgd=self.sgd)
return loss / len(vectors) return loss / len(vectors)
@staticmethod
def _get_loss(golds, scores):
loss, gradients = get_cossim_loss(scores, golds)
return loss, gradients

View File

@ -103,7 +103,7 @@ def main(
logger.info("STEP 3: Creating and training an Entity Linking pipe") logger.info("STEP 3: Creating and training an Entity Linking pipe")
el_pipe = nlp.create_pipe( el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name, name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors,
"labels_discard": labels_discard} "labels_discard": labels_discard}
) )
el_pipe.set_kb(kb) el_pipe.set_kb(kb)

View File

@ -14,7 +14,7 @@ pip install keras==2.0.9
Compatible with: spaCy v2.0.0+ Compatible with: spaCy v2.0.0+
""" """
import ml_datasets
import plac import plac
import random import random
import pathlib import pathlib
@ -24,7 +24,6 @@ from keras.models import Sequential, model_from_json
from keras.layers import LSTM, Dense, Embedding, Bidirectional from keras.layers import LSTM, Dense, Embedding, Bidirectional
from keras.layers import TimeDistributed from keras.layers import TimeDistributed
from keras.optimizers import Adam from keras.optimizers import Adam
import thinc.extra.datasets
from spacy.compat import pickle from spacy.compat import pickle
import spacy import spacy
@ -224,7 +223,7 @@ def main(
if model_dir is not None: if model_dir is not None:
model_dir = pathlib.Path(model_dir) model_dir = pathlib.Path(model_dir)
if train_dir is None or dev_dir is None: if train_dir is None or dev_dir is None:
imdb_data = thinc.extra.datasets.imdb() imdb_data = ml_datasets.imdb()
if is_runtime: if is_runtime:
if dev_dir is None: if dev_dir is None:
dev_texts, dev_labels = zip(*imdb_data[1]) dev_texts, dev_labels = zip(*imdb_data[1])

View File

@ -0,0 +1,63 @@
[training]
patience = 10000
eval_frequency = 200
dropout = 0.2
init_tok2vec = null
vectors = null
max_epochs = 100
orth_variant_level = 0.0
gold_preproc = true
max_length = 0
use_gpu = 0
scores = ["tags_acc", "uas", "las"]
score_weights = {"las": 0.8, "tags_acc": 0.2}
limit = 0
[training.batch_size]
@schedules = "compounding.v1"
start = 100
stop = 1000
compound = 1.001
[optimizer]
@optimizers = "Adam.v1"
learn_rate = 0.001
beta1 = 0.9
beta2 = 0.999
[nlp]
lang = "en"
vectors = ${training:vectors}
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.tagger]
factory = "tagger"
[nlp.pipeline.parser]
factory = "parser"
[nlp.pipeline.tagger.model]
@architectures = "tagger_model.v1"
[nlp.pipeline.tagger.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.parser.model]
@architectures = "transition_based_parser.v1"
nr_feature_tokens = 8
hidden_width = 64
maxout_pieces = 3
[nlp.pipeline.parser.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.tok2vec.model]
@architectures = "hash_embed_bilstm.v1"
pretrained_vectors = ${nlp:vectors}
width = 96
depth = 4
embed_size = 2000

View File

@ -0,0 +1,65 @@
[training]
patience = 10000
eval_frequency = 200
dropout = 0.2
init_tok2vec = null
vectors = null
max_epochs = 100
orth_variant_level = 0.0
gold_preproc = true
max_length = 0
use_gpu = -1
scores = ["tags_acc", "uas", "las"]
score_weights = {"las": 0.8, "tags_acc": 0.2}
limit = 0
[training.batch_size]
@schedules = "compounding.v1"
start = 100
stop = 1000
compound = 1.001
[optimizer]
@optimizers = "Adam.v1"
learn_rate = 0.001
beta1 = 0.9
beta2 = 0.999
[nlp]
lang = "en"
vectors = ${training:vectors}
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.tagger]
factory = "tagger"
[nlp.pipeline.parser]
factory = "parser"
[nlp.pipeline.tagger.model]
@architectures = "tagger_model.v1"
[nlp.pipeline.tagger.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.parser.model]
@architectures = "transition_based_parser.v1"
nr_feature_tokens = 8
hidden_width = 64
maxout_pieces = 3
[nlp.pipeline.parser.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.tok2vec.model]
@architectures = "hash_embed_cnn.v1"
pretrained_vectors = ${nlp:vectors}
width = 96
depth = 4
window_size = 1
embed_size = 2000
maxout_pieces = 3

View File

@ -13,9 +13,10 @@ Prerequisites: pip install joblib
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
from pathlib import Path from pathlib import Path
import ml_datasets
from joblib import Parallel, delayed from joblib import Parallel, delayed
from functools import partial from functools import partial
import thinc.extra.datasets
import plac import plac
import spacy import spacy
from spacy.util import minibatch from spacy.util import minibatch
@ -35,7 +36,7 @@ def main(output_dir, model="en_core_web_sm", n_jobs=4, batch_size=1000, limit=10
output_dir.mkdir() output_dir.mkdir()
# load and pre-process the IMBD dataset # load and pre-process the IMBD dataset
print("Loading IMDB data...") print("Loading IMDB data...")
data, _ = thinc.extra.datasets.imdb() data, _ = ml_datasets.imdb()
texts, _ = zip(*data[-limit:]) texts, _ = zip(*data[-limit:])
print("Processing texts...") print("Processing texts...")
partitions = minibatch(texts, size=batch_size) partitions = minibatch(texts, size=batch_size)

View File

@ -16,16 +16,18 @@ the development labels, after all --- only the unlabelled text.
import plac import plac
import tqdm import tqdm
import random import random
import ml_datasets
import spacy import spacy
import thinc.extra.datasets
from spacy.util import minibatch, use_gpu, compounding from spacy.util import minibatch, use_gpu, compounding
from spacy._ml import Tok2Vec
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer
from spacy.ml.tok2vec import Tok2Vec
import numpy import numpy
def load_texts(limit=0): def load_texts(limit=0):
train, dev = thinc.extra.datasets.imdb() train, dev = ml_datasets.imdb()
train_texts, train_labels = zip(*train) train_texts, train_labels = zip(*train)
dev_texts, dev_labels = zip(*train) dev_texts, dev_labels = zip(*train)
train_texts = list(train_texts) train_texts = list(train_texts)
@ -41,7 +43,7 @@ def load_texts(limit=0):
def load_textcat_data(limit=0): def load_textcat_data(limit=0):
"""Load data from the IMDB dataset.""" """Load data from the IMDB dataset."""
# Partition off part of the train data for evaluation # Partition off part of the train data for evaluation
train_data, eval_data = thinc.extra.datasets.imdb() train_data, eval_data = ml_datasets.imdb()
random.shuffle(train_data) random.shuffle(train_data)
train_data = train_data[-limit:] train_data = train_data[-limit:]
texts, labels = zip(*train_data) texts, labels = zip(*train_data)
@ -63,17 +65,15 @@ def prefer_gpu():
def build_textcat_model(tok2vec, nr_class, width): def build_textcat_model(tok2vec, nr_class, width):
from thinc.v2v import Model, Softmax, Maxout from thinc.model import Model
from thinc.api import flatten_add_lengths, chain from thinc.layers import Softmax, chain, reduce_mean
from thinc.t2v import Pooling, sum_pool, mean_pool, max_pool from thinc.layers import list2ragged
from thinc.misc import Residual, LayerNorm
from spacy._ml import logistic, zero_init
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
model = ( model = (
tok2vec tok2vec
>> flatten_add_lengths >> list2ragged()
>> Pooling(mean_pool) >> reduce_mean()
>> Softmax(nr_class, width) >> Softmax(nr_class, width)
) )
model.tok2vec = tok2vec model.tok2vec = tok2vec
@ -81,7 +81,7 @@ def build_textcat_model(tok2vec, nr_class, width):
def block_gradients(model): def block_gradients(model):
from thinc.api import wrap from thinc.api import wrap # TODO FIX
def forward(X, drop=0.0): def forward(X, drop=0.0):
Y, _ = model.begin_update(X, drop=drop) Y, _ = model.begin_update(X, drop=drop)

View File

@ -58,7 +58,7 @@ def main(model_name, unlabelled_loc):
# yet, but I'm getting weird results from Adam. Try commenting out the # yet, but I'm getting weird results from Adam. Try commenting out the
# nlp.update(), and using Adam -- you'll find the models drift apart. # nlp.update(), and using Adam -- you'll find the models drift apart.
# I guess Adam is losing precision, introducing gradient noise? # I guess Adam is losing precision, introducing gradient noise?
optimizer.alpha = 0.1 optimizer.learn_rate = 0.1
optimizer.b1 = 0.0 optimizer.b1 = 0.0
optimizer.b2 = 0.0 optimizer.b2 = 0.0

View File

@ -17,7 +17,7 @@ import plac
import random import random
from pathlib import Path from pathlib import Path
from spacy.symbols import PERSON import srsly
from spacy.vocab import Vocab from spacy.vocab import Vocab
import spacy import spacy
@ -68,7 +68,7 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
vocab = Vocab().from_disk(vocab_path) vocab = Vocab().from_disk(vocab_path)
# create blank Language class with correct vocab # create blank Language class with correct vocab
nlp = spacy.blank("en", vocab=vocab) nlp = spacy.blank("en", vocab=vocab)
nlp.vocab.vectors.name = "spacy_pretrained_vectors" nlp.vocab.vectors.name = "nel_vectors"
print("Created blank 'en' model with vocab from '%s'" % vocab_path) print("Created blank 'en' model with vocab from '%s'" % vocab_path)
# Add a sentencizer component. Alternatively, add a dependency parser for higher accuracy. # Add a sentencizer component. Alternatively, add a dependency parser for higher accuracy.
@ -93,7 +93,7 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
nlp.add_pipe(entity_linker, last=True) nlp.add_pipe(entity_linker, last=True)
# Convert the texts to docs to make sure we have doc.ents set for the training examples. # Convert the texts to docs to make sure we have doc.ents set for the training examples.
# Also ensure that the annotated examples correspond to known identifiers in the knowlege base. # Also ensure that the annotated examples correspond to known identifiers in the knowledge base.
kb_ids = nlp.get_pipe("entity_linker").kb.get_entity_strings() kb_ids = nlp.get_pipe("entity_linker").kb.get_entity_strings()
TRAIN_DOCS = [] TRAIN_DOCS = []
for text, annotation in TRAIN_DATA: for text, annotation in TRAIN_DATA:
@ -117,6 +117,7 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
with nlp.disable_pipes(*other_pipes): # only train entity linker with nlp.disable_pipes(*other_pipes): # only train entity linker
# reset and initialize the weights randomly # reset and initialize the weights randomly
optimizer = nlp.begin_training() optimizer = nlp.begin_training()
for itn in range(n_iter): for itn in range(n_iter):
random.shuffle(TRAIN_DOCS) random.shuffle(TRAIN_DOCS)
losses = {} losses = {}

View File

@ -10,10 +10,11 @@ see the documentation:
Compatible with: spaCy v2.0.0+ Compatible with: spaCy v2.0.0+
""" """
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
import ml_datasets
import plac import plac
import random import random
from pathlib import Path from pathlib import Path
import thinc.extra.datasets
import spacy import spacy
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
@ -115,7 +116,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
def load_data(limit=0, split=0.8): def load_data(limit=0, split=0.8):
"""Load data from the IMDB dataset.""" """Load data from the IMDB dataset."""
# Partition off part of the train data for evaluation # Partition off part of the train data for evaluation
train_data, _ = thinc.extra.datasets.imdb() train_data, _ = ml_datasets.imdb()
random.shuffle(train_data) random.shuffle(train_data)
train_data = train_data[-limit:] train_data = train_data[-limit:]
texts, labels = zip(*train_data) texts, labels = zip(*train_data)

View File

@ -1,17 +1,20 @@
# Our libraries # Our libraries
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc==7.4.0.dev0 thinc==8.0.0a0
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.4.0,<1.1.0 wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0 srsly>=2.0.0,<3.0.0
catalogue>=0.0.7,<1.1.0 catalogue>=0.0.7,<1.1.0
# Third party dependencies # Third party dependencies
numpy>=1.15.0 numpy>=1.15.0
requests>=2.13.0,<3.0.0 requests>=2.13.0,<3.0.0
plac>=0.9.6,<1.2.0 plac>=0.9.6,<1.2.0
tqdm>=4.38.0,<5.0.0 tqdm>=4.38.0,<5.0.0
# Optional dependencies
jsonschema>=2.6.0,<3.1.0
pydantic>=1.0.0,<2.0.0 pydantic>=1.0.0,<2.0.0
# Development dependencies # Development dependencies
cython>=0.25 cython>=0.25

View File

@ -35,16 +35,16 @@ setup_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
thinc==7.4.0.dev0 thinc==8.0.0a0
install_requires = install_requires =
# Our libraries # Our libraries
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc==7.4.0.dev0 thinc==8.0.0a0
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
wasabi>=0.4.0,<1.1.0 wasabi>=0.4.0,<1.1.0
srsly>=0.1.0,<1.1.0 srsly>=2.0.0,<3.0.0
catalogue>=0.0.7,<1.1.0 catalogue>=0.0.7,<1.1.0
# Third-party dependencies # Third-party dependencies
setuptools setuptools

View File

@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
# These are imported as part of the API # These are imported as part of the API
from thinc.neural.util import prefer_gpu, require_gpu from thinc.util import prefer_gpu, require_gpu
from . import pipeline from . import pipeline
from .cli.info import info as cli_info from .cli.info import info as cli_info
@ -21,6 +21,9 @@ if sys.maxunicode == 65535:
raise SystemError(Errors.E130) raise SystemError(Errors.E130)
config = registry
def load(name, **overrides): def load(name, **overrides):
depr_path = overrides.get("path") depr_path = overrides.get("path")
if depr_path not in (True, False, None): if depr_path not in (True, False, None):

View File

@ -4,12 +4,14 @@ if __name__ == "__main__":
from wasabi import msg from wasabi import msg
from spacy.cli import download, link, info, package, train, pretrain, convert from spacy.cli import download, link, info, package, train, pretrain, convert
from spacy.cli import init_model, profile, evaluate, validate, debug_data from spacy.cli import init_model, profile, evaluate, validate, debug_data
from spacy.cli import train_from_config_cli
commands = { commands = {
"download": download, "download": download,
"link": link, "link": link,
"info": info, "info": info,
"train": train, "train": train,
"train-from-config": train_from_config_cli,
"pretrain": pretrain, "pretrain": pretrain,
"debug-data": debug_data, "debug-data": debug_data,
"evaluate": evaluate, "evaluate": evaluate,

View File

@ -1,982 +0,0 @@
import numpy
from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu
from thinc.t2t import ExtractWindow, ParametricAttention
from thinc.t2v import Pooling, sum_pool, mean_pool
from thinc.i2v import HashEmbed
from thinc.misc import Residual, FeatureExtracter
from thinc.misc import LayerNorm as LN
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.api import with_getitem, flatten_add_lengths
from thinc.api import uniqued, wrap, noop
from thinc.linear.linear import LinearModel
from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural.util import get_array_module, copy_array
from thinc.neural.optimizers import Adam
from thinc import describe
from thinc.describe import Dimension, Synapses, Biases, Gradient
from thinc.neural._classes.affine import _set_dimensions_if_needed
import thinc.extra.load_nlp
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE
from .errors import Errors, user_warning, Warnings
from . import util
from . import ml as new_ml
from .ml import _legacy_tok2vec
VECTORS_KEY = "spacy_pretrained_vectors"
# Backwards compatibility with <2.2.2
USE_MODEL_REGISTRY_TOK2VEC = False
def cosine(vec1, vec2):
xp = get_array_module(vec1)
norm1 = xp.linalg.norm(vec1)
norm2 = xp.linalg.norm(vec2)
if norm1 == 0.0 or norm2 == 0.0:
return 0
else:
return vec1.dot(vec2) / (norm1 * norm2)
def create_default_optimizer(ops, **cfg):
learn_rate = util.env_opt("learn_rate", 0.001)
beta1 = util.env_opt("optimizer_B1", 0.9)
beta2 = util.env_opt("optimizer_B2", 0.999)
eps = util.env_opt("optimizer_eps", 1e-8)
L2 = util.env_opt("L2_penalty", 1e-6)
max_grad_norm = util.env_opt("grad_norm_clip", 1.0)
optimizer = Adam(ops, learn_rate, L2=L2, beta1=beta1, beta2=beta2, eps=eps)
optimizer.max_grad_norm = max_grad_norm
optimizer.device = ops.device
return optimizer
@layerize
def _flatten_add_lengths(seqs, pad=0, drop=0.0):
ops = Model.ops
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths, pad=pad)
X = ops.flatten(seqs, pad=pad)
return (X, lengths), finish_update
def _zero_init(model):
def _zero_init_impl(self, *args, **kwargs):
self.W.fill(0)
model.on_init_hooks.append(_zero_init_impl)
if model.W is not None:
model.W.fill(0.0)
return model
def with_cpu(ops, model):
"""Wrap a model that should run on CPU, transferring inputs and outputs
as necessary."""
model.to_cpu()
def with_cpu_forward(inputs, drop=0.0):
cpu_outputs, backprop = model.begin_update(_to_cpu(inputs), drop=drop)
gpu_outputs = _to_device(ops, cpu_outputs)
def with_cpu_backprop(d_outputs, sgd=None):
cpu_d_outputs = _to_cpu(d_outputs)
return backprop(cpu_d_outputs, sgd=sgd)
return gpu_outputs, with_cpu_backprop
return wrap(with_cpu_forward, model)
def _to_cpu(X):
if isinstance(X, numpy.ndarray):
return X
elif isinstance(X, tuple):
return tuple([_to_cpu(x) for x in X])
elif isinstance(X, list):
return [_to_cpu(x) for x in X]
elif hasattr(X, "get"):
return X.get()
else:
return X
def _to_device(ops, X):
if isinstance(X, tuple):
return tuple([_to_device(ops, x) for x in X])
elif isinstance(X, list):
return [_to_device(ops, x) for x in X]
else:
return ops.asarray(X)
class extract_ngrams(Model):
def __init__(self, ngram_size, attr=LOWER):
Model.__init__(self)
self.ngram_size = ngram_size
self.attr = attr
def begin_update(self, docs, drop=0.0):
batch_keys = []
batch_vals = []
for doc in docs:
unigrams = doc.to_array([self.attr])
ngrams = [unigrams]
for n in range(2, self.ngram_size + 1):
ngrams.append(self.ops.ngrams(n, unigrams))
keys = self.ops.xp.concatenate(ngrams)
keys, vals = self.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys)
batch_vals.append(vals)
# The dtype here matches what thinc is expecting -- which differs per
# platform (by int definition). This should be fixed once the problem
# is fixed on Thinc's side.
lengths = self.ops.asarray(
[arr.shape[0] for arr in batch_keys], dtype=numpy.int_
)
batch_keys = self.ops.xp.concatenate(batch_keys)
batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f")
return (batch_keys, batch_vals, lengths), None
@describe.on_data(
_set_dimensions_if_needed, lambda model, X, y: model.init_weights(model)
)
@describe.attributes(
nI=Dimension("Input size"),
nF=Dimension("Number of features"),
nO=Dimension("Output size"),
nP=Dimension("Maxout pieces"),
W=Synapses("Weights matrix", lambda obj: (obj.nF, obj.nO, obj.nP, obj.nI)),
b=Biases("Bias vector", lambda obj: (obj.nO, obj.nP)),
pad=Synapses(
"Pad",
lambda obj: (1, obj.nF, obj.nO, obj.nP),
lambda M, ops: ops.normal_init(M, 1.0),
),
d_W=Gradient("W"),
d_pad=Gradient("pad"),
d_b=Gradient("b"),
)
class PrecomputableAffine(Model):
def __init__(self, nO=None, nI=None, nF=None, nP=None, **kwargs):
Model.__init__(self, **kwargs)
self.nO = nO
self.nP = nP
self.nI = nI
self.nF = nF
def begin_update(self, X, drop=0.0):
Yf = self.ops.gemm(
X, self.W.reshape((self.nF * self.nO * self.nP, self.nI)), trans2=True
)
Yf = Yf.reshape((Yf.shape[0], self.nF, self.nO, self.nP))
Yf = self._add_padding(Yf)
def backward(dY_ids, sgd=None):
dY, ids = dY_ids
dY, ids = self._backprop_padding(dY, ids)
Xf = X[ids]
Xf = Xf.reshape((Xf.shape[0], self.nF * self.nI))
self.d_b += dY.sum(axis=0)
dY = dY.reshape((dY.shape[0], self.nO * self.nP))
Wopfi = self.W.transpose((1, 2, 0, 3))
Wopfi = self.ops.xp.ascontiguousarray(Wopfi)
Wopfi = Wopfi.reshape((self.nO * self.nP, self.nF * self.nI))
dXf = self.ops.gemm(dY.reshape((dY.shape[0], self.nO * self.nP)), Wopfi)
# Reuse the buffer
dWopfi = Wopfi
dWopfi.fill(0.0)
self.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
dWopfi = dWopfi.reshape((self.nO, self.nP, self.nF, self.nI))
# (o, p, f, i) --> (f, o, p, i)
self.d_W += dWopfi.transpose((2, 0, 1, 3))
if sgd is not None:
sgd(self._mem.weights, self._mem.gradient, key=self.id)
return dXf.reshape((dXf.shape[0], self.nF, self.nI))
return Yf, backward
def _add_padding(self, Yf):
Yf_padded = self.ops.xp.vstack((self.pad, Yf))
return Yf_padded
def _backprop_padding(self, dY, ids):
# (1, nF, nO, nP) += (nN, nF, nO, nP) where IDs (nN, nF) < 0
mask = ids < 0.0
mask = mask.sum(axis=1)
d_pad = dY * mask.reshape((ids.shape[0], 1, 1))
self.d_pad += d_pad.sum(axis=0)
return dY, ids
@staticmethod
def init_weights(model):
"""This is like the 'layer sequential unit variance', but instead
of taking the actual inputs, we randomly generate whitened data.
Why's this all so complicated? We have a huge number of inputs,
and the maxout unit makes guessing the dynamics tricky. Instead
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
if (model.W ** 2).sum() != 0.0:
return
ops = model.ops
xp = ops.xp
ops.normal_init(model.W, model.nF * model.nI, inplace=True)
ids = ops.allocate((5000, model.nF), dtype="f")
ids += xp.random.uniform(0, 1000, ids.shape)
ids = ops.asarray(ids, dtype="i")
tokvecs = ops.allocate((5000, model.nI), dtype="f")
tokvecs += xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
tokvecs.shape
)
def predict(ids, tokvecs):
# nS ids. nW tokvecs. Exclude the padding array.
hiddens = model(tokvecs[:-1]) # (nW, f, o, p)
vectors = model.ops.allocate((ids.shape[0], model.nO * model.nP), dtype="f")
# need nS vectors
hiddens = hiddens.reshape(
(hiddens.shape[0] * model.nF, model.nO * model.nP)
)
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
vectors = vectors.reshape((vectors.shape[0], model.nO, model.nP))
vectors += model.b
vectors = model.ops.asarray(vectors)
if model.nP >= 2:
return model.ops.maxout(vectors)[0]
else:
return vectors * (vectors >= 0)
tol_var = 0.01
tol_mean = 0.01
t_max = 10
t_i = 0
for t_i in range(t_max):
acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var:
model.W /= model.ops.xp.sqrt(var)
elif abs(mean) >= tol_mean:
model.b -= mean
else:
break
def link_vectors_to_models(vocab):
vectors = vocab.vectors
if vectors.name is None:
vectors.name = VECTORS_KEY
if vectors.data.size != 0:
user_warning(Warnings.W020.format(shape=vectors.data.shape))
ops = Model.ops
for word in vocab:
if word.orth in vectors.key2row:
word.rank = vectors.key2row[word.orth]
else:
word.rank = 0
data = ops.asarray(vectors.data)
# Set an entry here, so that vectors are accessed by StaticVectors
# (unideal, I know)
key = (ops.device, vectors.name)
if key in thinc.extra.load_nlp.VECTORS:
if thinc.extra.load_nlp.VECTORS[key].shape != data.shape:
# This is a hack to avoid the problem in #3853. Maybe we should
# print a warning as well?
old_name = vectors.name
new_name = f"{vectors.name}_{data.shape[0]}"
user_warning(Warnings.W019.format(old=old_name, new=new_name))
vectors.name = new_name
key = (ops.device, vectors.name)
thinc.extra.load_nlp.VECTORS[key] = data
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
import torch.nn
from thinc.api import with_square_sequences
from thinc.extra.wrappers import PyTorchWrapperRNN
if depth == 0:
return layerize(noop())
model = torch.nn.LSTM(nI, nO // 2, depth, bidirectional=True, dropout=dropout)
return with_square_sequences(PyTorchWrapperRNN(model))
def Tok2Vec(width, embed_size, **kwargs):
if not USE_MODEL_REGISTRY_TOK2VEC:
# Preserve prior tok2vec for backwards compat, in v2.2.2
return _legacy_tok2vec.Tok2Vec(width, embed_size, **kwargs)
pretrained_vectors = kwargs.get("pretrained_vectors", None)
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3)
subword_features = kwargs.get("subword_features", True)
char_embed = kwargs.get("char_embed", False)
conv_depth = kwargs.get("conv_depth", 4)
bilstm_depth = kwargs.get("bilstm_depth", 0)
conv_window = kwargs.get("conv_window", 1)
cols = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
doc2feats_cfg = {"arch": "spacy.Doc2Feats.v1", "config": {"columns": cols}}
if char_embed:
embed_cfg = {
"arch": "spacy.CharacterEmbed.v1",
"config": {
"width": 64,
"chars": 6,
"@mix": {
"arch": "spacy.LayerNormalizedMaxout.v1",
"config": {"width": width, "pieces": 3},
},
"@embed_features": None,
},
}
else:
embed_cfg = {
"arch": "spacy.MultiHashEmbed.v1",
"config": {
"width": width,
"rows": embed_size,
"columns": cols,
"use_subwords": subword_features,
"@pretrained_vectors": None,
"@mix": {
"arch": "spacy.LayerNormalizedMaxout.v1",
"config": {"width": width, "pieces": 3},
},
},
}
if pretrained_vectors:
embed_cfg["config"]["@pretrained_vectors"] = {
"arch": "spacy.PretrainedVectors.v1",
"config": {
"vectors_name": pretrained_vectors,
"width": width,
"column": cols.index("ID"),
},
}
if cnn_maxout_pieces >= 2:
cnn_cfg = {
"arch": "spacy.MaxoutWindowEncoder.v1",
"config": {
"width": width,
"window_size": conv_window,
"pieces": cnn_maxout_pieces,
"depth": conv_depth,
},
}
else:
cnn_cfg = {
"arch": "spacy.MishWindowEncoder.v1",
"config": {"width": width, "window_size": conv_window, "depth": conv_depth},
}
bilstm_cfg = {
"arch": "spacy.TorchBiLSTMEncoder.v1",
"config": {"width": width, "depth": bilstm_depth},
}
if conv_depth == 0 and bilstm_depth == 0:
encode_cfg = {}
elif conv_depth >= 1 and bilstm_depth >= 1:
encode_cfg = {
"arch": "thinc.FeedForward.v1",
"config": {"children": [cnn_cfg, bilstm_cfg]},
}
elif conv_depth >= 1:
encode_cfg = cnn_cfg
else:
encode_cfg = bilstm_cfg
config = {"@doc2feats": doc2feats_cfg, "@embed": embed_cfg, "@encode": encode_cfg}
return new_ml.Tok2Vec(config)
def reapply(layer, n_times):
def reapply_fwd(X, drop=0.0):
backprops = []
for i in range(n_times):
Y, backprop = layer.begin_update(X, drop=drop)
X = Y
backprops.append(backprop)
def reapply_bwd(dY, sgd=None):
dX = None
for backprop in reversed(backprops):
dY = backprop(dY, sgd=sgd)
if dX is None:
dX = dY
else:
dX += dY
return dX
return Y, reapply_bwd
return wrap(reapply_fwd, layer)
def asarray(ops, dtype):
def forward(X, drop=0.0):
return ops.asarray(X, dtype=dtype), None
return layerize(forward)
def _divide_array(X, size):
parts = []
index = 0
while index < len(X):
parts.append(X[index : index + size])
index += size
return parts
def get_col(idx):
if idx < 0:
raise IndexError(Errors.E066.format(value=idx))
def forward(X, drop=0.0):
if isinstance(X, numpy.ndarray):
ops = NumpyOps()
else:
ops = CupyOps()
output = ops.xp.ascontiguousarray(X[:, idx], dtype=X.dtype)
def backward(y, sgd=None):
dX = ops.allocate(X.shape)
dX[:, idx] += y
return dX
return output, backward
return layerize(forward)
def doc2feats(cols=None):
if cols is None:
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
def forward(docs, drop=0.0):
feats = []
for doc in docs:
feats.append(doc.to_array(cols))
return feats, None
model = layerize(forward)
model.cols = cols
return model
def print_shape(prefix):
def forward(X, drop=0.0):
return X, lambda dX, **kwargs: dX
return layerize(forward)
@layerize
def get_token_vectors(tokens_attrs_vectors, drop=0.0):
tokens, attrs, vectors = tokens_attrs_vectors
def backward(d_output, sgd=None):
return (tokens, d_output)
return vectors, backward
@layerize
def logistic(X, drop=0.0):
xp = get_array_module(X)
if not isinstance(X, xp.ndarray):
X = xp.asarray(X)
# Clip to range (-10, 10)
X = xp.minimum(X, 10.0, X)
X = xp.maximum(X, -10.0, X)
Y = 1.0 / (1.0 + xp.exp(-X))
def logistic_bwd(dY, sgd=None):
dX = dY * (Y * (1 - Y))
return dX
return Y, logistic_bwd
def zero_init(model):
def _zero_init_impl(self, X, y):
self.W.fill(0)
model.on_data_hooks.append(_zero_init_impl)
return model
def getitem(i):
def getitem_fwd(X, drop=0.0):
return X[i], None
return layerize(getitem_fwd)
@describe.attributes(
W=Synapses("Weights matrix", lambda obj: (obj.nO, obj.nI), lambda W, ops: None)
)
class MultiSoftmax(Affine):
"""Neural network layer that predicts several multi-class attributes at once.
For instance, we might predict one class with 6 variables, and another with 5.
We predict the 11 neurons required for this, and then softmax them such
that columns 0-6 make a probability distribution and coumns 6-11 make another.
"""
name = "multisoftmax"
def __init__(self, out_sizes, nI=None, **kwargs):
Model.__init__(self, **kwargs)
self.out_sizes = out_sizes
self.nO = sum(out_sizes)
self.nI = nI
def predict(self, input__BI):
output__BO = self.ops.affine(self.W, self.b, input__BI)
i = 0
for out_size in self.out_sizes:
self.ops.softmax(output__BO[:, i : i + out_size], inplace=True)
i += out_size
return output__BO
def begin_update(self, input__BI, drop=0.0):
output__BO = self.predict(input__BI)
def finish_update(grad__BO, sgd=None):
self.d_W += self.ops.gemm(grad__BO, input__BI, trans1=True)
self.d_b += grad__BO.sum(axis=0)
grad__BI = self.ops.gemm(grad__BO, self.W)
if sgd is not None:
sgd(self._mem.weights, self._mem.gradient, key=self.id)
return grad__BI
return output__BO, finish_update
def build_tagger_model(nr_class, **cfg):
embed_size = util.env_opt("embed_size", 2000)
if "token_vector_width" in cfg:
token_vector_width = cfg["token_vector_width"]
else:
token_vector_width = util.env_opt("token_vector_width", 96)
pretrained_vectors = cfg.get("pretrained_vectors")
subword_features = cfg.get("subword_features", True)
with Model.define_operators({">>": chain, "+": add}):
if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
tok2vec = Tok2Vec(
token_vector_width,
embed_size,
subword_features=subword_features,
pretrained_vectors=pretrained_vectors,
)
softmax = with_flatten(Softmax(nr_class, token_vector_width))
model = tok2vec >> softmax
model.nI = None
model.tok2vec = tok2vec
model.softmax = softmax
return model
def build_morphologizer_model(class_nums, **cfg):
embed_size = util.env_opt("embed_size", 7000)
if "token_vector_width" in cfg:
token_vector_width = cfg["token_vector_width"]
else:
token_vector_width = util.env_opt("token_vector_width", 128)
pretrained_vectors = cfg.get("pretrained_vectors")
char_embed = cfg.get("char_embed", True)
with Model.define_operators({">>": chain, "+": add, "**": clone}):
if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
tok2vec = Tok2Vec(
token_vector_width,
embed_size,
char_embed=char_embed,
pretrained_vectors=pretrained_vectors,
)
softmax = with_flatten(MultiSoftmax(class_nums, token_vector_width))
softmax.out_sizes = class_nums
model = tok2vec >> softmax
model.nI = None
model.tok2vec = tok2vec
model.softmax = softmax
return model
@layerize
def SpacyVectors(docs, drop=0.0):
batch = []
for doc in docs:
indices = numpy.zeros((len(doc),), dtype="i")
for i, word in enumerate(doc):
if word.orth in doc.vocab.vectors.key2row:
indices[i] = doc.vocab.vectors.key2row[word.orth]
else:
indices[i] = 0
vectors = doc.vocab.vectors.data[indices]
batch.append(vectors)
return batch, None
def build_text_classifier(nr_class, width=64, **cfg):
depth = cfg.get("depth", 2)
nr_vector = cfg.get("nr_vector", 5000)
pretrained_dims = cfg.get("pretrained_dims", 0)
with Model.define_operators({">>": chain, "+": add, "|": concatenate, "**": clone}):
if cfg.get("low_data") and pretrained_dims:
model = (
SpacyVectors
>> flatten_add_lengths
>> with_getitem(0, Affine(width, pretrained_dims))
>> ParametricAttention(width)
>> Pooling(sum_pool)
>> Residual(ReLu(width, width)) ** 2
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
>> logistic
)
return model
lower = HashEmbed(width, nr_vector, column=1)
prefix = HashEmbed(width // 2, nr_vector, column=2)
suffix = HashEmbed(width // 2, nr_vector, column=3)
shape = HashEmbed(width // 2, nr_vector, column=4)
trained_vectors = FeatureExtracter(
[ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
) >> with_flatten(
uniqued(
(lower | prefix | suffix | shape)
>> LN(Maxout(width, width + (width // 2) * 3)),
column=0,
)
)
if pretrained_dims:
static_vectors = SpacyVectors >> with_flatten(
Affine(width, pretrained_dims)
)
# TODO Make concatenate support lists
vectors = concatenate_lists(trained_vectors, static_vectors)
vectors_width = width * 2
else:
vectors = trained_vectors
vectors_width = width
static_vectors = None
tok2vec = vectors >> with_flatten(
LN(Maxout(width, vectors_width))
>> Residual((ExtractWindow(nW=1) >> LN(Maxout(width, width * 3)))) ** depth,
pad=depth,
)
cnn_model = (
tok2vec
>> flatten_add_lengths
>> ParametricAttention(width)
>> Pooling(sum_pool)
>> Residual(zero_init(Maxout(width, width)))
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
)
linear_model = build_bow_text_classifier(
nr_class, ngram_size=cfg.get("ngram_size", 1), exclusive_classes=False
)
if cfg.get("exclusive_classes"):
output_layer = Softmax(nr_class, nr_class * 2)
else:
output_layer = (
zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0)) >> logistic
)
model = (linear_model | cnn_model) >> output_layer
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class
model.lsuv = False
return model
def build_bow_text_classifier(
nr_class, ngram_size=1, exclusive_classes=False, no_output_layer=False, **cfg
):
with Model.define_operators({">>": chain}):
model = with_cpu(
Model.ops, extract_ngrams(ngram_size, attr=ORTH) >> LinearModel(nr_class)
)
if not no_output_layer:
model = model >> (cpu_softmax if exclusive_classes else logistic)
model.nO = nr_class
return model
@layerize
def cpu_softmax(X, drop=0.0):
ops = NumpyOps()
def cpu_softmax_backward(dY, sgd=None):
return dY
return ops.softmax(X), cpu_softmax_backward
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
is applied instead, so that outputs are in the range [0, 1].
"""
with Model.define_operators({">>": chain}):
if exclusive_classes:
output_layer = Softmax(nr_class, tok2vec.nO)
else:
output_layer = (
zero_init(Affine(nr_class, tok2vec.nO, drop_factor=0.0)) >> logistic
)
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class
return model
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
if "entity_width" not in cfg:
raise ValueError(Errors.E144.format(param="entity_width"))
conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
pretrained_vectors = cfg.get("pretrained_vectors", None)
context_width = cfg.get("entity_width")
with Model.define_operators({">>": chain, "**": clone}):
# context encoder
tok2vec = Tok2Vec(
width=hidden_width,
embed_size=embed_width,
pretrained_vectors=pretrained_vectors,
cnn_maxout_pieces=cnn_maxout_pieces,
subword_features=True,
conv_depth=conv_depth,
bilstm_depth=0,
)
model = (
tok2vec
>> flatten_add_lengths
>> Pooling(mean_pool)
>> Residual(zero_init(Maxout(hidden_width, hidden_width)))
>> zero_init(Affine(context_width, hidden_width, drop_factor=0.0))
)
model.tok2vec = tok2vec
model.nO = context_width
return model
@layerize
def flatten(seqs, drop=0.0):
ops = Model.ops
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths, pad=0)
X = ops.flatten(seqs, pad=0)
return X, finish_update
def concatenate_lists(*layers, **kwargs): # pragma: no cover
"""Compose two or more models `f`, `g`, etc, such that their outputs are
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
"""
if not layers:
return noop()
drop_factor = kwargs.get("drop_factor", 1.0)
ops = layers[0].ops
layers = [chain(layer, flatten) for layer in layers]
concat = concatenate(*layers)
def concatenate_lists_fwd(Xs, drop=0.0):
if drop is not None:
drop *= drop_factor
lengths = ops.asarray([len(X) for X in Xs], dtype="i")
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
ys = ops.unflatten(flat_y, lengths)
def concatenate_lists_bwd(d_ys, sgd=None):
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
return ys, concatenate_lists_bwd
model = wrap(concatenate_lists_fwd, concat)
return model
def masked_language_model(vocab, model, mask_prob=0.15):
"""Convert a model into a BERT-style masked language model"""
random_words = _RandomWords(vocab)
def mlm_forward(docs, drop=0.0):
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
output, backprop = model.begin_update(docs, drop=drop)
def mlm_backward(d_output, sgd=None):
d_output *= 1 - mask
return backprop(d_output, sgd=sgd)
return output, mlm_backward
return wrap(mlm_forward, model)
class _RandomWords(object):
def __init__(self, vocab):
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
self.words = self.words[:10000]
self.probs = self.probs[:10000]
self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
self.probs /= self.probs.sum()
self._cache = []
def next(self):
if not self._cache:
self._cache.extend(
numpy.random.choice(len(self.words), 10000, p=self.probs)
)
index = self._cache.pop()
return self.words[index]
def _apply_mask(docs, random_words, mask_prob=0.15):
# This needs to be here to avoid circular imports
from .tokens.doc import Doc
N = sum(len(doc) for doc in docs)
mask = numpy.random.uniform(0.0, 1.0, (N,))
mask = mask >= mask_prob
i = 0
masked_docs = []
for doc in docs:
words = []
for token in doc:
if not mask[i]:
word = _replace_word(token.text, random_words)
else:
word = token.text
words.append(word)
i += 1
spaces = [bool(w.whitespace_) for w in doc]
# NB: If you change this implementation to instead modify
# the docs in place, take care that the IDs reflect the original
# words. Currently we use the original docs to make the vectors
# for the target, so we don't lose the original tokens. But if
# you modified the docs in place here, you would.
masked_docs.append(Doc(doc.vocab, words=words, spaces=spaces))
return mask, masked_docs
def _replace_word(word, random_words, mask="[MASK]"):
roll = numpy.random.random()
if roll < 0.8:
return mask
elif roll < 0.9:
return random_words.next()
else:
return word
def _uniform_init(lo, hi):
def wrapped(W, ops):
copy_array(W, ops.xp.random.uniform(lo, hi, W.shape))
return wrapped
@describe.attributes(
nM=Dimension("Vector dimensions"),
nC=Dimension("Number of characters per word"),
vectors=Synapses(
"Embed matrix", lambda obj: (obj.nC, obj.nV, obj.nM), _uniform_init(-0.1, 0.1)
),
d_vectors=Gradient("vectors"),
)
class CharacterEmbed(Model):
def __init__(self, nM=None, nC=None, **kwargs):
Model.__init__(self, **kwargs)
self.nM = nM
self.nC = nC
@property
def nO(self):
return self.nM * self.nC
@property
def nV(self):
return 256
def begin_update(self, docs, drop=0.0):
if not docs:
return []
ids = []
output = []
weights = self.vectors
# This assists in indexing; it's like looping over this dimension.
# Still consider this weird witch craft...But thanks to Mark Neumann
# for the tip.
nCv = self.ops.xp.arange(self.nC)
for doc in docs:
doc_ids = doc.to_utf8_array(nr_char=self.nC)
doc_vectors = self.ops.allocate((len(doc), self.nC, self.nM))
# Let's say I have a 2d array of indices, and a 3d table of data. What numpy
# incantation do I chant to get
# output[i, j, k] == data[j, ids[i, j], k]?
doc_vectors[:, nCv] = weights[nCv, doc_ids[:, nCv]]
output.append(doc_vectors.reshape((len(doc), self.nO)))
ids.append(doc_ids)
def backprop_character_embed(d_vectors, sgd=None):
gradient = self.d_vectors
for doc_ids, d_doc_vectors in zip(ids, d_vectors):
d_doc_vectors = d_doc_vectors.reshape((len(doc_ids), self.nC, self.nM))
gradient[nCv, doc_ids[:, nCv]] += d_doc_vectors[:, nCv]
if sgd is not None:
sgd(self._mem.weights, self._mem.gradient, key=self.id)
return None
return output, backprop_character_embed
def get_cossim_loss(yh, y, ignore_zeros=False):
xp = get_array_module(yh)
# Find the zero vectors
if ignore_zeros:
zero_indices = xp.abs(y).sum(axis=1) == 0
# Add a small constant to avoid 0 vectors
yh = yh + 1e-8
y = y + 1e-8
# https://math.stackexchange.com/questions/1923613/partial-derivative-of-cosine-similarity
norm_yh = xp.linalg.norm(yh, axis=1, keepdims=True)
norm_y = xp.linalg.norm(y, axis=1, keepdims=True)
mul_norms = norm_yh * norm_y
cosine = (yh * y).sum(axis=1, keepdims=True) / mul_norms
d_yh = (y / mul_norms) - (cosine * (yh / norm_yh ** 2))
losses = xp.abs(cosine - 1)
if ignore_zeros:
# If the target was a zero vector, don't count it in the loss.
d_yh[zero_indices] = 0
losses[zero_indices] = 0
loss = losses.sum()
return loss, -d_yh

View File

@ -4,6 +4,7 @@ from .link import link # noqa: F401
from .package import package # noqa: F401 from .package import package # noqa: F401
from .profile import profile # noqa: F401 from .profile import profile # noqa: F401
from .train import train # noqa: F401 from .train import train # noqa: F401
from .train_from_config import train_from_config_cli # noqa: F401
from .pretrain import pretrain # noqa: F401 from .pretrain import pretrain # noqa: F401
from .debug_data import debug_data # noqa: F401 from .debug_data import debug_data # noqa: F401
from .evaluate import evaluate # noqa: F401 from .evaluate import evaluate # noqa: F401

View File

@ -4,19 +4,21 @@ import time
import re import re
from collections import Counter from collections import Counter
from pathlib import Path from pathlib import Path
from thinc.v2v import Affine, Maxout from thinc.layers import Linear, Maxout
from thinc.misc import LayerNorm as LN from thinc.util import prefer_gpu
from thinc.neural.util import prefer_gpu
from wasabi import msg from wasabi import msg
import srsly import srsly
from thinc.layers import chain, list2array
from thinc.loss import CosineDistance, L2Distance
from spacy.gold import Example from spacy.gold import Example
from ..errors import Errors from ..errors import Errors
from ..tokens import Doc from ..tokens import Doc
from ..attrs import ID, HEAD from ..attrs import ID, HEAD
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer from ..ml.component_models import Tok2Vec
from .._ml import masked_language_model, get_cossim_loss from ..ml.component_models import masked_language_model
from .. import util from .. import util
from ..util import create_default_optimizer
from .train import _load_pretrained_tok2vec from .train import _load_pretrained_tok2vec
@ -99,7 +101,7 @@ def pretrain(
with msg.loading(f"Loading model '{vectors_model}'..."): with msg.loading(f"Loading model '{vectors_model}'..."):
nlp = util.load_model(vectors_model) nlp = util.load_model(vectors_model)
msg.good(f"Loaded model '{vectors_model}'") msg.good(f"Loaded model '{vectors_model}'")
pretrained_vectors = None if not use_vectors else nlp.vocab.vectors.name pretrained_vectors = None if not use_vectors else nlp.vocab.vectors
model = create_pretraining_model( model = create_pretraining_model(
nlp, nlp,
Tok2Vec( Tok2Vec(
@ -136,7 +138,7 @@ def pretrain(
# Without '--init-tok2vec' the '--epoch-start' argument is ignored # Without '--init-tok2vec' the '--epoch-start' argument is ignored
epoch_start = 0 epoch_start = 0
optimizer = create_default_optimizer(model.ops) optimizer = create_default_optimizer()
tracker = ProgressTracker(frequency=10000) tracker = ProgressTracker(frequency=10000)
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}") msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
@ -251,13 +253,14 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"):
# and look them up all at once. This prevents data copying. # and look them up all at once. This prevents data copying.
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
target = docs[0].vocab.vectors.data[ids] target = docs[0].vocab.vectors.data[ids]
# TODO: this code originally didn't normalize, but shouldn't normalize=True ?
if objective == "L2": if objective == "L2":
d_target = prediction - target distance = L2Distance(normalize=False)
loss = (d_target ** 2).sum()
elif objective == "cosine": elif objective == "cosine":
loss, d_target = get_cossim_loss(prediction, target) distance = CosineDistance(normalize=False)
else: else:
raise ValueError(Errors.E142.format(loss_func=objective)) raise ValueError(Errors.E142.format(loss_func=objective))
d_target, loss = distance(prediction, target)
return loss, d_target return loss, d_target
@ -269,18 +272,18 @@ def create_pretraining_model(nlp, tok2vec):
""" """
output_size = nlp.vocab.vectors.data.shape[1] output_size = nlp.vocab.vectors.data.shape[1]
output_layer = chain( output_layer = chain(
LN(Maxout(300, pieces=3)), Affine(output_size, drop_factor=0.0) Maxout(300, pieces=3, normalize=True, dropout=0.0), Linear(output_size)
) )
# This is annoying, but the parser etc have the flatten step after # This is annoying, but the parser etc have the flatten step after
# the tok2vec. To load the weights in cleanly, we need to match # the tok2vec. To load the weights in cleanly, we need to match
# the shape of the models' components exactly. So what we cann # the shape of the models' components exactly. So what we cann
# "tok2vec" has to be the same set of processes as what the components do. # "tok2vec" has to be the same set of processes as what the components do.
tok2vec = chain(tok2vec, flatten) tok2vec = chain(tok2vec, list2array())
model = chain(tok2vec, output_layer) model = chain(tok2vec, output_layer)
model = masked_language_model(nlp.vocab, model) model = masked_language_model(nlp.vocab, model)
model.tok2vec = tok2vec model.set_ref("tok2vec", tok2vec)
model.output_layer = output_layer model.set_ref("output_layer", output_layer)
model.begin_training([nlp.make_doc("Give it a doc to infer shapes")]) model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
return model return model

View File

@ -5,7 +5,7 @@ import cProfile
import pstats import pstats
import sys import sys
import itertools import itertools
import thinc.extra.datasets import ml_datasets
from wasabi import msg from wasabi import msg
from ..util import load_model from ..util import load_model
@ -29,7 +29,7 @@ def profile(
if inputs is None: if inputs is None:
n_inputs = 25000 n_inputs = 25000
with msg.loading("Loading IMDB dataset via Thinc..."): with msg.loading("Loading IMDB dataset via Thinc..."):
imdb_train, _ = thinc.extra.datasets.imdb() imdb_train, _ = ml_datasets.imdb()
inputs, _ = zip(*imdb_train) inputs, _ = zip(*imdb_train)
msg.info(f"Loaded IMDB dataset and using {n_inputs} examples") msg.info(f"Loaded IMDB dataset and using {n_inputs} examples")
inputs = inputs[:n_inputs] inputs = inputs[:n_inputs]

View File

@ -1,7 +1,7 @@
import os import os
import tqdm import tqdm
from pathlib import Path from pathlib import Path
from thinc.neural._classes.model import Model from thinc.backends import use_ops
from timeit import default_timer as timer from timeit import default_timer as timer
import shutil import shutil
import srsly import srsly
@ -9,7 +9,7 @@ from wasabi import msg
import contextlib import contextlib
import random import random
from .._ml import create_default_optimizer from ..util import create_default_optimizer
from ..attrs import PROB, IS_OOV, CLUSTER, LANG from ..attrs import PROB, IS_OOV, CLUSTER, LANG
from ..gold import GoldCorpus from ..gold import GoldCorpus
from .. import util from .. import util
@ -200,7 +200,7 @@ def train(
if base_model: if base_model:
# Start with an existing model, use default optimizer # Start with an existing model, use default optimizer
optimizer = create_default_optimizer(Model.ops) optimizer = create_default_optimizer()
else: else:
# Start with a blank model, call begin_training # Start with a blank model, call begin_training
optimizer = nlp.begin_training(lambda: corpus.train_examples, device=use_gpu) optimizer = nlp.begin_training(lambda: corpus.train_examples, device=use_gpu)
@ -367,7 +367,7 @@ def train(
cpu_wps = nwords / (end_time - start_time) cpu_wps = nwords / (end_time - start_time)
else: else:
gpu_wps = nwords / (end_time - start_time) gpu_wps = nwords / (end_time - start_time)
with Model.use_device("cpu"): with use_ops("numpy"):
nlp_loaded = util.load_model_from_path(epoch_model_path) nlp_loaded = util.load_model_from_path(epoch_model_path)
for name, component in nlp_loaded.pipeline: for name, component in nlp_loaded.pipeline:
if hasattr(component, "cfg"): if hasattr(component, "cfg"):

View File

@ -0,0 +1,445 @@
import plac
from thinc.util import require_gpu
from wasabi import msg
from pathlib import Path
import thinc
import thinc.schedules
from thinc.model import Model
from spacy.gold import GoldCorpus
import spacy
from spacy.pipeline.tok2vec import Tok2VecListener
from typing import Optional, Dict, List, Union, Sequence
from pydantic import BaseModel, FilePath, StrictInt
import tqdm
from ..ml import component_models
from .. import util
registry = util.registry
CONFIG_STR = """
[training]
patience = 10
eval_frequency = 10
dropout = 0.2
init_tok2vec = null
vectors = null
max_epochs = 100
orth_variant_level = 0.0
gold_preproc = false
max_length = 0
use_gpu = 0
scores = ["ents_p", "ents_r", "ents_f"]
score_weights = {"ents_f": 1.0}
limit = 0
[training.batch_size]
@schedules = "compounding.v1"
start = 100
stop = 1000
compound = 1.001
[optimizer]
@optimizers = "Adam.v1"
learn_rate = 0.001
beta1 = 0.9
beta2 = 0.999
[nlp]
lang = "en"
vectors = ${training:vectors}
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.ner]
factory = "ner"
[nlp.pipeline.ner.model]
@architectures = "transition_based_ner.v1"
nr_feature_tokens = 3
hidden_width = 64
maxout_pieces = 3
[nlp.pipeline.ner.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.tok2vec.model]
@architectures = "hash_embed_cnn.v1"
pretrained_vectors = ${nlp:vectors}
width = 128
depth = 4
window_size = 1
embed_size = 10000
maxout_pieces = 3
"""
class PipelineComponent(BaseModel):
factory: str
model: Model
class Config:
arbitrary_types_allowed = True
class ConfigSchema(BaseModel):
optimizer: Optional["Optimizer"]
class training(BaseModel):
patience: int = 10
eval_frequency: int = 100
dropout: float = 0.2
init_tok2vec: Optional[FilePath] = None
vectors: Optional[str] = None
max_epochs: int = 100
orth_variant_level: float = 0.0
gold_preproc: bool = False
max_length: int = 0
use_gpu: int = 0
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
limit: int = 0
batch_size: Union[Sequence[int], int]
class nlp(BaseModel):
lang: str
vectors: Optional[str]
pipeline: Optional[Dict[str, PipelineComponent]]
class Config:
extra = "allow"
# Of course, these would normally decorate the functions where they're defined.
# But for now...
@registry.architectures.register("hash_embed_cnn.v1")
def hash_embed_cnn(
pretrained_vectors, width, depth, embed_size, maxout_pieces, window_size
):
return component_models.Tok2Vec(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
conv_depth=depth,
cnn_maxout_pieces=maxout_pieces,
bilstm_depth=0,
window_size=window_size,
)
@registry.architectures.register("hash_embed_bilstm.v1")
def hash_embed_bilstm_v1(pretrained_vectors, width, depth, embed_size):
return component_models.Tok2Vec(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
bilstm_depth=depth,
conv_depth=0,
cnn_maxout_pieces=0,
)
@registry.architectures.register("tagger_model.v1")
def build_tagger_model_v1(tok2vec):
return component_models.build_tagger_model(nr_class=None, tok2vec=tok2vec)
@registry.architectures.register("transition_based_parser.v1")
def create_tb_parser_model(
tok2vec: Model,
nr_feature_tokens: StrictInt = 3,
hidden_width: StrictInt = 64,
maxout_pieces: StrictInt = 3,
):
from thinc.layers import Linear, chain, list2array
from spacy.ml._layers import PrecomputableAffine
from spacy.syntax._parser_model import ParserModel
from thinc.api import use_ops, zero_init
token_vector_width = tok2vec.get_dim("nO")
tok2vec = chain(tok2vec, list2array())
tok2vec.set_dim("nO", token_vector_width)
lower = PrecomputableAffine(
hidden_width, nF=nr_feature_tokens, nI=tok2vec.get_dim("nO"), nP=maxout_pieces
)
lower.set_dim("nP", maxout_pieces)
with use_ops("numpy"):
# Initialize weights at zero, as it's a classification layer.
upper = Linear(init_W=zero_init)
return ParserModel(tok2vec, lower, upper)
@plac.annotations(
# fmt: off
train_path=("Location of JSON-formatted training data", "positional", None, Path),
dev_path=("Location of JSON-formatted development data", "positional", None, Path),
config_path=("Path to config file", "positional", None, Path),
output_path=("Output directory to store model in", "option", "o", Path),
meta_path=("Optional path to meta.json to use as base.", "option", "m", Path),
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
# fmt: on
)
def train_from_config_cli(
train_path,
dev_path,
config_path,
output_path=None,
meta_path=None,
raw_text=None,
debug=False,
verbose=False,
):
"""
Train or update a spaCy model. Requires data to be formatted in spaCy's
JSON format. To convert data from other formats, use the `spacy convert`
command.
"""
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
if not train_path or not train_path.exists():
msg.fail("Training data not found", train_path, exits=1)
if not dev_path or not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1)
if meta_path is not None and not meta_path.exists():
msg.fail("Can't find model meta.json", meta_path, exits=1)
if output_path is not None and not output_path.exists():
output_path.mkdir()
try:
train_from_config(
config_path,
{"train": train_path, "dev": dev_path},
output_path=output_path,
meta_path=meta_path,
raw_text=raw_text,
)
except KeyboardInterrupt:
msg.warn("Cancelled.")
def train_from_config(
config_path,
data_paths,
raw_text=None,
meta_path=None,
output_path=None,
):
msg.info("Loading config from: {}".format(config_path))
config = util.load_from_config(config_path, create_objects=True)
use_gpu = config["training"]["use_gpu"]
if use_gpu >= 0:
msg.info("Using GPU")
else:
msg.info("Using CPU")
msg.info("Creating nlp from config")
nlp = create_nlp_from_config(**config["nlp"])
optimizer = config["optimizer"]
limit = config["training"]["limit"]
msg.info("Loading training corpus")
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
msg.info("Initializing the nlp pipeline")
nlp.begin_training(
lambda: corpus.train_examples, device=use_gpu
)
train_batches = create_train_batches(nlp, corpus, config["training"])
evaluate = create_evaluation_callback(nlp, optimizer, corpus, config["training"])
# Create iterator, which yields out info after each optimization step.
msg.info("Start training")
training_step_iterator = train_while_improving(
nlp,
optimizer,
train_batches,
evaluate,
config["training"]["dropout"],
config["training"]["patience"],
config["training"]["eval_frequency"],
)
msg.info("Training. Initial learn rate: {}".format(optimizer.learn_rate))
print_row = setup_printer(config)
try:
progress = tqdm.tqdm(total=config["training"]["eval_frequency"], leave=False)
for batch, info, is_best_checkpoint in training_step_iterator:
progress.update(1)
if is_best_checkpoint is not None:
progress.close()
print_row(info)
if is_best_checkpoint and output_path is not None:
nlp.to_disk(output_path)
progress = tqdm.tqdm(
total=config["training"]["eval_frequency"], leave=False
)
finally:
if output_path is not None:
with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final"
nlp.to_disk(final_model_path)
msg.good("Saved model to output directory", final_model_path)
# with msg.loading("Creating best model..."):
# best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
# msg.good("Created best model", best_model_path)
def create_nlp_from_config(lang, vectors, pipeline):
lang_class = spacy.util.get_lang_class(lang)
nlp = lang_class()
if vectors is not None:
spacy.cli.train._load_vectors(nlp, vectors)
for name, component_cfg in pipeline.items():
factory = component_cfg.pop("factory")
component = nlp.create_pipe(factory, config=component_cfg)
nlp.add_pipe(component, name=name)
return nlp
def create_train_batches(nlp, corpus, cfg):
while True:
train_examples = corpus.train_dataset(
nlp,
noise_level=0.0,
orth_variant_level=cfg["orth_variant_level"],
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"],
ignore_misaligned=True,
)
for batch in util.minibatch_by_words(train_examples, size=cfg["batch_size"]):
yield batch
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
def evaluate():
with nlp.use_params(optimizer.averages):
dev_examples = list(
corpus.dev_dataset(
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
)
)
scorer = nlp.evaluate(dev_examples)
scores = scorer.scores
# Calculate a weighted sum based on score_weights for the main score
weights = cfg["score_weights"]
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
return weighted_score, scorer.scores
return evaluate
def train_while_improving(
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency
):
"""Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
where info is a dict, and is_best_checkpoint is in [True, False, None] --
None indicating that the iteration was not evaluated as a checkpoint.
The evaluation is conducted by calling the evaluate callback, which should
Positional arguments:
nlp: The spaCy pipeline to evaluate.
train_data (Iterable[Batch]): A generator of batches, with the training
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
data iterable needs to take care of iterating over the epochs and
shuffling.
evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation.
The callback should take no arguments and return a tuple
`(main_score, other_scores)`. The main_score should be a float where
higher is better. other_scores can be any object.
Every iteration, the function yields out a tuple with:
* batch: A zipped sequence of Tuple[Doc, GoldParse] pairs.
* info: A dict with various information about the last update (see below).
* is_best_checkpoint: A value in None, False, True, indicating whether this
was the best evaluation so far. You should use this to save the model
checkpoints during training. If None, evaluation was not conducted on
that iteration. False means evaluation was conducted, but a previous
evaluation was better.
The info dict provides the following information:
epoch (int): How many passes over the data have been completed.
step (int): How many steps have been completed.
score (float): The main score form the last evaluation.
other_scores: : The other scores from the last evaluation.
loss: The accumulated losses throughout training.
checkpoints: A list of previous results, where each result is a
(score, step, epoch) tuple.
"""
if isinstance(dropout, float):
dropouts = thinc.schedules.constant(dropout)
else:
dropouts = dropout
results = []
losses = {}
for step, batch in enumerate(train_data):
dropout = next(dropouts)
for subbatch in subdivide_batch(batch):
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
for name, proc in nlp.pipeline:
if hasattr(proc, "model"):
proc.model.finish_update(optimizer)
optimizer.step_schedules()
if not (step % eval_frequency):
score, other_scores = evaluate()
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:
score, other_scores = (None, None)
is_best_checkpoint = None
info = {
"step": step,
"score": score,
"other_scores": other_scores,
"losses": losses,
"checkpoints": results,
}
yield batch, info, is_best_checkpoint
if is_best_checkpoint is not None:
losses = {}
# Stop if no improvement in `patience` updates
best_score, best_step = max(results)
if (step - best_step) >= patience:
break
def subdivide_batch(batch):
return [batch]
def setup_printer(config):
score_cols = config["training"]["scores"]
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = ["Loss {}".format(pipe) for pipe in config["nlp"]["pipeline"]]
loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header]
table_widths = [6] + loss_widths + score_widths + [6]
table_aligns = ["r" for _ in table_widths]
msg.row(table_header, widths=table_widths)
msg.row(["-" * width for width in table_widths])
def print_row(info):
losses = [
"{0:.2f}".format(info["losses"].get(col, 0.0))
for col in config["nlp"]["pipeline"]
]
scores = [
"{0:.2f}".format(info["other_scores"].get(col, 0.0))
for col in config["training"]["scores"]
]
data = [info["step"]] + losses + scores + ["{0:.2f}".format(info["score"])]
msg.row(data, widths=table_widths, aligns=table_aligns)
return print_row
@registry.architectures.register("tok2vec_tensors.v1")
def tok2vec_tensors_v1(width):
tok2vec = Tok2VecListener("tok2vec", width=width)
return tok2vec

View File

@ -8,7 +8,7 @@ DOCS: https://spacy.io/api/top-level#compat
import os import os
import sys import sys
from thinc.neural.util import copy_array from thinc.util import copy_array
try: try:
import cPickle as pickle import cPickle as pickle
@ -30,10 +30,7 @@ try:
except ImportError: except ImportError:
cupy = None cupy = None
try: from thinc.optimizers import Optimizer # noqa: F401
from thinc.neural.optimizers import Optimizer # noqa: F401
except ImportError:
from thinc.neural.optimizers import Adam as Optimizer # noqa: F401
pickle = pickle pickle = pickle
copy_reg = copy_reg copy_reg = copy_reg

View File

@ -4,7 +4,8 @@ import weakref
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from thinc.neural import Model from thinc.model import Model
from thinc.backends import get_current_ops
import srsly import srsly
import multiprocessing as mp import multiprocessing as mp
from itertools import chain, cycle from itertools import chain, cycle
@ -16,7 +17,7 @@ from .lookups import Lookups
from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .gold import Example from .gold import Example
from .scorer import Scorer from .scorer import Scorer
from ._ml import link_vectors_to_models, create_default_optimizer from .util import link_vectors_to_models, create_default_optimizer
from .attrs import IS_STOP, LANG from .attrs import IS_STOP, LANG
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES from .lang.punctuation import TOKENIZER_INFIXES
@ -468,30 +469,27 @@ class Language(object):
if sgd is None: if sgd is None:
if self._optimizer is None: if self._optimizer is None:
self._optimizer = create_default_optimizer(Model.ops) self._optimizer = create_default_optimizer()
sgd = self._optimizer sgd = self._optimizer
grads = {}
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
get_grads.alpha = sgd.alpha
get_grads.b1 = sgd.b1
get_grads.b2 = sgd.b2
pipes = list(self.pipeline)
random.shuffle(pipes)
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
for name, proc in pipes: # Determine whether component should set annotations. In theory I guess
# we should do this by inspecting the meta? Or we could just always
# say "yes"
for name, proc in self.pipeline:
component_cfg.setdefault(name, {})
component_cfg[name].setdefault("drop", drop)
component_cfg[name].setdefault("set_annotations", False)
grads = {}
for name, proc in self.pipeline:
if not hasattr(proc, "update"): if not hasattr(proc, "update"):
continue continue
grads = {} proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
kwargs = component_cfg.get(name, {}) if sgd is not False:
kwargs.setdefault("drop", drop) for name, proc in self.pipeline:
proc.update(examples, sgd=get_grads, losses=losses, **kwargs) if hasattr(proc, "model"):
for key, (W, dW) in grads.items(): proc.model.finish_update(sgd)
sgd(W, dW, key=key)
def rehearse(self, examples, sgd=None, losses=None, config=None): def rehearse(self, examples, sgd=None, losses=None, config=None):
"""Make a "rehearsal" update to the models in the pipeline, to prevent """Make a "rehearsal" update to the models in the pipeline, to prevent
@ -518,7 +516,7 @@ class Language(object):
examples = Example.to_example_objects(examples, make_doc=self.make_doc) examples = Example.to_example_objects(examples, make_doc=self.make_doc)
if sgd is None: if sgd is None:
if self._optimizer is None: if self._optimizer is None:
self._optimizer = create_default_optimizer(Model.ops) self._optimizer = create_default_optimizer()
sgd = self._optimizer sgd = self._optimizer
pipes = list(self.pipeline) pipes = list(self.pipeline)
random.shuffle(pipes) random.shuffle(pipes)
@ -529,7 +527,7 @@ class Language(object):
def get_grads(W, dW, key=None): def get_grads(W, dW, key=None):
grads[key] = (W, dW) grads[key] = (W, dW)
get_grads.alpha = sgd.alpha get_grads.learn_rate = sgd.learn_rate
get_grads.b1 = sgd.b1 get_grads.b1 = sgd.b1
get_grads.b2 = sgd.b2 get_grads.b2 = sgd.b2
for name, proc in pipes: for name, proc in pipes:
@ -537,8 +535,8 @@ class Language(object):
continue continue
grads = {} grads = {}
proc.rehearse(examples, sgd=get_grads, losses=losses, **config.get(name, {})) proc.rehearse(examples, sgd=get_grads, losses=losses, **config.get(name, {}))
for key, (W, dW) in grads.items(): for key, (W, dW) in grads.items():
sgd(W, dW, key=key) sgd(W, dW, key=key)
return losses return losses
def preprocess_gold(self, examples): def preprocess_gold(self, examples):
@ -577,12 +575,13 @@ class Language(object):
if cfg.get("device", -1) >= 0: if cfg.get("device", -1) >= 0:
util.use_gpu(cfg["device"]) util.use_gpu(cfg["device"])
if self.vocab.vectors.data.shape[1] >= 1: if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = Model.ops.asarray(self.vocab.vectors.data) ops = get_current_ops()
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if self.vocab.vectors.data.shape[1]: if self.vocab.vectors.data.shape[1]:
cfg["pretrained_vectors"] = self.vocab.vectors.name cfg["pretrained_vectors"] = self.vocab.vectors
if sgd is None: if sgd is None:
sgd = create_default_optimizer(Model.ops) sgd = create_default_optimizer()
self._optimizer = sgd self._optimizer = sgd
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
@ -596,6 +595,7 @@ class Language(object):
sgd=self._optimizer, sgd=self._optimizer,
**kwargs **kwargs
) )
self._link_components()
return self._optimizer return self._optimizer
def resume_training(self, sgd=None, **cfg): def resume_training(self, sgd=None, **cfg):
@ -609,13 +609,14 @@ class Language(object):
""" """
if cfg.get("device", -1) >= 0: if cfg.get("device", -1) >= 0:
util.use_gpu(cfg["device"]) util.use_gpu(cfg["device"])
ops = get_current_ops()
if self.vocab.vectors.data.shape[1] >= 1: if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = Model.ops.asarray(self.vocab.vectors.data) self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if self.vocab.vectors.data.shape[1]: if self.vocab.vectors.data.shape[1]:
cfg["pretrained_vectors"] = self.vocab.vectors.name cfg["pretrained_vectors"] = self.vocab.vectors
if sgd is None: if sgd is None:
sgd = create_default_optimizer(Model.ops) sgd = create_default_optimizer()
self._optimizer = sgd self._optimizer = sgd
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"): if hasattr(proc, "_rehearsal_model"):
@ -736,7 +737,7 @@ class Language(object):
disable=disable, disable=disable,
n_process=n_process, n_process=n_process,
component_cfg=component_cfg, component_cfg=component_cfg,
as_example=False as_example=False # TODO: shouldn't this be as_example=as_example ?
) )
for doc, context in zip(docs, contexts): for doc, context in zip(docs, contexts):
yield (doc, context) yield (doc, context)
@ -838,6 +839,16 @@ class Language(object):
for proc in procs: for proc in procs:
proc.terminate() proc.terminate()
def _link_components(self):
"""Register 'listeners' within pipeline components, to allow them to
effectively share weights.
"""
for i, (name1, proc1) in enumerate(self.pipeline):
if hasattr(proc1, "find_listeners"):
for name2, proc2 in self.pipeline[i:]:
if hasattr(proc2, "model"):
proc1.find_listeners(proc2.model)
def to_disk(self, path, exclude=tuple(), disable=None): def to_disk(self, path, exclude=tuple(), disable=None):
"""Save the current state to a directory. If a model is loaded, this """Save the current state to a directory. If a model is loaded, this
will include the model. will include the model.
@ -906,6 +917,7 @@ class Language(object):
exclude = list(exclude) + ["vocab"] exclude = list(exclude) + ["vocab"]
util.from_disk(path, deserializers, exclude) util.from_disk(path, deserializers, exclude)
self._path = path self._path = path
self._link_components()
return self return self
def to_bytes(self, exclude=tuple(), disable=None, **kwargs): def to_bytes(self, exclude=tuple(), disable=None, **kwargs):
@ -962,6 +974,7 @@ class Language(object):
) )
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs) exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
util.from_bytes(bytes_data, deserializers, exclude) util.from_bytes(bytes_data, deserializers, exclude)
self._link_components()
return self return self

View File

@ -6,7 +6,7 @@ cimport numpy as np
np.import_array() np.import_array()
import numpy import numpy
from thinc.neural.util import get_array_module from thinc.util import get_array_module
from .typedefs cimport attr_t, flags_t from .typedefs cimport attr_t, flags_t
from .attrs cimport IS_ALPHA, IS_ASCII, IS_DIGIT, IS_LOWER, IS_PUNCT, IS_SPACE from .attrs cimport IS_ALPHA, IS_ASCII, IS_DIGIT, IS_LOWER, IS_PUNCT, IS_SPACE

View File

@ -1,2 +0,0 @@
from .tok2vec import Tok2Vec # noqa: F401
from .common import FeedForward, LayerNormalizedMaxout # noqa: F401

View File

@ -0,0 +1,52 @@
from thinc.api import Model
def CharacterEmbed(nM, nC):
# nM: Number of dimensions per character. nC: Number of characters.
nO = nM*nC if (nM is not None and nC is not None) else None
return Model(
"charembed",
forward,
init=init,
dims={"nM": nM, "nC": nC, "nO": nO, "nV": 256},
params={"E": None}
).initialize()
def init(model, X=None, Y=None):
vectors_table = model.ops.alloc3f(model.get_dim("nC"), model.get_dim("nV"), model.get_dim("nM"))
model.set_param("E", vectors_table)
def forward(model, docs, is_train):
if not docs:
return []
ids = []
output = []
E = model.get_param("E")
nC = model.get_dim("nC")
nM = model.get_dim("nM")
nO = model.get_dim("nO")
# This assists in indexing; it's like looping over this dimension.
# Still consider this weird witch craft...But thanks to Mark Neumann
# for the tip.
nCv = model.ops.xp.arange(nC)
for doc in docs:
doc_ids = doc.to_utf8_array(nr_char=nC)
doc_vectors = model.ops.alloc3f(len(doc), nC, nM)
# Let's say I have a 2d array of indices, and a 3d table of data. What numpy
# incantation do I chant to get
# output[i, j, k] == data[j, ids[i, j], k]?
doc_vectors[:, nCv] = E[nCv, doc_ids[:, nCv]]
output.append(doc_vectors.reshape((len(doc), nO)))
ids.append(doc_ids)
def backprop(d_output):
dE = model.ops.alloc(E.shape, dtype=E.dtype)
for doc_ids, d_doc_vectors in zip(ids, d_output):
d_doc_vectors = d_doc_vectors.reshape((len(doc_ids), nC, nM))
dE[nCv, doc_ids[:, nCv]] += d_doc_vectors[:, nCv]
model.inc_grad("E", dE)
return []
return output, backprop

165
spacy/ml/_layers.py Normal file
View File

@ -0,0 +1,165 @@
from thinc.model import Model
from thinc.api import normal_init
def PrecomputableAffine(nO, nI, nF, nP):
model = Model(
"precomputable_affine",
forward,
init=init,
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
params={"W": None, "b": None, "pad": None},
)
model.initialize()
return model
def forward(model, X, is_train):
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.get_param("W")
Yf = model.ops.gemm(
X, W.reshape((nF * nO * nP, nI)), trans2=True
)
Yf = Yf.reshape((Yf.shape[0], nF, nO, nP))
Yf = model.ops.xp.vstack((model.get_param("pad"), Yf))
def backward(dY_ids):
# This backprop is particularly tricky, because we get back a different
# thing from what we put out. We put out an array of shape:
# (nB, nF, nO, nP), and get back:
# (nB, nO, nP) and ids (nB, nF)
# The ids tell us the values of nF, so we would have:
#
# dYf = zeros((nB, nF, nO, nP))
# for b in range(nB):
# for f in range(nF):
# dYf[b, ids[b, f]] += dY[b]
#
# However, we avoid building that array for efficiency -- and just pass
# in the indices.
dY, ids = dY_ids
assert dY.ndim == 3
assert dY.shape[1] == nO, dY.shape
assert dY.shape[2] == nP, dY.shape
nB = dY.shape[0]
model.inc_grad("pad", _backprop_precomputable_affine_padding(model, dY, ids))
Xf = X[ids]
Xf = Xf.reshape((Xf.shape[0], nF * nI))
model.inc_grad("b", dY.sum(axis=0))
dY = dY.reshape((dY.shape[0], nO * nP))
Wopfi = W.transpose((1, 2, 0, 3))
Wopfi = model.ops.xp.ascontiguousarray(Wopfi)
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
# Reuse the buffer
dWopfi = Wopfi
dWopfi.fill(0.0)
model.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
# (o, p, f, i) --> (f, o, p, i)
model.inc_grad("W", dWopfi.transpose((2, 0, 1, 3)))
return dXf.reshape((dXf.shape[0], nF, nI))
return Yf, backward
def _backprop_precomputable_affine_padding(model, dY, ids):
nB = dY.shape[0]
nF = model.get_dim("nF")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
# Backprop the "padding", used as a filler for missing values.
# Values that are missing are set to -1, and each state vector could
# have multiple missing values. The padding has different values for
# different missing features. The gradient of the padding vector is:
#
# for b in range(nB):
# for f in range(nF):
# if ids[b, f] < 0:
# d_padding[0, f] += dY[b]
#
# Which can be rewritten as:
#
# for b in range(nB):
# d_pad[0, ids[b] < 0] += dY[b]
#
# I don't know how to avoid the loop without building a whole array :(.
# Cursed numpy.
d_pad = model.ops.alloc((1, nF, nO, nP))
for b in range(nB):
d_pad[0, ids[b] < 0] += dY[b]
return d_pad
def init(model, X=None, Y=None):
"""This is like the 'layer sequential unit variance', but instead
of taking the actual inputs, we randomly generate whitened data.
Why's this all so complicated? We have a huge number of inputs,
and the maxout unit makes guessing the dynamics tricky. Instead
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
if model.has_param("W") and model.get_param("W").any():
return
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.ops.alloc4f(nF, nO, nP, nI)
b = model.ops.alloc2f(nO, nP)
pad = model.ops.alloc4f(1, nF, nO, nP)
ops = model.ops
W = normal_init(ops, W.shape, fan_in=nF*nI)
model.set_param("W", W)
model.set_param("b", b)
model.set_param("pad", pad)
ids = ops.alloc((5000, nF), dtype="f")
ids += ops.xp.random.uniform(0, 1000, ids.shape)
ids = ops.asarray(ids, dtype="i")
tokvecs = ops.alloc((5000, nI), dtype="f")
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
tokvecs.shape
)
def predict(ids, tokvecs):
# nS ids. nW tokvecs. Exclude the padding array.
hiddens = model.predict(tokvecs[:-1]) # (nW, f, o, p)
vectors = model.ops.alloc((ids.shape[0], nO * nP), dtype="f")
# need nS vectors
hiddens = hiddens.reshape((hiddens.shape[0] * nF, nO * nP))
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
vectors = vectors.reshape((vectors.shape[0], nO, nP))
vectors += b
vectors = model.ops.asarray(vectors)
if nP >= 2:
return model.ops.maxout(vectors)[0]
else:
return vectors * (vectors >= 0)
tol_var = 0.01
tol_mean = 0.01
t_max = 10
W = model.get_param("W").copy()
b = model.get_param("b").copy()
for t_i in range(t_max):
acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var:
W /= model.ops.xp.sqrt(var)
model.set_param("W", W)
elif abs(mean) >= tol_mean:
b -= mean
model.set_param("b", b)
else:
break

View File

@ -1,129 +0,0 @@
from thinc.v2v import Model, Maxout
from thinc.i2v import HashEmbed, StaticVectors
from thinc.t2t import ExtractWindow
from thinc.misc import Residual
from thinc.misc import LayerNorm as LN
from thinc.misc import FeatureExtracter
from thinc.api import layerize, chain, clone, concatenate, with_flatten
from thinc.api import uniqued, wrap, noop
from ..attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
def Tok2Vec(width, embed_size, **kwargs):
# Circular imports :(
from .._ml import CharacterEmbed
from .._ml import PyTorchBiLSTM
pretrained_vectors = kwargs.get("pretrained_vectors", None)
cnn_maxout_pieces = kwargs.get("cnn_maxout_pieces", 3)
subword_features = kwargs.get("subword_features", True)
char_embed = kwargs.get("char_embed", False)
if char_embed:
subword_features = False
conv_depth = kwargs.get("conv_depth", 4)
bilstm_depth = kwargs.get("bilstm_depth", 0)
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed(width, embed_size, column=cols.index(NORM), name="embed_norm")
if subword_features:
prefix = HashEmbed(
width, embed_size // 2, column=cols.index(PREFIX), name="embed_prefix"
)
suffix = HashEmbed(
width, embed_size // 2, column=cols.index(SUFFIX), name="embed_suffix"
)
shape = HashEmbed(
width, embed_size // 2, column=cols.index(SHAPE), name="embed_shape"
)
else:
prefix, suffix, shape = (None, None, None)
if pretrained_vectors is not None:
glove = StaticVectors(pretrained_vectors, width, column=cols.index(ID))
if subword_features:
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> LN(Maxout(width, width * 5, pieces=3)),
column=cols.index(ORTH),
)
else:
embed = uniqued(
(glove | norm) >> LN(Maxout(width, width * 2, pieces=3)),
column=cols.index(ORTH),
)
elif subword_features:
embed = uniqued(
(norm | prefix | suffix | shape)
>> LN(Maxout(width, width * 4, pieces=3)),
column=cols.index(ORTH),
)
elif char_embed:
embed = concatenate_lists(
CharacterEmbed(nM=64, nC=8),
FeatureExtracter(cols) >> with_flatten(norm),
)
reduce_dimensions = LN(
Maxout(width, 64 * 8 + width, pieces=cnn_maxout_pieces)
)
else:
embed = norm
convolution = Residual(
ExtractWindow(nW=1)
>> LN(Maxout(width, width * 3, pieces=cnn_maxout_pieces))
)
if char_embed:
tok2vec = embed >> with_flatten(
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
)
else:
tok2vec = FeatureExtracter(cols) >> with_flatten(
embed >> convolution ** conv_depth, pad=conv_depth
)
if bilstm_depth >= 1:
tok2vec = tok2vec >> PyTorchBiLSTM(width, width, bilstm_depth)
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.nO = width
tok2vec.embed = embed
return tok2vec
@layerize
def flatten(seqs, drop=0.0):
ops = Model.ops
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths, pad=0)
X = ops.flatten(seqs, pad=0)
return X, finish_update
def concatenate_lists(*layers, **kwargs): # pragma: no cover
"""Compose two or more models `f`, `g`, etc, such that their outputs are
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
"""
if not layers:
return noop()
drop_factor = kwargs.get("drop_factor", 1.0)
ops = layers[0].ops
layers = [chain(layer, flatten) for layer in layers]
concat = concatenate(*layers)
def concatenate_lists_fwd(Xs, drop=0.0):
if drop is not None:
drop *= drop_factor
lengths = ops.asarray([len(X) for X in Xs], dtype="i")
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
ys = ops.unflatten(flat_y, lengths)
def concatenate_lists_bwd(d_ys, sgd=None):
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
return ys, concatenate_lists_bwd
model = wrap(concatenate_lists_fwd, concat)
return model

View File

@ -1,41 +0,0 @@
from thinc.api import layerize, wrap, noop, chain, concatenate
from thinc.v2v import Model
def concatenate_lists(*layers, **kwargs): # pragma: no cover
"""Compose two or more models `f`, `g`, etc, such that their outputs are
concatenated, i.e. `concatenate(f, g)(x)` computes `hstack(f(x), g(x))`
"""
if not layers:
return layerize(noop())
drop_factor = kwargs.get("drop_factor", 1.0)
ops = layers[0].ops
layers = [chain(layer, flatten) for layer in layers]
concat = concatenate(*layers)
def concatenate_lists_fwd(Xs, drop=0.0):
if drop is not None:
drop *= drop_factor
lengths = ops.asarray([len(X) for X in Xs], dtype="i")
flat_y, bp_flat_y = concat.begin_update(Xs, drop=drop)
ys = ops.unflatten(flat_y, lengths)
def concatenate_lists_bwd(d_ys, sgd=None):
return bp_flat_y(ops.flatten(d_ys), sgd=sgd)
return ys, concatenate_lists_bwd
model = wrap(concatenate_lists_fwd, concat)
return model
@layerize
def flatten(seqs, drop=0.0):
ops = Model.ops
lengths = ops.asarray([len(seq) for seq in seqs], dtype="i")
def finish_update(d_X, sgd=None):
return ops.unflatten(d_X, lengths, pad=0)
X = ops.flatten(seqs, pad=0)
return X, finish_update

View File

@ -1,21 +0,0 @@
from thinc.api import chain
from thinc.v2v import Maxout
from thinc.misc import LayerNorm
from ..util import registry, make_layer
@registry.architectures.register("thinc.FeedForward.v1")
def FeedForward(config):
layers = [make_layer(layer_cfg) for layer_cfg in config["layers"]]
model = chain(*layers)
model.cfg = config
return model
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
def LayerNormalizedMaxout(config):
width = config["width"]
pieces = config["pieces"]
layer = LayerNorm(Maxout(width, pieces=pieces))
layer.nO = width
return layer

View File

@ -0,0 +1,222 @@
from spacy import util
from spacy.ml.extract_ngrams import extract_ngrams
from ..attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
from ..errors import Errors
from ._character_embed import CharacterEmbed
from thinc.api import Model, Maxout, Linear, residual, reduce_mean, list2ragged
from thinc.api import PyTorchLSTM, add, MultiSoftmax, HashEmbed, StaticVectors
from thinc.api import expand_window, FeatureExtractor, SparseLinear, chain
from thinc.api import clone, concatenate, with_array, Softmax, Logistic, uniqued
from thinc.api import zero_init, glorot_uniform_init
def build_text_classifier(arch, config):
if arch == "cnn":
return build_simple_cnn_text_classifier(**config)
elif arch == "bow":
return build_bow_text_classifier(**config)
else:
raise ValueError("Unexpected textcat arch")
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes, **cfg):
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
is applied instead, so that outputs are in the range [0, 1].
"""
with Model.define_operators({">>": chain}):
if exclusive_classes:
output_layer = Softmax(nO=nr_class, nI=tok2vec.get_dim("nO"))
else:
# TODO: experiment with init_w=zero_init
output_layer = (
Linear(nO=nr_class, nI=tok2vec.get_dim("nO"))
>> Logistic()
)
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
model.set_ref("tok2vec", tok2vec)
model.set_dim("nO", nr_class)
return model
def build_bow_text_classifier(
nr_class, exclusive_classes, ngram_size=1, no_output_layer=False, **cfg
):
with Model.define_operators({">>": chain}):
model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nr_class)
model.to_cpu()
if not no_output_layer:
output_layer = (
Softmax(nO=nr_class) if exclusive_classes else Logistic(nO=nr_class)
)
output_layer.to_cpu()
model = model >> output_layer
model.set_dim("nO", nr_class)
return model
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
if "entity_width" not in cfg:
raise ValueError(Errors.E144.format(param="entity_width"))
conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
pretrained_vectors = cfg.get("pretrained_vectors", None)
context_width = cfg.get("entity_width")
with Model.define_operators({">>": chain, "**": clone}):
nel_tok2vec = Tok2Vec(
width=hidden_width,
embed_size=embed_width,
pretrained_vectors=pretrained_vectors,
cnn_maxout_pieces=cnn_maxout_pieces,
subword_features=True,
conv_depth=conv_depth,
bilstm_depth=0,
)
model = (
nel_tok2vec
>> list2ragged()
>> reduce_mean()
>> residual(Maxout(nO=hidden_width, nI=hidden_width, nP=2, dropout=0.0))
>> Linear(nO=context_width, nI=hidden_width)
)
model.initialize()
model.set_ref("tok2vec", nel_tok2vec)
model.set_dim("nO", context_width)
return model
def masked_language_model(*args, **kwargs):
raise NotImplementedError
def build_tagger_model(nr_class, tok2vec):
token_vector_width = tok2vec.get_dim("nO")
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
softmax = with_array(Softmax(nO=nr_class, nI=token_vector_width, init_W=zero_init))
model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", softmax)
return model
def build_morphologizer_model(class_nums, **cfg):
embed_size = util.env_opt("embed_size", 7000)
if "token_vector_width" in cfg:
token_vector_width = cfg["token_vector_width"]
else:
token_vector_width = util.env_opt("token_vector_width", 128)
pretrained_vectors = cfg.get("pretrained_vectors")
char_embed = cfg.get("char_embed", True)
with Model.define_operators({">>": chain, "+": add, "**": clone}):
if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
tok2vec = Tok2Vec(
token_vector_width,
embed_size,
char_embed=char_embed,
pretrained_vectors=pretrained_vectors,
)
softmax = with_array(MultiSoftmax(nOs=class_nums, nI=token_vector_width))
model = tok2vec >> softmax
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", softmax)
return model
def Tok2Vec(
width,
embed_size,
pretrained_vectors=None,
window_size=1,
cnn_maxout_pieces=3,
subword_features=True,
char_embed=False,
conv_depth=4,
bilstm_depth=0,
):
if char_embed:
subword_features = False
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM), dropout=0.0)
if subword_features:
prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=0.0)
suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=0.0)
shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=0.0)
else:
prefix, suffix, shape = (None, None, None)
if pretrained_vectors is not None:
glove = StaticVectors(vectors=pretrained_vectors, nO=width, column=cols.index(ID), dropout=0.0)
if subword_features:
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> Maxout(
nO=width, nI=width * 5, nP=3, dropout=0.0, normalize=True
),
column=cols.index(ORTH),
)
else:
embed = uniqued(
(glove | norm)
>> Maxout(
nO=width, nI=width * 2, nP=3, dropout=0.0, normalize=True
),
column=cols.index(ORTH),
)
elif subword_features:
embed = uniqued(
concatenate(norm, prefix, suffix, shape)
>> Maxout(nO=width, nI=width * 4, nP=3, dropout=0.0, normalize=True),
column=cols.index(ORTH),
)
elif char_embed:
embed = CharacterEmbed(nM=64, nC=8) | FeatureExtractor(cols) >> with_array(
norm
)
reduce_dimensions = Maxout(
nO=width,
nI=64 * 8 + width,
nP=cnn_maxout_pieces,
dropout=0.0,
normalize=True,
)
else:
embed = norm
convolution = residual(
expand_window(window_size=window_size)
>> Maxout(
nO=width,
nI=width * 3,
nP=cnn_maxout_pieces,
dropout=0.0,
normalize=True,
)
)
if char_embed:
tok2vec = embed >> with_array(
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
)
else:
tok2vec = FeatureExtractor(cols) >> with_array(
embed >> convolution ** conv_depth, pad=conv_depth
)
if bilstm_depth >= 1:
tok2vec = tok2vec >> PyTorchLSTM(
nO=width, nI=width, depth=bilstm_depth, bi=True
)
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.set_dim("nO", width)
tok2vec.set_ref("embed", embed)
return tok2vec

View File

@ -0,0 +1,39 @@
import numpy
from thinc.model import Model
from ..attrs import LOWER
def extract_ngrams(ngram_size, attr=LOWER) -> Model:
model = Model("extract_ngrams", forward)
model.attrs["ngram_size"] = ngram_size
model.attrs["attr"] = attr
return model
def forward(self, docs, is_train: bool):
batch_keys = []
batch_vals = []
for doc in docs:
unigrams = doc.to_array([self.attrs["attr"]])
ngrams = [unigrams]
for n in range(2, self.attrs["ngram_size"] + 1):
ngrams.append(self.ops.ngrams(n, unigrams))
keys = self.ops.xp.concatenate(ngrams)
keys, vals = self.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys)
batch_vals.append(vals)
# The dtype here matches what thinc is expecting -- which differs per
# platform (by int definition). This should be fixed once the problem
# is fixed on Thinc's side.
lengths = self.ops.asarray(
[arr.shape[0] for arr in batch_keys], dtype=numpy.int_
)
batch_keys = self.ops.xp.concatenate(batch_keys)
batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f")
def backprop(dY):
return dY
return (batch_keys, batch_vals, lengths), backprop

View File

@ -1,11 +1,12 @@
from thinc.api import chain, layerize, clone, concatenate, with_flatten, uniqued from thinc.layers import chain, clone, concatenate, with_array, uniqued
from thinc.api import noop, with_square_sequences from thinc.model import Model
from thinc.v2v import Maxout, Model from thinc.layers import noop, with_padded
from thinc.i2v import HashEmbed, StaticVectors from thinc.layers import Maxout, expand_window
from thinc.t2t import ExtractWindow from thinc.layers import HashEmbed, StaticVectors
from thinc.misc import Residual, LayerNorm, FeatureExtracter from thinc.layers import residual, LayerNorm, FeatureExtractor
from spacy.ml import _character_embed
from ..util import make_layer, registry from ..util import make_layer, registry
from ._wire import concatenate_lists
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
@ -13,19 +14,21 @@ def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"]) doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"]) embed = make_layer(config["@embed"])
encode = make_layer(config["@encode"]) encode = make_layer(config["@encode"])
field_size = getattr(encode, "receptive_field", 0) field_size = 0
tok2vec = chain(doc2feats, with_flatten(chain(embed, encode), pad=field_size)) if encode.has_attr("receptive_field"):
tok2vec.cfg = config field_size = encode.attrs["receptive_field"]
tok2vec.nO = encode.nO tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size))
tok2vec.embed = embed tok2vec.attrs["cfg"] = config
tok2vec.encode = encode tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed)
tok2vec.set_ref("encode", encode)
return tok2vec return tok2vec
@registry.architectures.register("spacy.Doc2Feats.v1") @registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config): def Doc2Feats(config):
columns = config["columns"] columns = config["columns"]
return FeatureExtracter(columns) return FeatureExtractor(columns)
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")
@ -40,55 +43,47 @@ def MultiHashEmbed(config):
width = config["width"] width = config["width"]
rows = config["rows"] rows = config["rows"]
norm = HashEmbed(width, rows, column=cols.index("NORM"), name="embed_norm") norm = HashEmbed(width, rows, column=cols.index("NORM"), dropout=0.0)
if config["use_subwords"]: if config["use_subwords"]:
prefix = HashEmbed( prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX"), dropout=0.0)
width, rows // 2, column=cols.index("PREFIX"), name="embed_prefix" suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX"), dropout=0.0)
) shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE"), dropout=0.0)
suffix = HashEmbed(
width, rows // 2, column=cols.index("SUFFIX"), name="embed_suffix"
)
shape = HashEmbed(
width, rows // 2, column=cols.index("SHAPE"), name="embed_shape"
)
if config.get("@pretrained_vectors"): if config.get("@pretrained_vectors"):
glove = make_layer(config["@pretrained_vectors"]) glove = make_layer(config["@pretrained_vectors"])
mix = make_layer(config["@mix"]) mix = make_layer(config["@mix"])
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
if config["use_subwords"] and config["@pretrained_vectors"]: if config["use_subwords"] and config["@pretrained_vectors"]:
mix._layers[0].nI = width * 5 mix._layers[0].set_dim("nI", width * 5)
layer = uniqued( layer = uniqued(
(glove | norm | prefix | suffix | shape) >> mix, (glove | norm | prefix | suffix | shape) >> mix,
column=cols.index("ORTH"), column=cols.index("ORTH"),
) )
elif config["use_subwords"]: elif config["use_subwords"]:
mix._layers[0].nI = width * 4 mix._layers[0].set_dim("nI", width * 4)
layer = uniqued( layer = uniqued(
(norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH") (norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH")
) )
elif config["@pretrained_vectors"]: elif config["@pretrained_vectors"]:
mix._layers[0].nI = width * 2 mix._layers[0].set_dim("nI", width * 2)
layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),) layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),)
else: else:
layer = norm layer = norm
layer.cfg = config layer.attrs["cfg"] = config
return layer return layer
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config): def CharacterEmbed(config):
from .. import _ml
width = config["width"] width = config["width"]
chars = config["chars"] chars = config["chars"]
chr_embed = _ml.CharacterEmbedModel(nM=width, nC=chars) chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars)
other_tables = make_layer(config["@embed_features"]) other_tables = make_layer(config["@embed_features"])
mix = make_layer(config["@mix"]) mix = make_layer(config["@mix"])
model = chain(concatenate_lists(chr_embed, other_tables), mix) model = chain(concatenate(chr_embed, other_tables), mix)
model.cfg = config model.attrs["cfg"] = config
return model return model
@ -99,48 +94,49 @@ def MaxoutWindowEncoder(config):
nP = config["pieces"] nP = config["pieces"]
depth = config["depth"] depth = config["depth"]
cnn = chain( cnn = expand_window(window_size=nW), Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True)
ExtractWindow(nW=nW), LayerNorm(Maxout(nO, nO * ((nW * 2) + 1), pieces=nP)) model = clone(residual(cnn), depth)
) model.set_dim("nO", nO)
model = clone(Residual(cnn), depth) model.attrs["receptive_field"] = nW * depth
model.nO = nO
model.receptive_field = nW * depth
return model return model
@registry.architectures.register("spacy.MishWindowEncoder.v1") @registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config): def MishWindowEncoder(config):
from thinc.v2v import Mish from thinc.layers import Mish
nO = config["width"] nO = config["width"]
nW = config["window_size"] nW = config["window_size"]
depth = config["depth"] depth = config["depth"]
cnn = chain(ExtractWindow(nW=nW), LayerNorm(Mish(nO, nO * ((nW * 2) + 1)))) cnn = chain(expand_window(window_size=nW), Mish(nO=nO, nI=nO * ((nW * 2) + 1)), LayerNorm(nO))
model = clone(Residual(cnn), depth) model = clone(residual(cnn), depth)
model.nO = nO model.set_dim("nO", nO)
return model return model
@registry.architectures.register("spacy.PretrainedVectors.v1") @registry.architectures.register("spacy.PretrainedVectors.v1")
def PretrainedVectors(config): def PretrainedVectors(config):
return StaticVectors(config["vectors_name"], config["width"], config["column"]) # TODO: actual vectors instead of name
return StaticVectors(vectors=config["vectors_name"], nO=config["width"], column=config["column"], dropout=0.0)
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1") @registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config): def TorchBiLSTMEncoder(config):
import torch.nn import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN # TODO FIX
from thinc.layers import PyTorchRNNWrapper
width = config["width"] width = config["width"]
depth = config["depth"] depth = config["depth"]
if depth == 0: if depth == 0:
return layerize(noop()) return noop()
return with_square_sequences( return with_padded(
PyTorchWrapperRNN(torch.nn.LSTM(width, width // 2, depth, bidirectional=True)) PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True))
) )
# TODO: update
_EXAMPLE_CONFIG = { _EXAMPLE_CONFIG = {
"@doc2feats": { "@doc2feats": {
"arch": "Doc2Feats", "arch": "Doc2Feats",

View File

@ -3,6 +3,7 @@ from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
from .pipes import SentenceRecognizer from .pipes import SentenceRecognizer
from .morphologizer import Morphologizer from .morphologizer import Morphologizer
from .entityruler import EntityRuler from .entityruler import EntityRuler
from .tok2vec import Tok2Vec
from .hooks import SentenceSegmenter, SimilarityHook from .hooks import SentenceSegmenter, SimilarityHook
from .functions import merge_entities, merge_noun_chunks, merge_subtokens from .functions import merge_entities, merge_noun_chunks, merge_subtokens
@ -13,6 +14,7 @@ __all__ = [
"EntityLinker", "EntityLinker",
"TextCategorizer", "TextCategorizer",
"Tensorizer", "Tensorizer",
"Tok2Vec",
"Pipe", "Pipe",
"Morphologizer", "Morphologizer",
"EntityRuler", "EntityRuler",

View File

@ -1,9 +1,8 @@
from thinc.t2v import Pooling, max_pool, mean_pool from thinc.layers import concatenate, reduce_max, reduce_mean, siamese, CauchySimilarity
from thinc.neural._classes.difference import Siamese, CauchySimilarity
from .pipes import Pipe from .pipes import Pipe
from ..language import component from ..language import component
from .._ml import link_vectors_to_models from ..util import link_vectors_to_models
@component("sentencizer_hook", assigns=["doc.user_hooks"]) @component("sentencizer_hook", assigns=["doc.user_hooks"])
@ -63,7 +62,10 @@ class SimilarityHook(Pipe):
@classmethod @classmethod
def Model(cls, length): def Model(cls, length):
return Siamese(Pooling(max_pool, mean_pool), CauchySimilarity(length)) return siamese(
concatenate(reduce_max(), reduce_mean()),
CauchySimilarity(length * 2)
)
def __call__(self, doc): def __call__(self, doc):
"""Install similarity hook""" """Install similarity hook"""
@ -80,7 +82,7 @@ class SimilarityHook(Pipe):
def update(self, doc1_doc2, golds, sgd=None, drop=0.0): def update(self, doc1_doc2, golds, sgd=None, drop=0.0):
self.require_model() self.require_model()
sims, bp_sims = self.model.begin_update(doc1_doc2, drop=drop) sims, bp_sims = self.model.begin_update(doc1_doc2)
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs): def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
"""Allocate model, using width from tensorizer in pipeline. """Allocate model, using width from tensorizer in pipeline.
@ -89,7 +91,7 @@ class SimilarityHook(Pipe):
pipeline (list): The pipeline the model is part of. pipeline (list): The pipeline the model is part of.
""" """
if self.model is True: if self.model is True:
self.model = self.Model(pipeline[0].model.nO) self.model = self.Model(pipeline[0].model.get_dim("nO"))
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()

View File

@ -3,19 +3,20 @@ from collections import defaultdict
import numpy import numpy
cimport numpy as np cimport numpy as np
from thinc.api import chain from thinc.layers import chain, list2array
from thinc.neural.util import to_categorical, copy_array, get_array_module from thinc.util import to_categorical, copy_array, get_array_module
from .. import util from .. import util
from .pipes import Pipe from .pipes import Pipe
from ..language import component from ..language import component
from .._ml import Tok2Vec, build_morphologizer_model from ..util import link_vectors_to_models, create_default_optimizer
from .._ml import link_vectors_to_models, zero_init, flatten
from .._ml import create_default_optimizer
from ..errors import Errors, TempErrors from ..errors import Errors, TempErrors
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..morphology cimport Morphology from ..morphology cimport Morphology
from ..ml.component_models import build_morphologizer_model
@component("morphologizer", assigns=["token.morph", "token.pos"]) @component("morphologizer", assigns=["token.morph", "token.pos"])
class Morphologizer(Pipe): class Morphologizer(Pipe):
@ -43,7 +44,7 @@ class Morphologizer(Pipe):
if self.model in (None, True, False): if self.model in (None, True, False):
return None return None
else: else:
return chain(self.model.tok2vec, flatten) return chain(self.model.get_ref("tok2vec"), list2array())
def __call__(self, doc): def __call__(self, doc):
features, tokvecs = self.predict([doc]) features, tokvecs = self.predict([doc])
@ -60,9 +61,9 @@ class Morphologizer(Pipe):
def predict(self, docs): def predict(self, docs):
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle case where there are no tokens in any docs. # Handle case where there are no tokens in any docs.
n_labels = self.model.nO n_labels = self.model.get_dim("nO")
guesses = [self.model.ops.allocate((0, n_labels)) for doc in docs] guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
tokvecs = self.model.ops.allocate((0, self.model.tok2vec.nO)) tokvecs = self.model.ops.alloc((0, self.model.get_ref("tok2vec").get_dim("nO")))
return guesses, tokvecs return guesses, tokvecs
tokvecs = self.model.tok2vec(docs) tokvecs = self.model.tok2vec(docs)
scores = self.model.softmax(tokvecs) scores = self.model.softmax(tokvecs)
@ -77,7 +78,7 @@ class Morphologizer(Pipe):
for field in self._class_map.fields] for field in self._class_map.fields]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc_scores = batch_scores[i] doc_scores = batch_scores[i]
doc_guesses = scores_to_guesses(doc_scores, self.model.softmax.out_sizes) doc_guesses = scores_to_guesses(doc_scores, self.model.get_ref("softmax").attrs["nOs"])
# Convert the neuron indices into feature IDs. # Convert the neuron indices into feature IDs.
doc_feat_ids = numpy.zeros((len(doc), len(self._class_map.fields)), dtype='i') doc_feat_ids = numpy.zeros((len(doc), len(self._class_map.fields)), dtype='i')
for j in range(len(doc)): for j in range(len(doc)):
@ -110,7 +111,7 @@ class Morphologizer(Pipe):
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
guesses = [] guesses = []
for doc_scores in scores: for doc_scores in scores:
guesses.append(scores_to_guesses(doc_scores, self.model.softmax.out_sizes)) guesses.append(scores_to_guesses(doc_scores, self.model.get_ref("softmax").attrs["nOs"]))
guesses = self.model.ops.xp.vstack(guesses) guesses = self.model.ops.xp.vstack(guesses)
scores = self.model.ops.xp.vstack(scores) scores = self.model.ops.xp.vstack(scores)
if not isinstance(scores, numpy.ndarray): if not isinstance(scores, numpy.ndarray):
@ -120,7 +121,7 @@ class Morphologizer(Pipe):
cdef int idx = 0 cdef int idx = 0
# Do this on CPU, as we can't vectorize easily. # Do this on CPU, as we can't vectorize easily.
target = numpy.zeros(scores.shape, dtype='f') target = numpy.zeros(scores.shape, dtype='f')
field_sizes = self.model.softmax.out_sizes field_sizes = self.model.get_ref("softmax").attrs["nOs"]
for example in examples: for example in examples:
doc = example.doc doc = example.doc
gold = example.gold gold = example.gold

View File

@ -3,11 +3,11 @@
import numpy import numpy
import srsly import srsly
import random import random
from thinc.api import chain from thinc.layers import chain, Linear, Maxout, Softmax, LayerNorm, list2array
from thinc.v2v import Affine, Maxout, Softmax from thinc.initializers import zero_init
from thinc.misc import LayerNorm from thinc.loss import CosineDistance
from thinc.neural.util import to_categorical from thinc.util import to_categorical, get_array_module
from thinc.neural.util import get_array_module from thinc.model import set_dropout_rate
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser from ..syntax.nn_parser cimport Parser
@ -21,13 +21,14 @@ from ..language import Language, component
from ..syntax import nonproj from ..syntax import nonproj
from ..gold import Example from ..gold import Example
from ..attrs import POS, ID from ..attrs import POS, ID
from ..util import link_vectors_to_models, create_default_optimizer
from ..parts_of_speech import X from ..parts_of_speech import X
from ..kb import KnowledgeBase from ..kb import KnowledgeBase
from .._ml import Tok2Vec, build_tagger_model, cosine, get_cossim_loss from ..ml.component_models import Tok2Vec, build_tagger_model
from .._ml import build_text_classifier, build_simple_cnn_text_classifier from ..ml.component_models import build_text_classifier
from .._ml import build_bow_text_classifier, build_nel_encoder from ..ml.component_models import build_simple_cnn_text_classifier
from .._ml import link_vectors_to_models, zero_init, flatten from ..ml.component_models import build_bow_text_classifier, build_nel_encoder
from .._ml import masked_language_model, create_default_optimizer, get_cossim_loss from ..ml.component_models import masked_language_model
from ..errors import Errors, TempErrors, user_warning, Warnings from ..errors import Errors, TempErrors, user_warning, Warnings
from .. import util from .. import util
@ -126,13 +127,15 @@ class Pipe(object):
"""Modify a batch of documents, using pre-computed scores.""" """Modify a batch of documents, using pre-computed scores."""
raise NotImplementedError raise NotImplementedError
def update(self, examples, drop=0.0, sgd=None, losses=None): def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None):
"""Learn from a batch of documents and gold-standard information, """Learn from a batch of documents and gold-standard information,
updating the pipe's model. updating the pipe's model.
Delegates to predict() and get_loss(). Delegates to predict() and get_loss().
""" """
pass if set_annotations:
docs = (self._get_doc(ex) for ex in examples)
docs = list(self.pipe(docs))
def rehearse(self, examples, sgd=None, losses=None, **config): def rehearse(self, examples, sgd=None, losses=None, **config):
pass pass
@ -152,7 +155,7 @@ class Pipe(object):
raise NotImplementedError raise NotImplementedError
def create_optimizer(self): def create_optimizer(self):
return create_default_optimizer(self.model.ops, **self.cfg.get("optimizer", {})) return create_default_optimizer()
def begin_training( def begin_training(
self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs
@ -163,10 +166,30 @@ class Pipe(object):
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
if hasattr(self, "vocab"): if hasattr(self, "vocab"):
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
self.model.initialize()
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
def get_gradients(self):
"""Get non-zero gradients of the model's parameters, as a dictionary
keyed by the parameter ID. The values are (weights, gradients) tuples.
"""
gradients = {}
if self.model in (None, True, False):
return gradients
queue = [self.model]
seen = set()
for node in queue:
if node.id in seen:
continue
seen.add(node.id)
if hasattr(node, "_mem") and node._mem.gradient.any():
gradients[node.id] = [node._mem.weights, node._mem.gradient]
if hasattr(node, "_layers"):
queue.extend(node._layers)
return gradients
def use_params(self, params): def use_params(self, params):
"""Modify the pipe's model, to use the given parameter values.""" """Modify the pipe's model, to use the given parameter values."""
with self.model.use_params(params): with self.model.use_params(params):
@ -193,7 +216,7 @@ class Pipe(object):
def load_model(b): def load_model(b):
# TODO: Remove this once we don't have to handle previous models # TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors.name self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
try: try:
@ -226,7 +249,7 @@ class Pipe(object):
def load_model(p): def load_model(p):
# TODO: Remove this once we don't have to handle previous models # TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors.name self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
try: try:
@ -254,10 +277,10 @@ class Tensorizer(Pipe):
width (int): Output size of the model. width (int): Output size of the model.
embed_size (int): Number of vectors in the embedding table. embed_size (int): Number of vectors in the embedding table.
**cfg: Config parameters. **cfg: Config parameters.
RETURNS (Model): A `thinc.neural.Model` or similar instance. RETURNS (Model): A `thinc.model.Model` or similar instance.
""" """
input_size = util.env_opt("token_vector_width", cfg.get("input_size", 96)) input_size = util.env_opt("token_vector_width", cfg.get("input_size", 96))
return zero_init(Affine(output_size, input_size, drop_factor=0.0)) return Linear(output_size, input_size, init_W=zero_init)
def __init__(self, vocab, model=True, **cfg): def __init__(self, vocab, model=True, **cfg):
"""Construct a new statistical model. Weights are not allocated on """Construct a new statistical model. Weights are not allocated on
@ -277,7 +300,6 @@ class Tensorizer(Pipe):
self.model = model self.model = model
self.input_models = [] self.input_models = []
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.cfg.setdefault("cnn_maxout_pieces", 3)
def __call__(self, example): def __call__(self, example):
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM """Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
@ -337,7 +359,7 @@ class Tensorizer(Pipe):
raise ValueError(Errors.E076.format(rows=tensor.shape[0], words=len(doc))) raise ValueError(Errors.E076.format(rows=tensor.shape[0], words=len(doc)))
doc.tensor = tensor doc.tensor = tensor
def update(self, examples, state=None, drop=0.0, sgd=None, losses=None): def update(self, examples, state=None, drop=0.0, set_annotations=False, sgd=None, losses=None):
"""Update the model. """Update the model.
docs (iterable): A batch of `Doc` objects. docs (iterable): A batch of `Doc` objects.
@ -350,17 +372,23 @@ class Tensorizer(Pipe):
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
inputs = [] inputs = []
bp_inputs = [] bp_inputs = []
set_dropout_rate(self.model, drop)
for tok2vec in self.input_models: for tok2vec in self.input_models:
tensor, bp_tensor = tok2vec.begin_update([ex.doc for ex in examples], drop=drop) set_dropout_rate(tok2vec, drop)
tensor, bp_tensor = tok2vec.begin_update([ex.doc for ex in examples])
inputs.append(tensor) inputs.append(tensor)
bp_inputs.append(bp_tensor) bp_inputs.append(bp_tensor)
inputs = self.model.ops.xp.hstack(inputs) inputs = self.model.ops.xp.hstack(inputs)
scores, bp_scores = self.model.begin_update(inputs, drop=drop) scores, bp_scores = self.model.begin_update(inputs)
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)
d_inputs = bp_scores(d_scores, sgd=sgd) d_inputs = bp_scores(d_scores, sgd=sgd)
d_inputs = self.model.ops.xp.split(d_inputs, len(self.input_models), axis=1) d_inputs = self.model.ops.xp.split(d_inputs, len(self.input_models), axis=1)
for d_input, bp_input in zip(d_inputs, bp_inputs): for d_input, bp_input in zip(d_inputs, bp_inputs):
bp_input(d_input, sgd=sgd) bp_input(d_input)
if sgd is not None:
for tok2vec in self.input_models:
tok2vec.finish_update(sgd)
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
losses[self.name] += loss losses[self.name] += loss
@ -387,6 +415,7 @@ class Tensorizer(Pipe):
self.input_models.append(model.tok2vec) self.input_models.append(model.tok2vec)
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
self.model.initialize()
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
@ -405,7 +434,6 @@ class Tagger(Pipe):
self.model = model self.model = model
self._rehearsal_model = None self._rehearsal_model = None
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.cfg.setdefault("cnn_maxout_pieces", 2)
@property @property
def labels(self): def labels(self):
@ -416,12 +444,12 @@ class Tagger(Pipe):
if self.model in (None, True, False): if self.model in (None, True, False):
return None return None
else: else:
return chain(self.model.tok2vec, flatten) return chain(self.model.get_ref("tok2vec"), list2array())
def __call__(self, example): def __call__(self, example):
doc = self._get_doc(example) doc = self._get_doc(example)
tags, tokvecs = self.predict([doc]) tags = self.predict([doc])
self.set_annotations([doc], tags, tensors=tokvecs) self.set_annotations([doc], tags)
if isinstance(example, Example): if isinstance(example, Example):
example.doc = doc example.doc = doc
return example return example
@ -430,8 +458,10 @@ class Tagger(Pipe):
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False): def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
for examples in util.minibatch(stream, size=batch_size): for examples in util.minibatch(stream, size=batch_size):
docs = [self._get_doc(ex) for ex in examples] docs = [self._get_doc(ex) for ex in examples]
tag_ids, tokvecs = self.predict(docs) tag_ids = self.predict(docs)
self.set_annotations(docs, tag_ids, tensors=tokvecs) assert len(docs) == len(examples)
assert len(tag_ids) == len(examples)
self.set_annotations(docs, tag_ids)
if as_example: if as_example:
annotated_examples = [] annotated_examples = []
@ -447,20 +477,25 @@ class Tagger(Pipe):
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
n_labels = len(self.labels) n_labels = len(self.labels)
guesses = [self.model.ops.allocate((0, n_labels)) for doc in docs] guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
tokvecs = self.model.ops.allocate((0, self.model.tok2vec.nO)) assert len(guesses) == len(docs)
return guesses, tokvecs return guesses
tokvecs = self.model.tok2vec(docs) scores = self.model.predict(docs)
scores = self.model.softmax(tokvecs) assert len(scores) == len(docs), (len(scores), len(docs))
guesses = self._scores2guesses(scores)
assert len(guesses) == len(docs)
return guesses
def _scores2guesses(self, scores):
guesses = [] guesses = []
for doc_scores in scores: for doc_scores in scores:
doc_guesses = doc_scores.argmax(axis=1) doc_guesses = doc_scores.argmax(axis=1)
if not isinstance(doc_guesses, numpy.ndarray): if not isinstance(doc_guesses, numpy.ndarray):
doc_guesses = doc_guesses.get() doc_guesses = doc_guesses.get()
guesses.append(doc_guesses) guesses.append(doc_guesses)
return guesses, tokvecs return guesses
def set_annotations(self, docs, batch_tag_ids, tensors=None): def set_annotations(self, docs, batch_tag_ids):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
@ -483,15 +518,9 @@ class Tagger(Pipe):
else: else:
doc.c[j].tag = self.vocab.strings[self.labels[tag_id]] doc.c[j].tag = self.vocab.strings[self.labels[tag_id]]
idx += 1 idx += 1
if tensors is not None and len(tensors):
if isinstance(doc.tensor, numpy.ndarray) \
and not isinstance(tensors[i], numpy.ndarray):
doc.extend_tensor(tensors[i].get())
else:
doc.extend_tensor(tensors[i])
doc.is_tagged = True doc.is_tagged = True
def update(self, examples, drop=0., sgd=None, losses=None): def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False):
self.require_model() self.require_model()
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
@ -500,13 +529,18 @@ class Tagger(Pipe):
if not any(len(ex.doc) if ex.doc else 0 for ex in examples): if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop) tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples])
loss, d_tag_scores = self.get_loss(examples, tag_scores) loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores, sgd=sgd) bp_tag_scores(d_tag_scores)
if sgd not in (None, False):
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
if set_annotations:
docs = [ex.doc for ex in examples]
self.set_annotations(docs, self._scores2guesses(tag_scores))
def rehearse(self, examples, drop=0., sgd=None, losses=None): def rehearse(self, examples, drop=0., sgd=None, losses=None):
"""Perform a 'rehearsal' update, where we try to match the output of """Perform a 'rehearsal' update, where we try to match the output of
@ -519,10 +553,12 @@ class Tagger(Pipe):
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return
guesses, backprop = self.model.begin_update(docs, drop=drop) set_dropout_rate(self.model, drop)
guesses, backprop = self.model.begin_update(docs)
target = self._rehearsal_model(examples) target = self._rehearsal_model(examples)
gradient = guesses - target gradient = guesses - target
backprop(gradient, sgd=sgd) backprop(gradient)
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum() losses[self.name] += (gradient**2).sum()
@ -546,7 +582,7 @@ class Tagger(Pipe):
known_labels[idx] = 0. known_labels[idx] = 0.
idx += 1 idx += 1
correct = self.model.ops.xp.array(correct, dtype="i") correct = self.model.ops.xp.array(correct, dtype="i")
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1]) d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
d_scores *= self.model.ops.asarray(known_labels) d_scores *= self.model.ops.asarray(known_labels)
loss = (d_scores**2).sum() loss = (d_scores**2).sum()
docs = [ex.doc for ex in examples] docs = [ex.doc for ex in examples]
@ -566,6 +602,7 @@ class Tagger(Pipe):
new_tag_map[tag] = orig_tag_map[tag] new_tag_map[tag] = orig_tag_map[tag]
else: else:
new_tag_map[tag] = {POS: X} new_tag_map[tag] = {POS: X}
cdef Vocab vocab = self.vocab cdef Vocab vocab = self.vocab
if new_tag_map: if new_tag_map:
vocab.morphology = Morphology(vocab.strings, new_tag_map, vocab.morphology = Morphology(vocab.strings, new_tag_map,
@ -577,16 +614,39 @@ class Tagger(Pipe):
if hp in kwargs: if hp in kwargs:
self.cfg[hp] = kwargs[hp] self.cfg[hp] = kwargs[hp]
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
# Get batch of example docs, example outputs to call begin_training().
# This lets the model infer shapes.
n_tags = self.vocab.morphology.n_tags
for node in self.model.walk():
# TODO: softmax hack ?
if node.name == "softmax" and node.has_dim("nO") is None:
node.set_dim("nO", n_tags)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
self.model.initialize()
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
@classmethod @classmethod
def Model(cls, n_tags, **cfg): def Model(cls, n_tags=None, **cfg):
if cfg.get("pretrained_dims") and not cfg.get("pretrained_vectors"): if cfg.get("pretrained_dims") and not cfg.get("pretrained_vectors"):
raise ValueError(TempErrors.T008) raise ValueError(TempErrors.T008)
return build_tagger_model(n_tags, **cfg) if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
config = {
"width": cfg.get("token_vector_width", 96),
"embed_size": cfg.get("embed_size", 2000),
"pretrained_vectors": cfg.get("pretrained_vectors", None),
"window_size": cfg.get("window_size", 1),
"cnn_maxout_pieces": cfg.get("cnn_maxout_pieces", 3),
"subword_features": cfg.get("subword_features", True),
"char_embed": cfg.get("char_embed", False),
"conv_depth": cfg.get("conv_depth", 4),
"bilstm_depth": cfg.get("bilstm_depth", 0),
}
tok2vec = Tok2Vec(**config)
return build_tagger_model(n_tags, tok2vec)
def add_label(self, label, values=None): def add_label(self, label, values=None):
if not isinstance(label, str): if not isinstance(label, str):
@ -633,12 +693,12 @@ class Tagger(Pipe):
def load_model(b): def load_model(b):
# TODO: Remove this once we don't have to handle previous models # TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors.name self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True: if self.model is True:
token_vector_width = util.env_opt( token_vector_width = util.env_opt(
"token_vector_width", "token_vector_width",
self.cfg.get("token_vector_width", 96)) self.cfg.get("token_vector_width", 96))
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) self.model = self.Model(**self.cfg)
try: try:
self.model.from_bytes(b) self.model.from_bytes(b)
except AttributeError: except AttributeError:
@ -676,9 +736,9 @@ class Tagger(Pipe):
def load_model(p): def load_model(p):
# TODO: Remove this once we don't have to handle previous models # TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg: if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors.name self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True: if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) self.model = self.Model(**self.cfg)
with p.open("rb") as file_: with p.open("rb") as file_:
try: try:
self.model.from_bytes(file_.read()) self.model.from_bytes(file_.read())
@ -753,10 +813,12 @@ class SentenceRecognizer(Tagger):
if not any(len(ex.doc) if ex.doc else 0 for ex in examples): if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return
set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop) tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples])
loss, d_tag_scores = self.get_loss(examples, tag_scores) loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores, sgd=sgd) bp_tag_scores(d_tag_scores)
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
@ -780,7 +842,7 @@ class SentenceRecognizer(Tagger):
known_labels[idx] = 0. known_labels[idx] = 0.
idx += 1 idx += 1
correct = self.model.ops.xp.array(correct, dtype="i") correct = self.model.ops.xp.array(correct, dtype="i")
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1]) d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
d_scores *= self.model.ops.asarray(known_labels) d_scores *= self.model.ops.asarray(known_labels)
loss = (d_scores**2).sum() loss = (d_scores**2).sum()
docs = [ex.doc for ex in examples] docs = [ex.doc for ex in examples]
@ -797,6 +859,7 @@ class SentenceRecognizer(Tagger):
self.model = self.Model(len(self.labels), **self.cfg) self.model = self.Model(len(self.labels), **self.cfg)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
self.model.initialize()
return sgd return sgd
@classmethod @classmethod
@ -918,6 +981,7 @@ class MultitaskObjective(Tagger):
token_vector_width = util.env_opt("token_vector_width") token_vector_width = util.env_opt("token_vector_width")
self.model = self.Model(len(self.labels), tok2vec=tok2vec) self.model = self.Model(len(self.labels), tok2vec=tok2vec)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
self.model.initialize()
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
@ -925,14 +989,12 @@ class MultitaskObjective(Tagger):
@classmethod @classmethod
def Model(cls, n_tags, tok2vec=None, **cfg): def Model(cls, n_tags, tok2vec=None, **cfg):
token_vector_width = util.env_opt("token_vector_width", 96) token_vector_width = util.env_opt("token_vector_width", 96)
softmax = Softmax(n_tags, token_vector_width*2)
model = chain( model = chain(
tok2vec, tok2vec,
LayerNorm(Maxout(token_vector_width*2, token_vector_width, pieces=3)), Maxout(nO=token_vector_width*2, nI=token_vector_width, nP=3, dropout=0.0),
softmax LayerNorm(token_vector_width*2),
Softmax(nO=n_tags, nI=token_vector_width*2)
) )
model.tok2vec = tok2vec
model.softmax = softmax
return model return model
def predict(self, docs): def predict(self, docs):
@ -958,7 +1020,7 @@ class MultitaskObjective(Tagger):
correct[idx] = self.labels[label] correct[idx] = self.labels[label]
idx += 1 idx += 1
correct = self.model.ops.xp.array(correct, dtype="i") correct = self.model.ops.xp.array(correct, dtype="i")
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1]) d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
loss = (d_scores**2).sum() loss = (d_scores**2).sum()
return float(loss), d_scores return float(loss), d_scores
@ -1047,19 +1109,18 @@ class ClozeMultitask(Pipe):
def Model(cls, vocab, tok2vec, **cfg): def Model(cls, vocab, tok2vec, **cfg):
output_size = vocab.vectors.data.shape[1] output_size = vocab.vectors.data.shape[1]
output_layer = chain( output_layer = chain(
LayerNorm(Maxout(output_size, tok2vec.nO, pieces=3)), Maxout(nO=output_size, nI=tok2vec.get_dim("nO"), nP=3, normalize=True, dropout=0.0),
zero_init(Affine(output_size, output_size, drop_factor=0.0)) Linear(nO=output_size, nI=output_size, init_W=zero_init)
) )
model = chain(tok2vec, output_layer) model = chain(tok2vec, output_layer)
model = masked_language_model(vocab, model) model = masked_language_model(vocab, model)
model.tok2vec = tok2vec
model.output_layer = output_layer
return model return model
def __init__(self, vocab, model=True, **cfg): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.cfg = cfg self.cfg = cfg
self.distance = CosineDistance(ignore_zeros=True, normalize=False)
def set_annotations(self, docs, dep_ids, tensors=None): def set_annotations(self, docs, dep_ids, tensors=None):
pass pass
@ -1069,7 +1130,8 @@ class ClozeMultitask(Pipe):
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if self.model is True: if self.model is True:
self.model = self.Model(self.vocab, tok2vec) self.model = self.Model(self.vocab, tok2vec)
X = self.model.ops.allocate((5, self.model.tok2vec.nO)) X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
self.model.initialize()
self.model.output_layer.begin_training(X) self.model.output_layer.begin_training(X)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
@ -1088,10 +1150,11 @@ class ClozeMultitask(Pipe):
# and look them up all at once. This prevents data copying. # and look them up all at once. This prevents data copying.
ids = self.model.ops.flatten([ex.doc.to_array(ID).ravel() for ex in examples]) ids = self.model.ops.flatten([ex.doc.to_array(ID).ravel() for ex in examples])
target = vectors[ids] target = vectors[ids]
loss, gradient = get_cossim_loss(prediction, target, ignore_zeros=True) gradient = self.distance.get_grad(prediction, target)
return float(loss), gradient loss = self.distance.get_loss(prediction, target)
return loss, gradient
def update(self, examples, drop=0., sgd=None, losses=None): def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
pass pass
def rehearse(self, examples, drop=0., sgd=None, losses=None): def rehearse(self, examples, drop=0., sgd=None, losses=None):
@ -1099,9 +1162,12 @@ class ClozeMultitask(Pipe):
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
losses[self.name] = 0. losses[self.name] = 0.
predictions, bp_predictions = self.model.begin_update([ex.doc for ex in examples], drop=drop) set_dropout_rate(self.model, drop)
predictions, bp_predictions = self.model.begin_update([ex.doc for ex in examples])
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions) loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
bp_predictions(d_predictions, sgd=sgd) bp_predictions(d_predictions)
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
@ -1115,19 +1181,45 @@ class TextCategorizer(Pipe):
""" """
@classmethod @classmethod
def Model(cls, nr_class=1, **cfg): def Model(cls, nr_class=1, exclusive_classes=None, **cfg):
embed_size = util.env_opt("embed_size", 2000) if nr_class == 1:
if "token_vector_width" in cfg: exclusive_classes = False
token_vector_width = cfg["token_vector_width"] if exclusive_classes is None:
raise ValueError(
"TextCategorizer Model must specify 'exclusive_classes'. "
"This setting determines whether the model will output "
"scores that sum to 1 for each example. If only one class "
"is true for each example, you should set exclusive_classes=True. "
"For 'multi_label' classification, set exclusive_classes=False."
)
if "embed_size" not in cfg:
cfg["embed_size"] = util.env_opt("embed_size", 2000)
if "token_vector_width" not in cfg:
cfg["token_vector_width"] = util.env_opt("token_vector_width", 96)
if cfg.get("architecture") == "bow":
return build_bow_text_classifier(nr_class, exclusive_classes, **cfg)
else: else:
token_vector_width = util.env_opt("token_vector_width", 96) if "tok2vec" in cfg:
if cfg.get("architecture") == "simple_cnn": tok2vec = cfg["tok2vec"]
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) else:
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) config = {
elif cfg.get("architecture") == "bow": "width": cfg.get("token_vector_width", 96),
return build_bow_text_classifier(nr_class, **cfg) "embed_size": cfg.get("embed_size", 2000),
else: "pretrained_vectors": cfg.get("pretrained_vectors", None),
return build_text_classifier(nr_class, **cfg) "window_size": cfg.get("window_size", 1),
"cnn_maxout_pieces": cfg.get("cnn_maxout_pieces", 3),
"subword_features": cfg.get("subword_features", True),
"char_embed": cfg.get("char_embed", False),
"conv_depth": cfg.get("conv_depth", 4),
"bilstm_depth": cfg.get("bilstm_depth", 0),
}
tok2vec = Tok2Vec(**config)
return build_simple_cnn_text_classifier(
tok2vec,
nr_class,
exclusive_classes,
**cfg
)
@property @property
def tok2vec(self): def tok2vec(self):
@ -1141,6 +1233,8 @@ class TextCategorizer(Pipe):
self.model = model self.model = model
self._rehearsal_model = None self._rehearsal_model = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
if "exclusive_classes" not in cfg:
self.cfg["exclusive_classes"] = True
@property @property
def labels(self): def labels(self):
@ -1180,7 +1274,7 @@ class TextCategorizer(Pipe):
scores = xp.zeros((len(docs), len(self.labels))) scores = xp.zeros((len(docs), len(self.labels)))
return scores, tensors return scores, tensors
scores = self.model(docs) scores = self.model.predict(docs)
scores = self.model.ops.asarray(scores) scores = self.model.ops.asarray(scores)
return scores, tensors return scores, tensors
@ -1189,18 +1283,24 @@ class TextCategorizer(Pipe):
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j]) doc.cats[label] = float(scores[i, j])
def update(self, examples, state=None, drop=0., sgd=None, losses=None): def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None):
self.require_model() self.require_model()
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
if not any(len(ex.doc) if ex.doc else 0 for ex in examples): if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return
scores, bp_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([ex.doc for ex in examples])
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores, sgd=sgd) bp_scores(d_scores)
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
losses[self.name] += loss losses[self.name] += loss
if set_annotations:
docs = [ex.doc for ex in examples]
self.set_annotations(docs, scores=scores)
def rehearse(self, examples, drop=0., sgd=None, losses=None): def rehearse(self, examples, drop=0., sgd=None, losses=None):
if self._rehearsal_model is None: if self._rehearsal_model is None:
@ -1210,10 +1310,13 @@ class TextCategorizer(Pipe):
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return
scores, bp_scores = self.model.begin_update(docs, drop=drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs)
target = self._rehearsal_model(examples) target = self._rehearsal_model(examples)
gradient = scores - target gradient = scores - target
bp_scores(gradient, sgd=sgd) bp_scores(gradient)
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum() losses[self.name] += (gradient**2).sum()
@ -1247,7 +1350,7 @@ class TextCategorizer(Pipe):
# - a huge problem. # - a huge problem.
raise ValueError(Errors.E116) raise ValueError(Errors.E116)
# smaller = self.model._layers[-1] # smaller = self.model._layers[-1]
# larger = Affine(len(self.labels)+1, smaller.nI) # larger = Linear(len(self.labels)+1, smaller.nI)
# copy_array(larger.W[:smaller.nO], smaller.W) # copy_array(larger.W[:smaller.nO], smaller.W)
# copy_array(larger.b[:smaller.nO], smaller.b) # copy_array(larger.b[:smaller.nO], smaller.b)
# self.model._layers[-1] = larger # self.model._layers[-1] = larger
@ -1259,12 +1362,15 @@ class TextCategorizer(Pipe):
for cat in example.doc_annotation.cats: for cat in example.doc_annotation.cats:
self.add_label(cat) self.add_label(cat)
if self.model is True: if self.model is True:
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors") self.cfg.update(kwargs)
self.require_labels() self.require_labels()
self.model = self.Model(len(self.labels), **self.cfg) self.model = self.Model(len(self.labels), **self.cfg)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
# TODO: use get_examples instead
docs = [Doc(Vocab(), words=["hello"])]
self.model.initialize(X=docs)
return sgd return sgd
@ -1382,6 +1488,7 @@ class EntityLinker(Pipe):
self.model = True self.model = True
self.kb = None self.kb = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False)
def set_kb(self, kb): def set_kb(self, kb):
self.kb = kb self.kb = kb
@ -1399,16 +1506,14 @@ class EntityLinker(Pipe):
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs): def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
self.require_kb() self.require_kb()
self.cfg["entity_width"] = self.kb.entity_vector_length self.cfg["entity_width"] = self.kb.entity_vector_length
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
self.model.initialize()
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
def update(self, examples, state=None, drop=0.0, sgd=None, losses=None): def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None):
self.require_model() self.require_model()
self.require_kb() self.require_kb()
if losses is not None: if losses is not None:
@ -1416,9 +1521,12 @@ class EntityLinker(Pipe):
if not examples: if not examples:
return 0 return 0
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
sentence_docs = [] sentence_docs = []
docs = [ex.doc for ex in examples] docs = [ex.doc for ex in examples]
if set_annotations:
# This seems simpler than other ways to get that exact output -- but
# it does run the model twice :(
predictions = self.model.predict(docs)
golds = [ex.gold for ex in examples] golds = [ex.gold for ex in examples]
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
@ -1443,13 +1551,17 @@ class EntityLinker(Pipe):
except AttributeError: except AttributeError:
# Catch the exception when ent.sent is None and provide a user-friendly warning # Catch the exception when ent.sent is None and provide a user-friendly warning
raise RuntimeError(Errors.E030) raise RuntimeError(Errors.E030)
set_dropout_rate(self.model, drop)
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop) sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds) loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
bp_context(d_scores, sgd=sgd) bp_context(d_scores)
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
if set_annotations:
self.set_annotations(docs, predictions)
return loss return loss
def get_similarity_loss(self, golds, scores): def get_similarity_loss(self, golds, scores):
@ -1467,7 +1579,8 @@ class EntityLinker(Pipe):
if scores.shape != entity_encodings.shape: if scores.shape != entity_encodings.shape:
raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up")) raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
loss, gradients = get_cossim_loss(yh=scores, y=entity_encodings) gradients = self.distance.get_grad(scores, entity_encodings)
loss = self.distance.get_loss(scores, entity_encodings)
loss = loss / len(entity_encodings) loss = loss / len(entity_encodings)
return loss, gradients return loss, gradients
@ -1533,7 +1646,7 @@ class EntityLinker(Pipe):
for sent in doc.sents: for sent in doc.sents:
sent_doc = sent.as_doc() sent_doc = sent.as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
xp = get_array_module(sentence_encoding) xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_norm = xp.linalg.norm(sentence_encoding_t)
@ -1720,7 +1833,6 @@ class Sentencizer(Pipe):
self.set_annotations(docs, scores, tensors=tensors) self.set_annotations(docs, scores, tensors=tensors)
else: else:
self.set_annotations(docs, predictions) self.set_annotations(docs, predictions)
if as_example: if as_example:
annotated_examples = [] annotated_examples = []
for ex, doc in zip(examples, docs): for ex, doc in zip(examples, docs):
@ -1729,7 +1841,7 @@ class Sentencizer(Pipe):
yield from annotated_examples yield from annotated_examples
else: else:
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
"""Apply the pipeline's model to a batch of docs, without """Apply the pipeline's model to a batch of docs, without
modifying them. modifying them.

188
spacy/pipeline/tok2vec.py Normal file
View File

@ -0,0 +1,188 @@
from .pipes import Pipe
from ..gold import Example
from ..tokens import Doc
from ..vocab import Vocab
from ..language import component
from ..util import link_vectors_to_models, minibatch, registry, eg2doc
from thinc.model import Model, set_dropout_rate
@component("tok2vec", assigns=["doc.tensor"])
class Tok2Vec(Pipe):
@classmethod
def from_nlp(cls, nlp, **cfg):
return cls(nlp.vocab, **cfg)
@classmethod
def Model(cls, architecture, **cfg):
"""Create a new statistical model for the class.
architecture (str): The registered model architecture to use.
**cfg: Config parameters.
RETURNS (Model): A `thinc.model.Model` or similar instance.
"""
model = registry.architectures.get(architecture)
return model(**cfg)
def __init__(self, vocab, model=True, **cfg):
"""Construct a new statistical model. Weights are not allocated on
initialisation.
vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab`
instance with the `Doc` objects it will process.
model (Model): A `Model` instance or `True` to allocate one later.
**cfg: Config parameters.
"""
self.vocab = vocab
self.model = model
self.cfg = dict(cfg)
self.listeners = []
def create_listener(self):
listener = Tok2VecListener(upstream_name="tok2vec", width=self.model.get_dim("nO"))
self.listeners.append(listener)
def add_listener(self, listener):
self.listeners.append(listener)
def find_listeners(self, model):
for node in model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name == self.name:
self.add_listener(node)
def __call__(self, doc):
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
model. Vectors are set to the `Doc.tensor` attribute.
docs (Doc or iterable): One or more documents to add vectors to.
RETURNS (dict or None): Intermediate computations.
"""
tokvecses = self.predict([doc])
self.set_annotations([doc], tokvecses)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
"""Process `Doc` objects as a stream.
stream (iterator): A sequence of `Doc` objects to process.
batch_size (int): Number of `Doc` objects to group.
n_threads (int): Number of threads.
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
"""
for batch in minibatch(stream, batch_size):
batch = list(batch)
if as_example:
docs = [eg2doc(doc) for doc in batch]
else:
docs = batch
tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses)
yield from batch
def predict(self, docs):
"""Return a single tensor for a batch of documents.
docs (iterable): A sequence of `Doc` objects.
RETURNS (object): Vector representations for each token in the documents.
"""
tokvecs = self.model.predict(docs)
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners:
listener.receive(batch_id, tokvecs, None)
return tokvecs
def set_annotations(self, docs, tokvecses):
"""Set the tensor attribute for a batch of documents.
docs (iterable): A sequence of `Doc` objects.
tokvecs (object): Vector representation for each token in the documents.
"""
for doc, tokvecs in zip(docs, tokvecses):
assert tokvecs.shape[0] == len(doc)
doc.tensor = tokvecs
def update(self, examples, drop=0.0, sgd=None, losses=None, set_annotations=False):
"""Update the model.
examples (iterable): A batch of examples
drop (float): The droput rate.
sgd (callable): An optimizer.
RETURNS (dict): Results from the update.
"""
if losses is None:
losses = {}
examples = Example.to_example_objects(examples)
docs = [eg.doc for eg in examples]
if isinstance(docs, Doc):
docs = [docs]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
def capture_losses(d_tokvecs):
"""Accumulate tok2vec loss before doing backprop."""
l2_loss = sum((d_t2v**2).sum() for d_t2v in d_tokvecs)
if self.name in losses:
losses[self.name] += l2_loss / len(d_tokvecs)
else:
losses[self.name] = l2_loss / len(d_tokvecs)
return bp_tokvecs(d_tokvecs)
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners:
listener.receive(batch_id, tokvecs, capture_losses)
if sgd is not None:
self.model.finish_update(sgd)
if set_annotations:
self.set_annotations(docs, tokvecs)
def get_loss(self, docs, golds, scores):
pass
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
"""Allocate models and pre-process training data
get_examples (function): Function returning example training data.
pipeline (list): The pipeline the model is part of.
"""
if self.model is True:
self.model = self.Model(**self.cfg)
# TODO: use examples instead ?
docs = [Doc(Vocab(), words=["hello"])]
self.model.initialize(X=docs)
link_vectors_to_models(self.vocab)
class Tok2VecListener(Model):
"""A layer that gets fed its answers from an upstream connection,
for instance from a component earlier in the pipeline.
"""
name = "tok2vec-listener"
def __init__(self, upstream_name, width):
Model.__init__(self, name=self.name, forward=forward, dims={"nO": width})
self.upstream_name = upstream_name
self._batch_id = None
self._outputs = None
self._backprop = None
@classmethod
def get_batch_id(cls, inputs):
return sum(sum(token.orth for token in doc) for doc in inputs)
def receive(self, batch_id, outputs, backprop):
self._batch_id = batch_id
self._outputs = outputs
self._backprop = backprop
def verify_inputs(self, inputs):
if self._batch_id is None and self._outputs is None:
raise ValueError
else:
batch_id = self.get_batch_id(inputs)
if batch_id != self._batch_id:
raise ValueError(f"Mismatched IDs! {batch_id} vs {self._batch_id}")
else:
return True
def forward(model: Tok2VecListener, inputs, is_train):
if is_train:
model.verify_inputs(inputs)
return model._outputs, model._backprop
else:
return [doc.tensor for doc in inputs], lambda dX: []

View File

@ -1,4 +1,4 @@
from thinc.typedefs cimport class_t, hash_t from ..typedefs cimport hash_t, class_t
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1

View File

@ -5,9 +5,9 @@ import numpy
from cpython.ref cimport PyObject, Py_XDECREF from cpython.ref cimport PyObject, Py_XDECREF
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation from thinc.extra.search import MaxViolation
from thinc.typedefs cimport hash_t, class_t
from thinc.extra.search cimport MaxViolation from thinc.extra.search cimport MaxViolation
from ..typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from ..gold cimport GoldParse from ..gold cimport GoldParse
from ..errors import Errors from ..errors import Errors

View File

@ -1,6 +1,6 @@
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc from libc.stdlib cimport calloc, free, realloc
from thinc.typedefs cimport weight_t, class_t, hash_t from ..typedefs cimport weight_t, class_t, hash_t
from ._state cimport StateC from ._state cimport StateC

View File

@ -10,18 +10,14 @@ from libcpp.vector cimport vector
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc from libc.stdlib cimport calloc, free, realloc
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t, class_t, hash_t
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from thinc.api import chain, clone from thinc.layers import Linear
from thinc.v2v import Model, Maxout, Affine from thinc.model import Model
from thinc.misc import LayerNorm from thinc.backends import CupyOps, NumpyOps, use_ops
from thinc.neural.ops import CupyOps, NumpyOps from thinc.backends.linalg cimport Vec, VecVec
from thinc.neural.util import get_array_module
from thinc.linalg cimport Vec, VecVec
cimport blis.cy cimport blis.cy
from .._ml import zero_init, PrecomputableAffine, Tok2Vec, flatten from ..typedefs cimport weight_t, class_t, hash_t
from .._ml import link_vectors_to_models, create_default_optimizer
from ..compat import copy_array from ..compat import copy_array
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..gold cimport GoldParse from ..gold cimport GoldParse
@ -31,6 +27,7 @@ from .stateclass cimport StateClass
from .transition_system cimport Transition from .transition_system cimport Transition
from . import _beam_utils from . import _beam_utils
from . import nonproj from . import nonproj
from ..util import link_vectors_to_models, create_default_optimizer
cdef WeightsC get_c_weights(model) except *: cdef WeightsC get_c_weights(model) except *:
@ -44,8 +41,8 @@ cdef WeightsC get_c_weights(model) except *:
output.hidden_weights = NULL output.hidden_weights = NULL
output.hidden_bias = NULL output.hidden_bias = NULL
else: else:
vec2scores_W = model.vec2scores.W vec2scores_W = model.vec2scores.get_param("W")
vec2scores_b = model.vec2scores.b vec2scores_b = model.vec2scores.get_param("b")
output.hidden_weights = <const float*>vec2scores_W.data output.hidden_weights = <const float*>vec2scores_W.data
output.hidden_bias = <const float*>vec2scores_b.data output.hidden_bias = <const float*>vec2scores_b.data
cdef np.ndarray class_mask = model._class_mask cdef np.ndarray class_mask = model._class_mask
@ -57,12 +54,12 @@ cdef SizesC get_c_sizes(model, int batch_size) except *:
cdef SizesC output cdef SizesC output
output.states = batch_size output.states = batch_size
if model.vec2scores is None: if model.vec2scores is None:
output.classes = model.state2vec.nO output.classes = model.state2vec.get_dim("nO")
else: else:
output.classes = model.vec2scores.nO output.classes = model.vec2scores.get_dim("nO")
output.hiddens = model.state2vec.nO output.hiddens = model.state2vec.get_dim("nO")
output.pieces = model.state2vec.nP output.pieces = model.state2vec.get_dim("nP")
output.feats = model.state2vec.nF output.feats = model.state2vec.get_dim("nF")
output.embed_width = model.tokvecs.shape[1] output.embed_width = model.tokvecs.shape[1]
return output return output
@ -226,7 +223,7 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
class ParserModel(Model): class ParserModel(Model):
def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None): def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
Model.__init__(self) Model.__init__(self, name="parser_model", forward=forward)
self._layers = [tok2vec, lower_model] self._layers = [tok2vec, lower_model]
if upper_model is not None: if upper_model is not None:
self._layers.append(upper_model) self._layers.append(upper_model)
@ -235,41 +232,47 @@ class ParserModel(Model):
for class_ in unseen_classes: for class_ in unseen_classes:
self.unseen_classes.add(class_) self.unseen_classes.add(class_)
def begin_update(self, docs, drop=0.): def predict(self, docs):
step_model = ParserStepModel(docs, self._layers, drop=drop, step_model = ParserStepModel(docs, self._layers,
unseen_classes=self.unseen_classes) unseen_classes=self.unseen_classes, train=False)
def finish_parser_update(golds, sgd=None): return step_model
step_model.make_updates(sgd)
return None
return step_model, finish_parser_update
def resize_output(self, new_output): def resize_output(self, new_nO):
if len(self._layers) == 2: if len(self._layers) == 2:
return return
if new_output == self.upper.nO: if new_nO == self.upper.get_dim("nO"):
return return
smaller = self.upper smaller = self.upper
nI = smaller.get_dim("nI")
with Model.use_device('cpu'): with use_ops('numpy'):
larger = Affine(new_output, smaller.nI) larger = Linear(new_nO, nI)
larger.W.fill(0.0) larger_W = larger.ops.alloc2f(new_nO, nI)
larger.b.fill(0.0) larger_b = larger.ops.alloc1f(new_nO)
# It seems very unhappy if I pass these as smaller.W? smaller_W = smaller.get_param("W")
# Seems to segfault. Maybe it's a descriptor protocol thing? smaller_b = smaller.get_param("b")
smaller_W = smaller.W
larger_W = larger.W
smaller_b = smaller.b
larger_b = larger.b
# Weights are stored in (nr_out, nr_in) format, so we're basically # Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here. # just adding rows here.
larger_W[:smaller.nO] = smaller_W larger_W[:smaller.get_dim("nO")] = smaller_W
larger_b[:smaller.nO] = smaller_b larger_b[:smaller.get_dim("nO")] = smaller_b
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
self._layers[-1] = larger self._layers[-1] = larger
for i in range(smaller.nO, new_output): for i in range(smaller.get_dim("nO"), new_nO):
self.unseen_classes.add(i) self.unseen_classes.add(i)
def begin_training(self, X, y=None): def initialize(self, X=None, Y=None):
self.lower.begin_training(X, y=y) self.tok2vec.initialize()
self.lower.initialize(X=X, Y=Y)
if self.upper is not None:
# In case we need to trigger the callbacks
statevecs = self.ops.alloc((2, self.lower.get_dim("nO")))
self.upper.initialize(X=statevecs)
def finish_update(self, optimizer):
self.tok2vec.finish_update(optimizer)
self.lower.finish_update(optimizer)
if self.upper is not None:
self.upper.finish_update(optimizer)
@property @property
def tok2vec(self): def tok2vec(self):
@ -284,17 +287,25 @@ class ParserModel(Model):
return self._layers[2] return self._layers[2]
def forward(model:ParserModel, X, is_train):
step_model = ParserStepModel(X, model._layers, unseen_classes=model.unseen_classes,
train=is_train)
return step_model, step_model.finish_steps
class ParserStepModel(Model): class ParserStepModel(Model):
def __init__(self, docs, layers, unseen_classes=None, drop=0.): def __init__(self, docs, layers, unseen_classes=None, train=True):
self.tokvecs, self.bp_tokvecs = layers[0].begin_update(docs, drop=drop) Model.__init__(self, name="parser_step_model", forward=step_forward)
if layers[1].nP >= 2: self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
if layers[1].get_dim("nP") >= 2:
activation = "maxout" activation = "maxout"
elif len(layers) == 2: elif len(layers) == 2:
activation = None activation = None
else: else:
activation = "relu" activation = "relu"
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1], self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
activation=activation, drop=drop) activation=activation, train=train)
if len(layers) == 3: if len(layers) == 3:
self.vec2scores = layers[-1] self.vec2scores = layers[-1]
else: else:
@ -304,7 +315,7 @@ class ParserStepModel(Model):
if self.vec2scores is None: if self.vec2scores is None:
self._class_mask = numpy.zeros((self.state2vec.nO,), dtype='f') self._class_mask = numpy.zeros((self.state2vec.nO,), dtype='f')
else: else:
self._class_mask = numpy.zeros((self.vec2scores.nO,), dtype='f') self._class_mask = numpy.zeros((self.vec2scores.get_dim("nO"),), dtype='f')
self._class_mask.fill(1) self._class_mask.fill(1)
if unseen_classes is not None: if unseen_classes is not None:
for class_ in unseen_classes: for class_ in unseen_classes:
@ -323,40 +334,6 @@ class ParserStepModel(Model):
def mark_class_seen(self, class_): def mark_class_seen(self, class_):
self._class_mask[class_] = 1 self._class_mask[class_] = 1
def begin_update(self, states, drop=0.):
token_ids = self.get_token_ids(states)
vector, get_d_tokvecs = self.state2vec.begin_update(token_ids, drop=0.0)
if self.vec2scores is not None:
mask = self.vec2scores.ops.get_dropout_mask(vector.shape, drop)
if mask is not None:
vector *= mask
scores, get_d_vector = self.vec2scores.begin_update(vector, drop=drop)
else:
scores = NumpyOps().asarray(vector)
get_d_vector = lambda d_scores, sgd=None: d_scores
mask = None
# If the class is unseen, make sure its score is minimum
scores[:, self._class_mask == 0] = numpy.nanmin(scores)
def backprop_parser_step(d_scores, sgd=None):
# Zero vectors for unseen classes
d_scores *= self._class_mask
d_vector = get_d_vector(d_scores, sgd=sgd)
if mask is not None:
d_vector *= mask
if isinstance(self.state2vec.ops, CupyOps) \
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
# Move token_ids and d_vector to GPU, asynchronously
self.backprops.append((
util.get_async(self.cuda_stream, token_ids),
util.get_async(self.cuda_stream, d_vector),
get_d_tokvecs
))
else:
self.backprops.append((token_ids, d_vector, get_d_tokvecs))
return None
return scores, backprop_parser_step
def get_token_ids(self, batch): def get_token_ids(self, batch):
states = _beam_utils.collect_states(batch) states = _beam_utils.collect_states(batch)
cdef StateClass state cdef StateClass state
@ -370,25 +347,56 @@ class ParserStepModel(Model):
c_ids += ids.shape[1] c_ids += ids.shape[1]
return ids return ids
def make_updates(self, sgd): def finish_steps(self, golds):
# Add a padding vector to the d_tokvecs gradient, so that missing # Add a padding vector to the d_tokvecs gradient, so that missing
# values don't affect the real gradient. # values don't affect the real gradient.
d_tokvecs = self.ops.allocate((self.tokvecs.shape[0]+1, self.tokvecs.shape[1])) d_tokvecs = self.ops.alloc((self.tokvecs.shape[0]+1, self.tokvecs.shape[1]))
# Tells CUDA to block, so our async copies complete. # Tells CUDA to block, so our async copies complete.
if self.cuda_stream is not None: if self.cuda_stream is not None:
self.cuda_stream.synchronize() self.cuda_stream.synchronize()
for ids, d_vector, bp_vector in self.backprops: for ids, d_vector, bp_vector in self.backprops:
d_state_features = bp_vector((d_vector, ids), sgd=sgd) d_state_features = bp_vector((d_vector, ids))
ids = ids.flatten() ids = ids.flatten()
d_state_features = d_state_features.reshape( d_state_features = d_state_features.reshape(
(ids.size, d_state_features.shape[2])) (ids.size, d_state_features.shape[2]))
self.ops.scatter_add(d_tokvecs, ids, self.ops.scatter_add(d_tokvecs, ids,
d_state_features) d_state_features)
# Padded -- see update() # Padded -- see update()
self.bp_tokvecs(d_tokvecs[:-1], sgd=sgd) if isinstance(self.ops, CupyOps):
d_tokvecs = self.ops.to_numpy(d_tokvecs)
self.bp_tokvecs(d_tokvecs[:-1])
return d_tokvecs return d_tokvecs
def step_forward(model: ParserStepModel, states, is_train):
token_ids = model.get_token_ids(states)
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
if model.vec2scores is not None:
scores, get_d_vector = model.vec2scores(vector, is_train)
else:
scores = NumpyOps().asarray(vector)
get_d_vector = lambda d_scores: d_scores
# If the class is unseen, make sure its score is minimum
scores[:, model._class_mask == 0] = numpy.nanmin(scores)
def backprop_parser_step(d_scores):
# Zero vectors for unseen classes
d_scores *= model._class_mask
d_vector = get_d_vector(d_scores)
if isinstance(model.state2vec.ops, CupyOps) \
and not isinstance(token_ids, model.state2vec.ops.xp.ndarray):
# Move token_ids and d_vector to GPU, asynchronously
model.backprops.append((
util.get_async(model.cuda_stream, token_ids),
util.get_async(model.cuda_stream, d_vector),
get_d_tokvecs
))
else:
model.backprops.append((token_ids, d_vector, get_d_tokvecs))
return None
return scores, backprop_parser_step
cdef class precompute_hiddens: cdef class precompute_hiddens:
"""Allow a model to be "primed" by pre-computing input features in bulk. """Allow a model to be "primed" by pre-computing input features in bulk.
@ -406,7 +414,7 @@ cdef class precompute_hiddens:
we can do all our hard maths up front, packed into large multiplications, we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU. and do the hard-to-program parsing on the CPU.
""" """
cdef readonly int nF, nO, nP cdef readonly int nF, nO, nP # TODO: make these more like the dimensions in thinc
cdef bint _is_synchronized cdef bint _is_synchronized
cdef public object ops cdef public object ops
cdef np.ndarray _features cdef np.ndarray _features
@ -417,8 +425,8 @@ cdef class precompute_hiddens:
cdef object activation cdef object activation
def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None, def __init__(self, batch_size, tokvecs, lower_model, cuda_stream=None,
activation="maxout", drop=0.): activation="maxout", train=False):
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop) gpu_cached, bp_features = lower_model(tokvecs, train)
cdef np.ndarray cached cdef np.ndarray cached
if not isinstance(gpu_cached, numpy.ndarray): if not isinstance(gpu_cached, numpy.ndarray):
# Note the passing of cuda_stream here: it lets # Note the passing of cuda_stream here: it lets
@ -427,12 +435,16 @@ cdef class precompute_hiddens:
cached = gpu_cached.get(stream=cuda_stream) cached = gpu_cached.get(stream=cuda_stream)
else: else:
cached = gpu_cached cached = gpu_cached
if not isinstance(lower_model.b, numpy.ndarray): if not isinstance(lower_model.get_param("b"), numpy.ndarray):
self.bias = lower_model.b.get() # self.bias = lower_model.get_param("b").get(stream=cuda_stream) ???
self.bias = lower_model.get_param("b")
else: else:
self.bias = lower_model.b self.bias = lower_model.get_param("b")
self.nF = cached.shape[1] self.nF = cached.shape[1]
self.nP = getattr(lower_model, 'nP', 1) if lower_model.has_dim("nP"):
self.nP = lower_model.get_dim("nP")
else:
self.nP = 1
self.nO = cached.shape[2] self.nO = cached.shape[2]
self.ops = lower_model.ops self.ops = lower_model.ops
assert activation in (None, "relu", "maxout") assert activation in (None, "relu", "maxout")
@ -448,10 +460,26 @@ cdef class precompute_hiddens:
self._is_synchronized = True self._is_synchronized = True
return <float*>self._cached.data return <float*>self._cached.data
def __call__(self, X): def get_dim(self, name):
return self.begin_update(X, drop=None)[0] if name == "nF":
return self.nF
elif name == "nP":
return self.nP
elif name == "nO":
return self.nO
else:
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
def begin_update(self, token_ids, drop=0.): def __call__(self, X, bint is_train):
if is_train:
return self.begin_update(X)
else:
return self.predict(X), lambda X: X
def predict(self, X):
return self.begin_update(X)[0]
def begin_update(self, token_ids):
cdef np.ndarray state_vector = numpy.zeros( cdef np.ndarray state_vector = numpy.zeros(
(token_ids.shape[0], self.nO, self.nP), dtype='f') (token_ids.shape[0], self.nO, self.nP), dtype='f')
# This is tricky, but (assuming GPU available); # This is tricky, but (assuming GPU available);
@ -466,13 +494,13 @@ cdef class precompute_hiddens:
sum_state_features(<float*>state_vector.data, sum_state_features(<float*>state_vector.data,
feat_weights, &ids[0,0], feat_weights, &ids[0,0],
token_ids.shape[0], self.nF, self.nO*self.nP) token_ids.shape[0], self.nF, self.nO*self.nP)
state_vector += self.bias state_vector = state_vector + self.bias
state_vector, bp_nonlinearity = self._nonlinearity(state_vector) state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
def backward(d_state_vector_ids, sgd=None): def backward(d_state_vector_ids):
d_state_vector, token_ids = d_state_vector_ids d_state_vector, token_ids = d_state_vector_ids
d_state_vector = bp_nonlinearity(d_state_vector, sgd) d_state_vector = bp_nonlinearity(d_state_vector)
d_tokens = bp_hiddens((d_state_vector, token_ids), sgd) d_tokens = bp_hiddens((d_state_vector, token_ids))
return d_tokens return d_tokens
return state_vector, backward return state_vector, backward
@ -492,7 +520,7 @@ cdef class precompute_hiddens:
else: else:
mask = None mask = None
def backprop_nonlinearity(d_best, sgd=None): def backprop_nonlinearity(d_best):
if isinstance(d_best, numpy.ndarray): if isinstance(d_best, numpy.ndarray):
ops = NumpyOps() ops = NumpyOps()
else: else:

View File

@ -1,6 +1,6 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t from ..typedefs cimport weight_t
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ..typedefs cimport attr_t from ..typedefs cimport attr_t

View File

@ -1,7 +1,7 @@
from thinc.typedefs cimport weight_t
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from collections import Counter from collections import Counter
from ..typedefs cimport weight_t
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from .transition_system cimport Transition from .transition_system cimport Transition

View File

@ -1,5 +1,3 @@
from thinc.typedefs cimport atom_t
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .arc_eager cimport TransitionSystem from .arc_eager cimport TransitionSystem
from ..vocab cimport Vocab from ..vocab cimport Vocab

View File

@ -13,24 +13,23 @@ from libcpp.vector cimport vector
from libc.string cimport memset, memcpy from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t, class_t, hash_t
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from thinc.api import chain, clone from thinc.layers import chain, clone, Linear, list2array
from thinc.v2v import Model, Maxout, Affine from thinc.backends import NumpyOps, CupyOps, use_ops
from thinc.misc import LayerNorm from thinc.util import get_array_module
from thinc.neural.ops import NumpyOps, CupyOps from thinc.backends.linalg cimport Vec, VecVec
from thinc.neural.util import get_array_module from thinc.initializers import zero_init
from thinc.linalg cimport Vec, VecVec from thinc.model import set_dropout_rate
import srsly import srsly
from spacy.gold import Example from spacy.gold import Example
from ..typedefs cimport weight_t, class_t, hash_t
from ._parser_model cimport alloc_activations, free_activations from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid from ._parser_model cimport predict_states, arg_max_if_valid
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ._parser_model cimport get_c_weights, get_c_sizes from ._parser_model cimport get_c_weights, get_c_sizes
from ._parser_model import ParserModel from ._parser_model import ParserModel
from .._ml import zero_init, PrecomputableAffine, Tok2Vec, flatten from ..util import link_vectors_to_models, create_default_optimizer
from .._ml import link_vectors_to_models, create_default_optimizer
from ..compat import copy_array from ..compat import copy_array
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..gold cimport GoldParse from ..gold cimport GoldParse
@ -44,6 +43,10 @@ from . import _beam_utils
from . import nonproj from . import nonproj
from ..ml._layers import PrecomputableAffine
from ..ml.component_models import Tok2Vec
cdef class Parser: cdef class Parser:
""" """
Base class of the DependencyParser and EntityRecognizer. Base class of the DependencyParser and EntityRecognizer.
@ -54,7 +57,7 @@ cdef class Parser:
subword_features = util.env_opt('subword_features', subword_features = util.env_opt('subword_features',
cfg.get('subword_features', True)) cfg.get('subword_features', True))
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4)) conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
conv_window = util.env_opt('conv_window', cfg.get('conv_depth', 1)) window_size = util.env_opt('window_size', cfg.get('window_size', 1))
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3)) t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0)) bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0)) self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
@ -71,23 +74,23 @@ cdef class Parser:
parser_maxout_pieces = 1 parser_maxout_pieces = 1
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 2000)) embed_size = util.env_opt('embed_size', cfg.get('embed_size', 2000))
pretrained_vectors = cfg.get('pretrained_vectors', None) pretrained_vectors = cfg.get('pretrained_vectors', None)
tok2vec = Tok2Vec(token_vector_width, embed_size, tok2vec = Tok2Vec(width=token_vector_width,
embed_size=embed_size,
conv_depth=conv_depth, conv_depth=conv_depth,
conv_window=conv_window, window_size=window_size,
cnn_maxout_pieces=t2v_pieces, cnn_maxout_pieces=t2v_pieces,
subword_features=subword_features, subword_features=subword_features,
pretrained_vectors=pretrained_vectors, pretrained_vectors=pretrained_vectors,
bilstm_depth=bilstm_depth) bilstm_depth=bilstm_depth)
tok2vec = chain(tok2vec, flatten) tok2vec = chain(tok2vec, list2array())
tok2vec.nO = token_vector_width tok2vec.set_dim("nO", token_vector_width)
lower = PrecomputableAffine(hidden_width, lower = PrecomputableAffine(hidden_width,
nF=nr_feature_tokens, nI=token_vector_width, nF=nr_feature_tokens, nI=token_vector_width,
nP=parser_maxout_pieces) nP=parser_maxout_pieces)
lower.nP = parser_maxout_pieces lower.set_dim("nP", parser_maxout_pieces)
if depth == 1: if depth == 1:
with Model.use_device('cpu'): with use_ops('numpy'):
upper = Affine(nr_class, hidden_width, drop_factor=0.0) upper = Linear(nr_class, hidden_width, init_W=zero_init)
upper.W *= 0
else: else:
upper = None upper = None
@ -102,11 +105,13 @@ cdef class Parser:
'bilstm_depth': bilstm_depth, 'bilstm_depth': bilstm_depth,
'self_attn_depth': self_attn_depth, 'self_attn_depth': self_attn_depth,
'conv_depth': conv_depth, 'conv_depth': conv_depth,
'conv_window': conv_window, 'window_size': window_size,
'embed_size': embed_size, 'embed_size': embed_size,
'cnn_maxout_pieces': t2v_pieces 'cnn_maxout_pieces': t2v_pieces
} }
return ParserModel(tok2vec, lower, upper), cfg model = ParserModel(tok2vec, lower, upper)
model.initialize()
return model, cfg
name = 'base_parser' name = 'base_parser'
@ -283,12 +288,13 @@ cdef class Parser:
def greedy_parse(self, docs, drop=0.): def greedy_parse(self, docs, drop=0.):
cdef vector[StateC*] states cdef vector[StateC*] states
cdef StateClass state cdef StateClass state
set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs) batch = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch, # This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to # if labels are missing. We therefore have to check whether we need to
# expand our model output. # expand our model output.
self._resize() self._resize()
model = self.model(docs) model = self.model.predict(docs)
weights = get_c_weights(model) weights = get_c_weights(model)
for state in batch: for state in batch:
if not state.is_final(): if not state.is_final():
@ -303,18 +309,19 @@ cdef class Parser:
cdef Beam beam cdef Beam beam
cdef Doc doc cdef Doc doc
cdef np.ndarray token_ids cdef np.ndarray token_ids
set_dropout_rate(self.model, drop)
beams = self.moves.init_beams(docs, beam_width, beam_density=beam_density) beams = self.moves.init_beams(docs, beam_width, beam_density=beam_density)
# This is pretty dirty, but the NER can resize itself in init_batch, # This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to # if labels are missing. We therefore have to check whether we need to
# expand our model output. # expand our model output.
self._resize() self._resize()
model = self.model(docs) model = self.model.predict(docs)
token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature), token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature),
dtype='i', order='C') dtype='i', order='C')
cdef int* c_ids cdef int* c_ids
cdef int nr_feature = self.cfg["nr_feature_tokens"] cdef int nr_feature = self.cfg["nr_feature_tokens"]
cdef int n_states cdef int n_states
model = self.model(docs) model = self.model.predict(docs)
todo = [beam for beam in beams if not beam.is_done] todo = [beam for beam in beams if not beam.is_done]
while todo: while todo:
token_ids.fill(-1) token_ids.fill(-1)
@ -331,8 +338,8 @@ cdef class Parser:
n_states += 1 n_states += 1
if n_states == 0: if n_states == 0:
break break
vectors = model.state2vec(token_ids[:n_states]) vectors = model.state2vec.predict(token_ids[:n_states])
scores = model.vec2scores(vectors) scores = model.vec2scores.predict(vectors)
todo = self.transition_beams(todo, scores) todo = self.transition_beams(todo, scores)
return beams return beams
@ -424,7 +431,7 @@ cdef class Parser:
beam.check_done(_beam_utils.check_final_state, NULL) beam.check_done(_beam_utils.check_final_state, NULL)
return [b for b in beams if not b.is_done] return [b for b in beams if not b.is_done]
def update(self, examples, drop=0., sgd=None, losses=None): def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
self.require_model() self.require_model()
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
@ -438,8 +445,10 @@ cdef class Parser:
beam_update_prob = self.cfg.get('beam_update_prob', 0.5) beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob: if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(examples, self.cfg.get('beam_width', 1), return self.update_beam(examples, self.cfg.get('beam_width', 1),
drop=drop, sgd=sgd, losses=losses, drop=drop, sgd=sgd, losses=losses, set_annotations=set_annotations,
beam_density=self.cfg.get('beam_density', 0.001)) beam_density=self.cfg.get('beam_density', 0.001))
set_dropout_rate(self.model, drop)
# Chop sequences into lengths of this many transitions, to make the # Chop sequences into lengths of this many transitions, to make the
# batch uniform length. # batch uniform length.
cut_gold = numpy.random.choice(range(20, 100)) cut_gold = numpy.random.choice(range(20, 100))
@ -448,19 +457,24 @@ cdef class Parser:
if not s.is_final() and g is not None] if not s.is_final() and g is not None]
# Prepare the stepwise model, and get the callback for finishing the batch # Prepare the stepwise model, and get the callback for finishing the batch
model, finish_update = self.model.begin_update([ex.doc for ex in examples], drop=drop) model, backprop_tok2vec = self.model.begin_update([ex.doc for ex in examples])
all_states = list(states)
for _ in range(max_steps): for _ in range(max_steps):
if not states_golds: if not states_golds:
break break
states, golds = zip(*states_golds) states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states, drop=drop) scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses) d_scores = self.get_batch_loss(states, golds, scores, losses)
backprop(d_scores, sgd=sgd) backprop(d_scores)
# Follow the predicted action # Follow the predicted action
self.transition_states(states, scores) self.transition_states(states, scores)
states_golds = [eg for eg in states_golds if not eg[0].is_final()] states_golds = [eg for eg in states_golds if not eg[0].is_final()]
# Do the backprop backprop_tok2vec(golds)
finish_update(golds, sgd=sgd) if sgd is not None:
self.model.finish_update(sgd)
if set_annotations:
docs = [ex.doc for ex in examples]
self.set_annotations(docs, all_states)
return losses return losses
def rehearse(self, examples, sgd=None, losses=None, **cfg): def rehearse(self, examples, sgd=None, losses=None, **cfg):
@ -482,13 +496,15 @@ cdef class Parser:
# expand our model output. # expand our model output.
self._resize() self._resize()
# Prepare the stepwise model, and get the callback for finishing the batch # Prepare the stepwise model, and get the callback for finishing the batch
tutor, _ = self._rehearsal_model.begin_update(docs, drop=0.0) set_dropout_rate(self._rehearsal_model, 0.0)
model, finish_update = self.model.begin_update(docs, drop=0.0) set_dropout_rate(self.model, 0.0)
tutor, _ = self._rehearsal_model.begin_update(docs)
model, finish_update = self.model.begin_update(docs)
n_scores = 0. n_scores = 0.
loss = 0. loss = 0.
while states: while states:
targets, _ = tutor.begin_update(states, drop=0.) targets, _ = tutor.begin_update(states)
guesses, backprop = model.begin_update(states, drop=0.) guesses, backprop = model.begin_update(states)
d_scores = (guesses - targets) / targets.shape[0] d_scores = (guesses - targets) / targets.shape[0]
# If all weights for an output are 0 in the original model, don't # If all weights for an output are 0 in the original model, don't
# supervise that output. This allows us to add classes. # supervise that output. This allows us to add classes.
@ -499,12 +515,14 @@ cdef class Parser:
states = [state for state in states if not state.is_final()] states = [state for state in states if not state.is_final()]
n_scores += d_scores.size n_scores += d_scores.size
# Do the backprop # Do the backprop
finish_update(docs, sgd=sgd) finish_update(docs)
if sgd is not None:
self.model.finish_update(sgd)
losses[self.name] += loss / n_scores losses[self.name] += loss / n_scores
return losses return losses
def update_beam(self, examples, width, drop=0., sgd=None, losses=None, def update_beam(self, examples, width, drop=0., sgd=None, losses=None,
beam_density=0.0): set_annotations=False, beam_density=0.0):
examples = Example.to_example_objects(examples) examples = Example.to_example_objects(examples)
docs = [ex.doc for ex in examples] docs = [ex.doc for ex in examples]
golds = [ex.gold for ex in examples] golds = [ex.gold for ex in examples]
@ -514,15 +532,16 @@ cdef class Parser:
for gold in golds: for gold in golds:
self.moves.preprocess_gold(gold) self.moves.preprocess_gold(gold)
new_golds.append(gold) new_golds.append(gold)
model, finish_update = self.model.begin_update(docs, drop=drop) set_dropout_rate(self.model, drop)
model, backprop_tok2vec = self.model.begin_update(docs)
states_d_scores, backprops, beams = _beam_utils.update_beam( states_d_scores, backprops, beams = _beam_utils.update_beam(
self.moves, self.cfg["nr_feature_tokens"], 10000, states, golds, model.state2vec, self.moves, self.cfg["nr_feature_tokens"], 10000, states, golds,
model.vec2scores, width, drop=drop, losses=losses, model.state2vec, model.vec2scores, width, losses=losses,
beam_density=beam_density) beam_density=beam_density)
for i, d_scores in enumerate(states_d_scores): for i, d_scores in enumerate(states_d_scores):
losses[self.name] += (d_scores**2).mean() losses[self.name] += (d_scores**2).mean()
ids, bp_vectors, bp_scores = backprops[i] ids, bp_vectors, bp_scores = backprops[i]
d_vector = bp_scores(d_scores, sgd=sgd) d_vector = bp_scores(d_scores)
if isinstance(model.ops, CupyOps) \ if isinstance(model.ops, CupyOps) \
and not isinstance(ids, model.state2vec.ops.xp.ndarray): and not isinstance(ids, model.state2vec.ops.xp.ndarray):
model.backprops.append(( model.backprops.append((
@ -531,11 +550,34 @@ cdef class Parser:
bp_vectors)) bp_vectors))
else: else:
model.backprops.append((ids, d_vector, bp_vectors)) model.backprops.append((ids, d_vector, bp_vectors))
model.make_updates(sgd) backprop_tok2vec(golds)
if sgd is not None:
self.model.finish_update(sgd)
if set_annotations:
self.set_annotations(docs, beams)
cdef Beam beam cdef Beam beam
for beam in beams: for beam in beams:
_beam_utils.cleanup_beam(beam) _beam_utils.cleanup_beam(beam)
def get_gradients(self):
"""Get non-zero gradients of the model's parameters, as a dictionary
keyed by the parameter ID. The values are (weights, gradients) tuples.
"""
gradients = {}
if self.model in (None, True, False):
return gradients
queue = [self.model]
seen = set()
for node in queue:
if node.id in seen:
continue
seen.add(node.id)
if hasattr(node, "_mem") and node._mem.gradient.any():
gradients[node.id] = [node._mem.weights, node._mem.gradient]
if hasattr(node, "_layers"):
queue.extend(node._layers)
return gradients
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500): def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long """Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,
@ -605,8 +647,7 @@ cdef class Parser:
return d_scores return d_scores
def create_optimizer(self): def create_optimizer(self):
return create_default_optimizer(self.model.ops, return create_default_optimizer()
**self.cfg.get('optimizer', {}))
def begin_training(self, get_examples, pipeline=None, sgd=None, **cfg): def begin_training(self, get_examples, pipeline=None, sgd=None, **cfg):
if 'model' in cfg: if 'model' in cfg:
@ -636,14 +677,16 @@ cdef class Parser:
for doc, gold in parses: for doc, gold in parses:
doc_sample.append(doc) doc_sample.append(doc)
gold_sample.append(gold) gold_sample.append(gold)
self.model.begin_training(doc_sample, gold_sample) self.model.initialize(doc_sample, gold_sample)
if pipeline is not None: if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **cfg) self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
else: else:
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
self.model.begin_training([]) if self.model.upper.has_dim("nO") is None:
self.model.upper.set_dim("nO", self.moves.n_moves)
self.model.initialize()
self.cfg.update(cfg) self.cfg.update(cfg)
return sgd return sgd
@ -709,7 +752,7 @@ cdef class Parser:
if 'model' not in exclude: if 'model' not in exclude:
# TODO: Remove this once we don't have to handle previous models # TODO: Remove this once we don't have to handle previous models
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg: if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name self.cfg['pretrained_vectors'] = self.vocab.vectors
if self.model is True: if self.model is True:
self.model, cfg = self.Model(**self.cfg) self.model, cfg = self.Model(**self.cfg)
else: else:

View File

@ -1,7 +1,6 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t
from ..typedefs cimport attr_t from ..typedefs cimport attr_t, weight_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ..gold cimport GoldParse from ..gold cimport GoldParse
from ..gold cimport GoldParseC from ..gold cimport GoldParseC

View File

@ -1,7 +1,7 @@
# cython: infer_types=True # cython: infer_types=True
from cpython.ref cimport Py_INCREF from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t from ..typedefs cimport weight_t
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from collections import Counter from collections import Counter
import srsly import srsly

View File

@ -1,6 +1,6 @@
import pytest import pytest
from thinc.neural.optimizers import Adam from thinc.optimizers import Adam
from thinc.neural.ops import NumpyOps from thinc.backends import NumpyOps
from spacy.attrs import NORM from spacy.attrs import NORM
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -28,7 +28,7 @@ def _train_parser(parser):
fix_random_seed(1) fix_random_seed(1)
parser.add_label("left") parser.add_label("left")
parser.begin_training([], **parser.cfg) parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001) sgd = Adam(0.001, ops=NumpyOps())
for i in range(5): for i in range(5):
losses = {} losses = {}
@ -41,8 +41,8 @@ def _train_parser(parser):
def test_add_label(parser): def test_add_label(parser):
parser = _train_parser(parser) parser = _train_parser(parser)
parser.add_label("right") parser.add_label("right")
sgd = Adam(NumpyOps(), 0.001) sgd = Adam(0.001, ops=NumpyOps())
for i in range(10): for i in range(100):
losses = {} losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse( gold = GoldParse(

View File

@ -7,6 +7,11 @@ from spacy.syntax.ner import BiluoPushDown
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.tokens import Doc from spacy.tokens import Doc
TRAIN_DATA = [
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
("I like London and Berlin.", {"entities": [(7, 13, "LOC"), (18, 24, "LOC")]}),
]
@pytest.fixture @pytest.fixture
def vocab(): def vocab():
@ -263,7 +268,7 @@ def test_change_number_features():
nlp.add_pipe(ner) nlp.add_pipe(ner)
ner.add_label("PERSON") ner.add_label("PERSON")
nlp.begin_training() nlp.begin_training()
assert ner.model.lower.nF == ner.nr_feature assert ner.model.lower.get_dim("nF") == ner.nr_feature
# Test we can change it # Test we can change it
nlp = English() nlp = English()
ner = nlp.create_pipe("ner") ner = nlp.create_pipe("ner")
@ -272,11 +277,36 @@ def test_change_number_features():
nlp.begin_training( nlp.begin_training(
component_cfg={"ner": {"nr_feature_tokens": 3, "token_vector_width": 128}} component_cfg={"ner": {"nr_feature_tokens": 3, "token_vector_width": 128}}
) )
assert ner.model.lower.nF == 3 assert ner.model.lower.get_dim("nF") == 3
# Test the model runs # Test the model runs
nlp("hello world") nlp("hello world")
def test_overfitting():
# Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly
nlp = English()
ner = nlp.create_pipe("ner")
for _, annotations in TRAIN_DATA:
for ent in annotations.get("entities"):
ner.add_label(ent[2])
nlp.add_pipe(ner)
optimizer = nlp.begin_training()
for i in range(50):
losses = {}
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
assert losses["ner"] < 0.00001
# test the trained model
test_text = "I like London."
doc = nlp(test_text)
ents = doc.ents
assert len(ents) == 1
assert ents[0].text == "London"
assert ents[0].label_ == "LOC"
class BlockerComponent1(object): class BlockerComponent1(object):
name = "my_blocker" name = "my_blocker"

View File

@ -1,5 +1,5 @@
import pytest import pytest
from spacy._ml import Tok2Vec from spacy.ml.component_models import Tok2Vec
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.nn_parser import Parser from spacy.syntax.nn_parser import Parser
@ -20,7 +20,9 @@ def arc_eager(vocab):
@pytest.fixture @pytest.fixture
def tok2vec(): def tok2vec():
return Tok2Vec(8, 100) tok2vec = Tok2Vec(8, 100)
tok2vec.initialize()
return tok2vec
@pytest.fixture @pytest.fixture
@ -30,7 +32,7 @@ def parser(vocab, arc_eager):
@pytest.fixture @pytest.fixture
def model(arc_eager, tok2vec): def model(arc_eager, tok2vec):
return Parser.Model(arc_eager.n_moves, token_vector_width=tok2vec.nO)[0] return Parser.Model(arc_eager.n_moves, token_vector_width=tok2vec.get_dim("nO"))[0]
@pytest.fixture @pytest.fixture
@ -53,7 +55,7 @@ def test_build_model(parser):
def test_predict_doc(parser, tok2vec, model, doc): def test_predict_doc(parser, tok2vec, model, doc):
doc.tensor = tok2vec([doc])[0] doc.tensor = tok2vec.predict([doc])[0]
parser.model = model parser.model = model
parser(doc) parser(doc)
@ -61,8 +63,9 @@ def test_predict_doc(parser, tok2vec, model, doc):
def test_update_doc(parser, model, doc, gold): def test_update_doc(parser, model, doc, gold):
parser.model = model parser.model = model
def optimize(weights, gradient, key=None): def optimize(key, weights, gradient):
weights -= 0.001 * gradient weights -= 0.001 * gradient
return weights, gradient
parser.update((doc, gold), sgd=optimize) parser.update((doc, gold), sgd=optimize)

View File

@ -1,7 +1,25 @@
import pytest import pytest
from spacy.lang.en import English
from ..util import get_doc, apply_transition_sequence from ..util import get_doc, apply_transition_sequence
TRAIN_DATA = [
(
"They trade mortgage-backed securities.",
{
"heads": [1, 1, 4, 4, 5, 1, 1],
"deps": ["nsubj", "ROOT", "compound", "punct", "nmod", "dobj", "punct"],
},
),
(
"I like London and Berlin.",
{
"heads": [1, 1, 1, 2, 2, 1],
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
},
),
]
def test_parser_root(en_tokenizer): def test_parser_root(en_tokenizer):
text = "i don't have other assistance" text = "i don't have other assistance"
@ -162,3 +180,27 @@ def test_parser_set_sent_starts(en_vocab):
for sent in doc.sents: for sent in doc.sents:
for token in sent: for token in sent:
assert token.head in sent assert token.head in sent
def test_overfitting():
# Simple test to try and quickly overfit the dependency parser - ensuring the ML models work correctly
nlp = English()
parser = nlp.create_pipe("parser")
for _, annotations in TRAIN_DATA:
for dep in annotations.get("deps", []):
parser.add_label(dep)
nlp.add_pipe(parser)
optimizer = nlp.begin_training()
for i in range(50):
losses = {}
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
assert losses["parser"] < 0.00001
# test the trained model
test_text = "I like securities."
doc = nlp(test_text)
assert doc[0].dep_ is "nsubj"
assert doc[2].dep_ is "dobj"
assert doc[3].dep_ is "punct"

View File

@ -1,6 +1,6 @@
import pytest import pytest
from thinc.neural.optimizers import Adam from thinc.optimizers import Adam
from thinc.neural.ops import NumpyOps from thinc.backends import NumpyOps
from spacy.attrs import NORM from spacy.attrs import NORM
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -21,7 +21,7 @@ def parser(vocab):
# parser.add_label('right') # parser.add_label('right')
parser.add_label("left") parser.add_label("left")
parser.begin_training([], **parser.cfg) parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001) sgd = Adam(0.001)
for i in range(10): for i in range(10):
losses = {} losses = {}

View File

@ -1,4 +1,5 @@
import pytest import pytest
import srsly
from spacy.language import Language from spacy.language import Language
@ -8,3 +9,35 @@ def test_label_types():
nlp.get_pipe("tagger").add_label("A") nlp.get_pipe("tagger").add_label("A")
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.get_pipe("tagger").add_label(9) nlp.get_pipe("tagger").add_label(9)
TAG_MAP = {"N": {"pos": "NOUN"}, "V": {"pos": "VERB"}, "J": {"pos": "ADJ"}}
TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
("Eat blue ham", {"tags": ["V", "J", "N"]}),
]
def test_overfitting():
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
nlp = Language()
tagger = nlp.create_pipe("tagger")
for tag, values in TAG_MAP.items():
tagger.add_label(tag, values)
nlp.add_pipe(tagger)
optimizer = nlp.begin_training()
for i in range(50):
losses = {}
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
assert losses["tagger"] < 0.00001
# test the trained model
test_text = "I like blue eggs"
doc = nlp(test_text)
assert doc[0].tag_ is "N"
assert doc[1].tag_ is "V"
assert doc[2].tag_ is "J"
assert doc[3].tag_ is "N"

View File

@ -6,6 +6,11 @@ from spacy.pipeline import TextCategorizer
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.gold import GoldParse from spacy.gold import GoldParse
TRAIN_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
]
@pytest.mark.skip(reason="Test is flakey when run with others") @pytest.mark.skip(reason="Test is flakey when run with others")
def test_simple_train(): def test_simple_train():
@ -67,3 +72,26 @@ def test_label_types():
nlp.get_pipe("textcat").add_label("answer") nlp.get_pipe("textcat").add_label("answer")
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.get_pipe("textcat").add_label(9) nlp.get_pipe("textcat").add_label(9)
def test_overfitting():
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
nlp = Language()
textcat = nlp.create_pipe("textcat")
for _, annotations in TRAIN_DATA:
for label, value in annotations.get("cats").items():
textcat.add_label(label)
nlp.add_pipe(textcat)
optimizer = nlp.begin_training()
for i in range(50):
losses = {}
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
assert losses["textcat"] < 0.00001
# test the trained model
test_text = "I am happy."
doc = nlp(test_text)
cats = doc.cats
assert cats["POSITIVE"] > 0.9
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.001)

View File

@ -8,7 +8,7 @@ from spacy.matcher import Matcher
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.compat import pickle from spacy.compat import pickle
from spacy._ml import link_vectors_to_models from spacy.util import link_vectors_to_models
import numpy import numpy
import random import random

View File

@ -32,7 +32,7 @@ def test_issue3611():
# training the network # training the network
with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]): with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]):
optimizer = nlp.begin_training() optimizer = nlp.begin_training(X=x_train, Y=y_train)
for i in range(3): for i in range(3):
losses = {} losses = {}
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))

View File

@ -1,12 +1,12 @@
import pytest import pytest
from spacy import registry from spacy import registry
from thinc.v2v import Affine from thinc.layers import Linear
from catalogue import RegistryError from catalogue import RegistryError
@registry.architectures.register("my_test_function") @registry.architectures.register("my_test_function")
def create_model(nr_in, nr_out): def create_model(nr_in, nr_out):
return Affine(nr_in, nr_out) return Linear(nr_in, nr_out)
def test_get_architecture(): def test_get_architecture():

View File

@ -5,7 +5,8 @@ from pathlib import Path
from spacy import util from spacy import util
from spacy import prefer_gpu, require_gpu from spacy import prefer_gpu, require_gpu
from spacy.compat import symlink_to, symlink_remove, is_windows from spacy.compat import symlink_to, symlink_remove, is_windows
from spacy._ml import PrecomputableAffine from spacy.ml._layers import PrecomputableAffine
from spacy.ml._layers import _backprop_precomputable_affine_padding
from subprocess import CalledProcessError from subprocess import CalledProcessError
@ -67,28 +68,30 @@ def test_util_get_package_path(package):
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2): def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP) model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP)
assert model.W.shape == (nF, nO, nP, nI) assert model.get_param("W").shape == (nF, nO, nP, nI)
tensor = model.ops.allocate((10, nI)) tensor = model.ops.alloc((10, nI))
Y, get_dX = model.begin_update(tensor) Y, get_dX = model.begin_update(tensor)
assert Y.shape == (tensor.shape[0] + 1, nF, nO, nP) assert Y.shape == (tensor.shape[0] + 1, nF, nO, nP)
assert model.d_pad.shape == (1, nF, nO, nP) dY = model.ops.alloc((15, nO, nP))
dY = model.ops.allocate((15, nO, nP)) ids = model.ops.alloc((15, nF))
ids = model.ops.allocate((15, nF))
ids[1, 2] = -1 ids[1, 2] = -1
dY[1] = 1 dY[1] = 1
assert model.d_pad[0, 2, 0, 0] == 0.0 assert not model.has_grad("pad")
model._backprop_padding(dY, ids) d_pad = _backprop_precomputable_affine_padding(model, dY, ids)
assert model.d_pad[0, 2, 0, 0] == 1.0 assert d_pad[0, 2, 0, 0] == 1.0
model.d_pad.fill(0.0)
ids.fill(0.0) ids.fill(0.0)
dY.fill(0.0) dY.fill(0.0)
ids[1, 2] = -1 dY[0] = 0
ids[1, 2] = 0
ids[1, 1] = -1 ids[1, 1] = -1
ids[1, 0] = -1 ids[1, 0] = -1
dY[1] = 1 dY[1] = 1
assert model.d_pad[0, 2, 0, 0] == 0.0 ids[2, 0] = -1
model._backprop_padding(dY, ids) dY[2] = 5
assert model.d_pad[0, 2, 0, 0] == 3.0 d_pad = _backprop_precomputable_affine_padding(model, dY, ids)
assert d_pad[0, 0, 0, 0] == 6
assert d_pad[0, 1, 0, 0] == 1
assert d_pad[0, 2, 0, 0] == 0
def test_prefer_gpu(): def test_prefer_gpu():

View File

@ -1,6 +1,6 @@
import pytest import pytest
from spacy._ml import Tok2Vec from spacy.ml.component_models import Tok2Vec
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokens import Doc from spacy.tokens import Doc
@ -10,7 +10,7 @@ def get_batch(batch_size):
docs = [] docs = []
start = 0 start = 0
for size in range(1, batch_size + 1): for size in range(1, batch_size + 1):
# Make the words numbers, so that they're distnct # Make the words numbers, so that they're distinct
# across the batch, and easy to track. # across the batch, and easy to track.
numbers = [str(i) for i in range(start, start + size)] numbers = [str(i) for i in range(start, start + size)]
docs.append(Doc(vocab, words=numbers)) docs.append(Doc(vocab, words=numbers))
@ -37,6 +37,7 @@ def test_empty_doc():
def test_tok2vec_batch_sizes(batch_size, width, embed_size): def test_tok2vec_batch_sizes(batch_size, width, embed_size):
batch = get_batch(batch_size) batch = get_batch(batch_size)
tok2vec = Tok2Vec(width, embed_size) tok2vec = Tok2Vec(width, embed_size)
tok2vec.initialize()
vectors, backprop = tok2vec.begin_update(batch) vectors, backprop = tok2vec.begin_update(batch)
assert len(vectors) == len(batch) assert len(vectors) == len(batch)
for doc_vec, doc in zip(vectors, batch): for doc_vec, doc in zip(vectors, batch):
@ -56,6 +57,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
def test_tok2vec_configs(tok2vec_config): def test_tok2vec_configs(tok2vec_config):
docs = get_batch(3) docs = get_batch(3)
tok2vec = Tok2Vec(**tok2vec_config) tok2vec = Tok2Vec(**tok2vec_config)
tok2vec.initialize()
vectors, backprop = tok2vec.begin_update(docs) vectors, backprop = tok2vec.begin_update(docs)
assert len(vectors) == len(docs) assert len(vectors) == len(docs)
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"]) assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])

View File

@ -1,14 +1,13 @@
import pytest import pytest
import numpy import numpy
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
from spacy._ml import cosine
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.vectors import Vectors from spacy.vectors import Vectors
from spacy.tokenizer import Tokenizer from spacy.tokenizer import Tokenizer
from spacy.strings import hash_string from spacy.strings import hash_string
from spacy.tokens import Doc from spacy.tokens import Doc
from ..util import add_vecs_to_vocab from ..util import add_vecs_to_vocab, get_cosine
@pytest.fixture @pytest.fixture
@ -311,4 +310,4 @@ def test_vocab_prune_vectors():
assert list(remap.keys()) == ["kitten"] assert list(remap.keys()) == ["kitten"]
neighbour, similarity = list(remap.values())[0] neighbour, similarity = list(remap.values())[0]
assert neighbour == "cat", remap assert neighbour == "cat", remap
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-4, rtol=1e-3) assert_allclose(similarity, get_cosine(data[0], data[2]), atol=1e-4, rtol=1e-3)

View File

@ -4,7 +4,7 @@
from libc.string cimport memcpy, memset from libc.string cimport memcpy, memset
from libc.stdlib cimport malloc, free from libc.stdlib cimport malloc, free
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from thinc.neural.util import get_array_module from thinc.util import get_array_module
import numpy import numpy

View File

@ -1,7 +1,7 @@
import numpy import numpy
import zlib import zlib
import srsly import srsly
from thinc.neural.ops import NumpyOps from thinc.backends import NumpyOps
from ..compat import copy_reg from ..compat import copy_reg
from ..tokens import Doc from ..tokens import Doc

View File

@ -11,7 +11,7 @@ import numpy
import numpy.linalg import numpy.linalg
import struct import struct
import srsly import srsly
from thinc.neural.util import get_array_module, copy_array from thinc.util import get_array_module, copy_array
from .span cimport Span from .span cimport Span
from .token cimport Token from .token cimport Token

View File

@ -3,7 +3,7 @@ from libc.math cimport sqrt
import numpy import numpy
import numpy.linalg import numpy.linalg
from thinc.neural.util import get_array_module from thinc.util import get_array_module
from collections import defaultdict from collections import defaultdict
from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix

View File

@ -7,7 +7,7 @@ cimport numpy as np
np.import_array() np.import_array()
import numpy import numpy
from thinc.neural.util import get_array_module from thinc.util import get_array_module
from ..typedefs cimport hash_t from ..typedefs cimport hash_t
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme

View File

@ -2,7 +2,9 @@ from libc.stdint cimport uint16_t, uint32_t, uint64_t, uintptr_t, int32_t
from libc.stdint cimport uint8_t from libc.stdint cimport uint8_t
ctypedef float weight_t
ctypedef uint64_t hash_t ctypedef uint64_t hash_t
ctypedef uint64_t class_t
ctypedef char* utf8_t ctypedef char* utf8_t
ctypedef uint64_t attr_t ctypedef uint64_t attr_t
ctypedef uint64_t flags_t ctypedef uint64_t flags_t

View File

@ -4,8 +4,14 @@ import importlib.util
import re import re
from pathlib import Path from pathlib import Path
import random import random
from thinc.neural._classes.model import Model from typing import List
from thinc.neural.ops import NumpyOps
import thinc
import thinc.config
from thinc.backends import NumpyOps, get_current_ops
from thinc.optimizers import Adam
from thinc.util import require_gpu
import functools import functools
import itertools import itertools
import numpy.random import numpy.random
@ -13,6 +19,7 @@ import srsly
import catalogue import catalogue
import sys import sys
try: try:
import cupy.random import cupy.random
except ImportError: except ImportError:
@ -20,14 +27,13 @@ except ImportError:
from .symbols import ORTH from .symbols import ORTH
from .compat import cupy, CudaStream from .compat import cupy, CudaStream
from .errors import Errors, Warnings, deprecation_warning from .errors import Errors, Warnings, deprecation_warning, user_warning
_data_path = Path(__file__).parent / "data" _data_path = Path(__file__).parent / "data"
_PRINT_ENV = False _PRINT_ENV = False
class registry(object): class registry(thinc.registry):
languages = catalogue.create("spacy", "languages", entry_points=True) languages = catalogue.create("spacy", "languages", entry_points=True)
architectures = catalogue.create("spacy", "architectures", entry_points=True) architectures = catalogue.create("spacy", "architectures", entry_points=True)
lookups = catalogue.create("spacy", "lookups", entry_points=True) lookups = catalogue.create("spacy", "lookups", entry_points=True)
@ -219,6 +225,23 @@ def load_model_from_init_py(init_file, **overrides):
return load_model_from_path(data_path, meta, **overrides) return load_model_from_path(data_path, meta, **overrides)
def load_from_config(path, create_objects=False):
"""Load a Thinc-formatted config file, optionally filling in objects where
the config references registry entries. See "Thinc config files" for details.
path (unicode or Path): Path to the config file
create_objects (bool): Whether to automatically create objects when the config
references registry entries. Defaults to False.
RETURNS (dict): The objects from the config file.
"""
config = thinc.config.Config().from_disk(path)
if create_objects:
return registry.make_from_config(config, validate=True)
else:
return config
def get_model_meta(path): def get_model_meta(path):
"""Get model meta.json from a directory path and validate its contents. """Get model meta.json from a directory path and validate its contents.
@ -293,9 +316,10 @@ def get_component_name(component):
def get_cuda_stream(require=False, non_blocking=True): def get_cuda_stream(require=False, non_blocking=True):
ops = get_current_ops()
if CudaStream is None: if CudaStream is None:
return None return None
elif isinstance(Model.ops, NumpyOps): elif isinstance(ops, NumpyOps):
return None return None
else: else:
return CudaStream(non_blocking=non_blocking) return CudaStream(non_blocking=non_blocking)
@ -310,6 +334,14 @@ def get_async(stream, numpy_array):
return array return array
def eg2doc(example):
"""Get a Doc object from an Example (or if it's a Doc, use it directly)"""
# Put the import here to avoid circular import problems
from .tokens.doc import Doc
return example if isinstance(example, Doc) else example.doc
def env_opt(name, default=None): def env_opt(name, default=None):
if type(default) is float: if type(default) is float:
type_convert = float type_convert = float
@ -532,6 +564,8 @@ def minibatch_by_words(examples, size, tuples=True, count_words=len):
"""Create minibatches of a given number of words.""" """Create minibatches of a given number of words."""
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
if isinstance(size, List):
size_ = iter(size)
else: else:
size_ = size size_ = size
examples = iter(examples) examples = iter(examples)
@ -680,17 +714,7 @@ def escape_html(text):
def use_gpu(gpu_id): def use_gpu(gpu_id):
try: return require_gpu(gpu_id)
import cupy.cuda.device
except ImportError:
return None
from thinc.neural.ops import CupyOps
device = cupy.cuda.device.Device(gpu_id)
device.use()
Model.ops = CupyOps()
Model.Ops = CupyOps
return device
def fix_random_seed(seed=0): def fix_random_seed(seed=0):
@ -747,3 +771,33 @@ class DummyTokenizer(object):
def from_disk(self, _path, **kwargs): def from_disk(self, _path, **kwargs):
return self return self
def link_vectors_to_models(vocab):
vectors = vocab.vectors
if vectors.name is None:
vectors.name = VECTORS_KEY
if vectors.data.size != 0:
user_warning(Warnings.W020.format(shape=vectors.data.shape))
for word in vocab:
if word.orth in vectors.key2row:
word.rank = vectors.key2row[word.orth]
else:
word.rank = 0
VECTORS_KEY = "spacy_pretrained_vectors"
def create_default_optimizer():
ops = get_current_ops()
learn_rate = env_opt("learn_rate", 0.001)
beta1 = env_opt("optimizer_B1", 0.9)
beta2 = env_opt("optimizer_B2", 0.999)
eps = env_opt("optimizer_eps", 1e-8)
L2 = env_opt("L2_penalty", 1e-6)
max_grad_norm = env_opt("grad_norm_clip", 1.0)
optimizer = Adam(learn_rate, L2=L2, beta1=beta1, beta2=beta2, eps=eps, ops=ops)
optimizer.max_grad_norm = max_grad_norm
optimizer.device = ops.device_type
return optimizer

View File

@ -5,8 +5,8 @@ from libcpp.set cimport set as cppset
import functools import functools
import numpy import numpy
import srsly import srsly
from thinc.neural.util import get_array_module from thinc.util import get_array_module
from thinc.neural._classes.model import Model from thinc.backends import get_current_ops
from .strings cimport StringStore from .strings cimport StringStore
@ -426,9 +426,9 @@ cdef class Vectors:
self.add(key, row=i) self.add(key, row=i)
def load_vectors(path): def load_vectors(path):
xp = Model.ops.xp ops = get_current_ops()
if path.exists(): if path.exists():
self.data = xp.load(str(path)) self.data = ops.xp.load(str(path))
serializers = { serializers = {
"key2row": load_key2row, "key2row": load_key2row,

View File

@ -2,7 +2,7 @@
from libc.string cimport memcpy from libc.string cimport memcpy
import srsly import srsly
from thinc.neural.util import get_array_module from thinc.util import get_array_module
from .lexeme cimport EMPTY_LEXEME from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport Lexeme from .lexeme cimport Lexeme
@ -16,7 +16,7 @@ from .errors import Errors
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .attrs import intify_attrs, NORM from .attrs import intify_attrs, NORM
from .vectors import Vectors from .vectors import Vectors
from ._ml import link_vectors_to_models from .util import link_vectors_to_models
from .lookups import Lookups from .lookups import Lookups
from . import util from . import util