mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Adapt parser and NER for transformers (#5449)
* Draft layer for BILUO actions * Fixes to biluo layer * WIP on BILUO layer * Add tests for BILUO layer * Format * Fix transitions * Update test * Link in the simple_ner * Update BILUO tagger * Update __init__ * Import simple_ner * Update test * Import * Add files * Add config * Fix label passing for BILUO and tagger * Fix label handling for simple_ner component * Update simple NER test * Update config * Hack train script * Update BILUO layer * Fix SimpleNER component * Update train_from_config * Add biluo_to_iob helper * Add IOB layer * Add IOBTagger model * Update biluo layer * Update SimpleNER tagger * Update BILUO * Read random seed in train-from-config * Update use of normal_init * Fix normalization of gradient in SimpleNER * Update IOBTagger * Remove print * Tweak masking in BILUO * Add dropout in SimpleNER * Update thinc * Tidy up simple_ner * Fix biluo model * Unhack train-from-config * Update setup.cfg and requirements * Add tb_framework.py for parser model * Try to avoid memory leak in BILUO * Move ParserModel into spacy.ml, avoid need for subclass. * Use updated parser model * Remove incorrect call to model.initializre in PrecomputableAffine * Update parser model * Avoid divide by zero in tagger * Add extra dropout layer in tagger * Refine minibatch_by_words function to avoid oom * Fix parser model after refactor * Try to avoid div-by-zero in SimpleNER * Fix infinite loop in minibatch_by_words * Use SequenceCategoricalCrossentropy in Tagger * Fix parser model when hidden layer * Remove extra dropout from tagger * Add extra nan check in tagger * Fix thinc version * Update tests and imports * Fix test * Update test * Update tests * Fix tests * Fix test Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
3100c97e69
commit
333b1a308b
|
@ -4,12 +4,18 @@ limit = 0
|
||||||
dropout = 0.2
|
dropout = 0.2
|
||||||
patience = 10000
|
patience = 10000
|
||||||
eval_frequency = 200
|
eval_frequency = 200
|
||||||
scores = ["ents_f"]
|
scores = ["ents_p", "ents_r", "ents_f"]
|
||||||
score_weights = {"ents_f": 1}
|
score_weights = {"ents_f": 1}
|
||||||
orth_variant_level = 0.0
|
orth_variant_level = 0.0
|
||||||
gold_preproc = true
|
gold_preproc = true
|
||||||
max_length = 0
|
max_length = 0
|
||||||
batch_size = 25
|
|
||||||
|
[training.batch_size]
|
||||||
|
@schedules = "compounding.v1"
|
||||||
|
start = 3000
|
||||||
|
stop = 3000
|
||||||
|
compound = 1.001
|
||||||
|
|
||||||
|
|
||||||
[optimizer]
|
[optimizer]
|
||||||
@optimizers = "Adam.v1"
|
@optimizers = "Adam.v1"
|
||||||
|
@ -21,45 +27,18 @@ beta2 = 0.999
|
||||||
lang = "en"
|
lang = "en"
|
||||||
vectors = null
|
vectors = null
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec]
|
|
||||||
factory = "tok2vec"
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model]
|
|
||||||
@architectures = "spacy.Tok2Vec.v1"
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model.extract]
|
|
||||||
@architectures = "spacy.Doc2Feats.v1"
|
|
||||||
columns = ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model.embed]
|
|
||||||
@architectures = "spacy.MultiHashEmbed.v1"
|
|
||||||
columns = ${nlp.pipeline.tok2vec.model.extract:columns}
|
|
||||||
width = 96
|
|
||||||
rows = 2000
|
|
||||||
use_subwords = true
|
|
||||||
pretrained_vectors = null
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model.embed.mix]
|
|
||||||
@architectures = "spacy.LayerNormalizedMaxout.v1"
|
|
||||||
width = ${nlp.pipeline.tok2vec.model.embed:width}
|
|
||||||
maxout_pieces = 3
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model.encode]
|
|
||||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
|
||||||
width = ${nlp.pipeline.tok2vec.model.embed:width}
|
|
||||||
window_size = 1
|
|
||||||
maxout_pieces = 3
|
|
||||||
depth = 2
|
|
||||||
|
|
||||||
[nlp.pipeline.ner]
|
[nlp.pipeline.ner]
|
||||||
factory = "ner"
|
factory = "simple_ner"
|
||||||
|
|
||||||
[nlp.pipeline.ner.model]
|
[nlp.pipeline.ner.model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.BiluoTagger.v1"
|
||||||
nr_feature_tokens = 6
|
|
||||||
hidden_width = 64
|
|
||||||
maxout_pieces = 2
|
|
||||||
|
|
||||||
[nlp.pipeline.ner.model.tok2vec]
|
[nlp.pipeline.ner.model.tok2vec]
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
width = ${nlp.pipeline.tok2vec.model.embed:width}
|
width = 128
|
||||||
|
depth = 4
|
||||||
|
embed_size = 7000
|
||||||
|
maxout_pieces = 3
|
||||||
|
window_size = 1
|
||||||
|
subword_features = true
|
||||||
|
pretrained_vectors = null
|
||||||
|
|
|
@ -42,26 +42,28 @@ def main(model=None, output_dir=None, n_iter=100):
|
||||||
|
|
||||||
# create the built-in pipeline components and add them to the pipeline
|
# create the built-in pipeline components and add them to the pipeline
|
||||||
# nlp.create_pipe works for built-ins that are registered with spaCy
|
# nlp.create_pipe works for built-ins that are registered with spaCy
|
||||||
if "ner" not in nlp.pipe_names:
|
if "simple_ner" not in nlp.pipe_names:
|
||||||
ner = nlp.create_pipe("ner")
|
ner = nlp.create_pipe("simple_ner")
|
||||||
nlp.add_pipe(ner, last=True)
|
nlp.add_pipe(ner, last=True)
|
||||||
# otherwise, get it so we can add labels
|
# otherwise, get it so we can add labels
|
||||||
else:
|
else:
|
||||||
ner = nlp.get_pipe("ner")
|
ner = nlp.get_pipe("simple_ner")
|
||||||
|
|
||||||
# add labels
|
# add labels
|
||||||
for _, annotations in TRAIN_DATA:
|
for _, annotations in TRAIN_DATA:
|
||||||
for ent in annotations.get("entities"):
|
for ent in annotations.get("entities"):
|
||||||
|
print("Add label", ent[2])
|
||||||
ner.add_label(ent[2])
|
ner.add_label(ent[2])
|
||||||
|
|
||||||
# get names of other pipes to disable them during training
|
# get names of other pipes to disable them during training
|
||||||
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
|
pipe_exceptions = ["simple_ner"]
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train NER
|
with nlp.disable_pipes(*other_pipes): # only train NER
|
||||||
# reset and initialize the weights randomly – but only if we're
|
# reset and initialize the weights randomly – but only if we're
|
||||||
# training a new model
|
# training a new model
|
||||||
if model is None:
|
if model is None:
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
|
print("Transitions", list(enumerate(nlp.get_pipe("simple_ner").get_tag_names())))
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
random.shuffle(TRAIN_DATA)
|
random.shuffle(TRAIN_DATA)
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -70,7 +72,7 @@ def main(model=None, output_dir=None, n_iter=100):
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(
|
nlp.update(
|
||||||
batch,
|
batch,
|
||||||
drop=0.5, # dropout - make it harder to memorise data
|
drop=0.0, # dropout - make it harder to memorise data
|
||||||
losses=losses,
|
losses=losses,
|
||||||
)
|
)
|
||||||
print("Losses", losses)
|
print("Losses", losses)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from wasabi import msg
|
||||||
import thinc
|
import thinc
|
||||||
import thinc.schedules
|
import thinc.schedules
|
||||||
from thinc.api import Model
|
from thinc.api import Model
|
||||||
|
import random
|
||||||
|
|
||||||
from ..gold import GoldCorpus
|
from ..gold import GoldCorpus
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -119,6 +120,7 @@ class ConfigSchema(BaseModel):
|
||||||
output_path=("Output directory to store model in", "option", "o", Path),
|
output_path=("Output directory to store model in", "option", "o", Path),
|
||||||
meta_path=("Optional path to meta.json to use as base.", "option", "m", Path),
|
meta_path=("Optional path to meta.json to use as base.", "option", "m", Path),
|
||||||
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
||||||
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
)
|
)
|
||||||
def train_from_config_cli(
|
def train_from_config_cli(
|
||||||
|
@ -130,6 +132,7 @@ def train_from_config_cli(
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
debug=False,
|
debug=False,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
use_gpu=-1
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
||||||
|
@ -147,6 +150,12 @@ def train_from_config_cli(
|
||||||
if output_path is not None and not output_path.exists():
|
if output_path is not None and not output_path.exists():
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
|
|
||||||
|
if use_gpu >= 0:
|
||||||
|
msg.info("Using GPU")
|
||||||
|
util.use_gpu(use_gpu)
|
||||||
|
else:
|
||||||
|
msg.info("Using CPU")
|
||||||
|
|
||||||
train_from_config(
|
train_from_config(
|
||||||
config_path,
|
config_path,
|
||||||
{"train": train_path, "dev": dev_path},
|
{"train": train_path, "dev": dev_path},
|
||||||
|
@ -161,13 +170,8 @@ def train_from_config(
|
||||||
):
|
):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
|
util.fix_random_seed(config["training"]["seed"])
|
||||||
nlp_config = config["nlp"]
|
nlp_config = config["nlp"]
|
||||||
use_gpu = config["training"]["use_gpu"]
|
|
||||||
if use_gpu >= 0:
|
|
||||||
msg.info("Using GPU")
|
|
||||||
util.use_gpu(use_gpu)
|
|
||||||
else:
|
|
||||||
msg.info("Using CPU")
|
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
msg.info("Creating nlp from config")
|
msg.info("Creating nlp from config")
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
|
@ -177,7 +181,7 @@ def train_from_config(
|
||||||
msg.info("Loading training corpus")
|
msg.info("Loading training corpus")
|
||||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||||
msg.info("Initializing the nlp pipeline")
|
msg.info("Initializing the nlp pipeline")
|
||||||
nlp.begin_training(lambda: corpus.train_examples, device=use_gpu)
|
nlp.begin_training(lambda: corpus.train_examples)
|
||||||
|
|
||||||
train_batches = create_train_batches(nlp, corpus, training)
|
train_batches = create_train_batches(nlp, corpus, training)
|
||||||
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
|
||||||
|
@ -192,6 +196,7 @@ def train_from_config(
|
||||||
training["dropout"],
|
training["dropout"],
|
||||||
training["patience"],
|
training["patience"],
|
||||||
training["eval_frequency"],
|
training["eval_frequency"],
|
||||||
|
training["accumulate_gradient"]
|
||||||
)
|
)
|
||||||
|
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
|
@ -220,43 +225,50 @@ def train_from_config(
|
||||||
|
|
||||||
def create_train_batches(nlp, corpus, cfg):
|
def create_train_batches(nlp, corpus, cfg):
|
||||||
while True:
|
while True:
|
||||||
train_examples = corpus.train_dataset(
|
train_examples = list(corpus.train_dataset(
|
||||||
nlp,
|
nlp,
|
||||||
noise_level=0.0,
|
noise_level=0.0,
|
||||||
orth_variant_level=cfg["orth_variant_level"],
|
orth_variant_level=cfg["orth_variant_level"],
|
||||||
gold_preproc=cfg["gold_preproc"],
|
gold_preproc=cfg["gold_preproc"],
|
||||||
max_length=cfg["max_length"],
|
max_length=cfg["max_length"],
|
||||||
ignore_misaligned=True,
|
ignore_misaligned=True,
|
||||||
)
|
))
|
||||||
for batch in util.minibatch_by_words(train_examples, size=cfg["batch_size"]):
|
random.shuffle(train_examples)
|
||||||
|
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
||||||
|
for batch in batches:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
||||||
def evaluate():
|
def evaluate():
|
||||||
with nlp.use_params(optimizer.averages):
|
dev_examples = list(
|
||||||
dev_examples = list(
|
corpus.dev_dataset(
|
||||||
corpus.dev_dataset(
|
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
|
||||||
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
n_words = sum(len(ex.doc) for ex in dev_examples)
|
)
|
||||||
start_time = timer()
|
n_words = sum(len(ex.doc) for ex in dev_examples)
|
||||||
scorer = nlp.evaluate(dev_examples)
|
start_time = timer()
|
||||||
end_time = timer()
|
|
||||||
wps = n_words / (end_time - start_time)
|
if optimizer.averages:
|
||||||
scores = scorer.scores
|
with nlp.use_params(optimizer.averages):
|
||||||
# Calculate a weighted sum based on score_weights for the main score
|
scorer = nlp.evaluate(dev_examples, batch_size=32)
|
||||||
weights = cfg["score_weights"]
|
else:
|
||||||
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
|
scorer = nlp.evaluate(dev_examples, batch_size=32)
|
||||||
scores["speed"] = wps
|
end_time = timer()
|
||||||
|
wps = n_words / (end_time - start_time)
|
||||||
|
scores = scorer.scores
|
||||||
|
# Calculate a weighted sum based on score_weights for the main score
|
||||||
|
weights = cfg["score_weights"]
|
||||||
|
weighted_score = sum(scores[s] * weights.get(s, 0.0) for s in weights)
|
||||||
|
scores["speed"] = wps
|
||||||
return weighted_score, scores
|
return weighted_score, scores
|
||||||
|
|
||||||
return evaluate
|
return evaluate
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency
|
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency,
|
||||||
|
accumulate_gradient
|
||||||
):
|
):
|
||||||
"""Train until an evaluation stops improving. Works as a generator,
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
@ -303,7 +315,7 @@ def train_while_improving(
|
||||||
losses = {}
|
losses = {}
|
||||||
for step, batch in enumerate(train_data):
|
for step, batch in enumerate(train_data):
|
||||||
dropout = next(dropouts)
|
dropout = next(dropouts)
|
||||||
for subbatch in subdivide_batch(batch):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
|
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
if hasattr(proc, "model"):
|
if hasattr(proc, "model"):
|
||||||
|
@ -332,8 +344,19 @@ def train_while_improving(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def subdivide_batch(batch):
|
def subdivide_batch(batch, accumulate_gradient):
|
||||||
return [batch]
|
batch = list(batch)
|
||||||
|
batch.sort(key=lambda eg: len(eg.doc))
|
||||||
|
sub_len = len(batch) // accumulate_gradient
|
||||||
|
start = 0
|
||||||
|
for i in range(accumulate_gradient):
|
||||||
|
subbatch = batch[start : start + sub_len]
|
||||||
|
if subbatch:
|
||||||
|
yield subbatch
|
||||||
|
start += len(subbatch)
|
||||||
|
subbatch = batch[start : ]
|
||||||
|
if subbatch:
|
||||||
|
yield subbatch
|
||||||
|
|
||||||
|
|
||||||
def setup_printer(training, nlp):
|
def setup_printer(training, nlp):
|
||||||
|
|
|
@ -608,6 +608,14 @@ def iob_to_biluo(tags):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def biluo_to_iob(tags):
|
||||||
|
out = []
|
||||||
|
for tag in tags:
|
||||||
|
tag = tag.replace("U-", "B-", 1).replace("L-", "I-", 1)
|
||||||
|
out.append(tag)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _consume_os(tags):
|
def _consume_os(tags):
|
||||||
while tags and tags[0] == "O":
|
while tags and tags[0] == "O":
|
||||||
yield tags.pop(0)
|
yield tags.pop(0)
|
||||||
|
|
|
@ -195,6 +195,7 @@ class Language(object):
|
||||||
default_senter_config,
|
default_senter_config,
|
||||||
default_tensorizer_config,
|
default_tensorizer_config,
|
||||||
default_tok2vec_config,
|
default_tok2vec_config,
|
||||||
|
default_simple_ner_config
|
||||||
)
|
)
|
||||||
|
|
||||||
self.defaults = {
|
self.defaults = {
|
||||||
|
@ -205,6 +206,7 @@ class Language(object):
|
||||||
"entity_linker": default_nel_config(),
|
"entity_linker": default_nel_config(),
|
||||||
"morphologizer": default_morphologizer_config(),
|
"morphologizer": default_morphologizer_config(),
|
||||||
"senter": default_senter_config(),
|
"senter": default_senter_config(),
|
||||||
|
"simple_ner": default_simple_ner_config(),
|
||||||
"tensorizer": default_tensorizer_config(),
|
"tensorizer": default_tensorizer_config(),
|
||||||
"tok2vec": default_tok2vec_config(),
|
"tok2vec": default_tok2vec_config(),
|
||||||
}
|
}
|
||||||
|
|
109
spacy/ml/_biluo.py
Normal file
109
spacy/ml/_biluo.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
"""Thinc layer to do simpler transition-based parsing, NER, etc."""
|
||||||
|
from typing import List, Tuple, Dict, Optional
|
||||||
|
import numpy
|
||||||
|
from thinc.api import Ops, Model, with_array, softmax_activation, padded2list
|
||||||
|
from thinc.api import to_numpy
|
||||||
|
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
|
||||||
|
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
def BILUO() -> Model[Padded, Padded]:
|
||||||
|
return Model(
|
||||||
|
"biluo",
|
||||||
|
forward,
|
||||||
|
init=init,
|
||||||
|
dims={"nO": None},
|
||||||
|
attrs={"get_num_actions": get_num_actions}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init(model, X: Optional[Padded]=None, Y: Optional[Padded]=None):
|
||||||
|
if X is not None and Y is not None:
|
||||||
|
if X.data.shape != Y.data.shape:
|
||||||
|
# TODO: Fix error
|
||||||
|
raise ValueError("Mismatched shapes (TODO: Fix message)")
|
||||||
|
model.set_dim("nO", X.data.shape[2])
|
||||||
|
elif X is not None:
|
||||||
|
model.set_dim("nO", X.data.shape[2])
|
||||||
|
elif Y is not None:
|
||||||
|
model.set_dim("nO", Y.data.shape[2])
|
||||||
|
elif model.get_dim("nO") is None:
|
||||||
|
raise ValueError("Dimension unset for BILUO: nO")
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
|
||||||
|
n_labels = (model.get_dim("nO") - 1) // 4
|
||||||
|
n_tokens, n_docs, n_actions = Xp.data.shape
|
||||||
|
# At each timestep, we make a validity mask of shape (n_docs, n_actions)
|
||||||
|
# to indicate which actions are valid next for each sequence. To construct
|
||||||
|
# the mask, we have a state of shape (2, n_actions) and a validity table of
|
||||||
|
# shape (2, n_actions+1, n_actions). The first dimension of the state indicates
|
||||||
|
# whether it's the last token, the second dimension indicates the previous
|
||||||
|
# action, plus a special 'null action' for the first entry.
|
||||||
|
valid_transitions = model.ops.asarray(_get_transition_table(n_labels))
|
||||||
|
prev_actions = model.ops.alloc1i(n_docs)
|
||||||
|
# Initialize as though prev action was O
|
||||||
|
prev_actions.fill(n_actions - 1)
|
||||||
|
Y = model.ops.alloc3f(*Xp.data.shape)
|
||||||
|
masks = model.ops.alloc3f(*Y.shape)
|
||||||
|
max_value = Xp.data.max()
|
||||||
|
for t in range(Xp.data.shape[0]):
|
||||||
|
is_last = (Xp.lengths < (t+2)).astype("i")
|
||||||
|
masks[t] = valid_transitions[is_last, prev_actions]
|
||||||
|
# Don't train the out-of-bounds sequences.
|
||||||
|
masks[t, Xp.size_at_t[t]:] = 0
|
||||||
|
# Valid actions get 0*10e8, invalid get large negative value
|
||||||
|
Y[t] = Xp.data[t] + ((masks[t]-1) * max_value * 10)
|
||||||
|
prev_actions = Y[t].argmax(axis=-1)
|
||||||
|
|
||||||
|
def backprop_biluo(dY: Padded) -> Padded:
|
||||||
|
dY.data *= masks
|
||||||
|
return dY
|
||||||
|
|
||||||
|
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop_biluo
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_actions(n_labels: int) -> int:
|
||||||
|
# One BEGIN action per label
|
||||||
|
# One IN action per label
|
||||||
|
# One LAST action per label
|
||||||
|
# One UNIT action per label
|
||||||
|
# One OUT action
|
||||||
|
return n_labels + n_labels + n_labels + n_labels + 1
|
||||||
|
|
||||||
|
|
||||||
|
def _get_transition_table(
|
||||||
|
n_labels: int, *, _cache: Dict[int, Floats3d] = {}
|
||||||
|
) -> Floats3d:
|
||||||
|
n_actions = get_num_actions(n_labels)
|
||||||
|
if n_actions in _cache:
|
||||||
|
return _cache[n_actions]
|
||||||
|
table = numpy.zeros((2, n_actions, n_actions), dtype="f")
|
||||||
|
B_start, B_end = (0, n_labels)
|
||||||
|
I_start, I_end = (B_end, B_end + n_labels)
|
||||||
|
L_start, L_end = (I_end, I_end + n_labels)
|
||||||
|
U_start, U_end = (L_end, L_end + n_labels)
|
||||||
|
# Using ranges allows us to set specific cells, which is necessary to express
|
||||||
|
# that only actions of the same label are valid continuations.
|
||||||
|
B_range = numpy.arange(B_start, B_end)
|
||||||
|
I_range = numpy.arange(I_start, I_end)
|
||||||
|
L_range = numpy.arange(L_start, L_end)
|
||||||
|
O_action = U_end
|
||||||
|
# If this is the last token and the previous action was B or I, only L
|
||||||
|
# of that label is valid
|
||||||
|
table[1, B_range, L_range] = 1
|
||||||
|
table[1, I_range, L_range] = 1
|
||||||
|
# If this isn't the last token and the previous action was B or I, only I or
|
||||||
|
# L of that label are valid.
|
||||||
|
table[0, B_range, I_range] = 1
|
||||||
|
table[0, B_range, L_range] = 1
|
||||||
|
table[0, I_range, I_range] = 1
|
||||||
|
table[0, I_range, L_range] = 1
|
||||||
|
# If this isn't the last token and the previous was L, U or O, B is valid
|
||||||
|
table[0, L_start:, :B_end] = 1
|
||||||
|
# Regardless of whether this is the last token, if the previous action was
|
||||||
|
# {L, U, O}, U and O are valid.
|
||||||
|
table[:, L_start:, U_start:] = 1
|
||||||
|
_cache[n_actions] = table
|
||||||
|
return table
|
92
spacy/ml/_iob.py
Normal file
92
spacy/ml/_iob.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
"""Thinc layer to do simpler transition-based parsing, NER, etc."""
|
||||||
|
from typing import List, Tuple, Dict, Optional
|
||||||
|
from thinc.api import Ops, Model, with_array, softmax_activation, padded2list
|
||||||
|
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
|
||||||
|
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
def IOB() -> Model[Padded, Padded]:
|
||||||
|
return Model(
|
||||||
|
"biluo",
|
||||||
|
forward,
|
||||||
|
init=init,
|
||||||
|
dims={"nO": None},
|
||||||
|
attrs={"get_num_actions": get_num_actions}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init(model, X: Optional[Padded]=None, Y: Optional[Padded]=None):
|
||||||
|
if X is not None and Y is not None:
|
||||||
|
if X.data.shape != Y.data.shape:
|
||||||
|
# TODO: Fix error
|
||||||
|
raise ValueError("Mismatched shapes (TODO: Fix message)")
|
||||||
|
model.set_dim("nO", X.data.shape[2])
|
||||||
|
elif X is not None:
|
||||||
|
model.set_dim("nO", X.data.shape[2])
|
||||||
|
elif Y is not None:
|
||||||
|
model.set_dim("nO", Y.data.shape[2])
|
||||||
|
elif model.get_dim("nO") is None:
|
||||||
|
raise ValueError("Dimension unset for BILUO: nO")
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
|
||||||
|
n_labels = (model.get_dim("nO") - 1) // 2
|
||||||
|
n_tokens, n_docs, n_actions = Xp.data.shape
|
||||||
|
# At each timestep, we make a validity mask of shape (n_docs, n_actions)
|
||||||
|
# to indicate which actions are valid next for each sequence. To construct
|
||||||
|
# the mask, we have a state of shape (2, n_actions) and a validity table of
|
||||||
|
# shape (2, n_actions+1, n_actions). The first dimension of the state indicates
|
||||||
|
# whether it's the last token, the second dimension indicates the previous
|
||||||
|
# action, plus a special 'null action' for the first entry.
|
||||||
|
valid_transitions = _get_transition_table(model.ops, n_labels)
|
||||||
|
prev_actions = model.ops.alloc1i(n_docs)
|
||||||
|
# Initialize as though prev action was O
|
||||||
|
prev_actions.fill(n_actions - 1)
|
||||||
|
Y = model.ops.alloc3f(*Xp.data.shape)
|
||||||
|
masks = model.ops.alloc3f(*Y.shape)
|
||||||
|
for t in range(Xp.data.shape[0]):
|
||||||
|
masks[t] = valid_transitions[prev_actions]
|
||||||
|
# Don't train the out-of-bounds sequences.
|
||||||
|
masks[t, Xp.size_at_t[t]:] = 0
|
||||||
|
# Valid actions get 0*10e8, invalid get -1*10e8
|
||||||
|
Y[t] = Xp.data[t] + ((masks[t]-1) * 10e8)
|
||||||
|
prev_actions = Y[t].argmax(axis=-1)
|
||||||
|
|
||||||
|
def backprop_biluo(dY: Padded) -> Padded:
|
||||||
|
# Masking the gradient seems to do poorly here. But why?
|
||||||
|
#dY.data *= masks
|
||||||
|
return dY
|
||||||
|
|
||||||
|
return Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices), backprop_biluo
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_actions(n_labels: int) -> int:
|
||||||
|
# One BEGIN action per label
|
||||||
|
# One IN action per label
|
||||||
|
# One LAST action per label
|
||||||
|
# One UNIT action per label
|
||||||
|
# One OUT action
|
||||||
|
return n_labels * 2 + 1
|
||||||
|
|
||||||
|
|
||||||
|
def _get_transition_table(
|
||||||
|
ops: Ops, n_labels: int, _cache: Dict[int, Floats3d] = {}
|
||||||
|
) -> Floats3d:
|
||||||
|
n_actions = get_num_actions(n_labels)
|
||||||
|
if n_actions in _cache:
|
||||||
|
return ops.asarray(_cache[n_actions])
|
||||||
|
table = ops.alloc2f(n_actions, n_actions)
|
||||||
|
B_start, B_end = (0, n_labels)
|
||||||
|
I_start, I_end = (B_end, B_end + n_labels)
|
||||||
|
O_action = I_end
|
||||||
|
B_range = ops.xp.arange(B_start, B_end)
|
||||||
|
I_range = ops.xp.arange(I_start, I_end)
|
||||||
|
# B and O are always valid
|
||||||
|
table[:, B_start : B_end] = 1
|
||||||
|
table[:, O_action] = 1
|
||||||
|
# I can only follow a matching B
|
||||||
|
table[B_range, I_range] = 1
|
||||||
|
|
||||||
|
_cache[n_actions] = table
|
||||||
|
return table
|
|
@ -9,7 +9,6 @@ def PrecomputableAffine(nO, nI, nF, nP):
|
||||||
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
|
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
|
||||||
params={"W": None, "b": None, "pad": None},
|
params={"W": None, "b": None, "pad": None},
|
||||||
)
|
)
|
||||||
model.initialize()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,8 +109,7 @@ def init(model, X=None, Y=None):
|
||||||
pad = model.ops.alloc4f(1, nF, nO, nP)
|
pad = model.ops.alloc4f(1, nF, nO, nP)
|
||||||
|
|
||||||
ops = model.ops
|
ops = model.ops
|
||||||
scale = float(ops.xp.sqrt(1.0 / (nF * nI)))
|
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
|
||||||
W = normal_init(ops, W.shape, mean=scale)
|
|
||||||
model.set_param("W", W)
|
model.set_param("W", W)
|
||||||
model.set_param("b", b)
|
model.set_param("b", b)
|
||||||
model.set_param("pad", pad)
|
model.set_param("pad", pad)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from .entity_linker import * # noqa
|
from .entity_linker import * # noqa
|
||||||
from .parser import * # noqa
|
from .parser import * # noqa
|
||||||
|
from .simple_ner import *
|
||||||
from .tagger import * # noqa
|
from .tagger import * # noqa
|
||||||
from .tensorizer import * # noqa
|
from .tensorizer import * # noqa
|
||||||
from .textcat import * # noqa
|
from .textcat import * # noqa
|
||||||
|
|
|
@ -91,3 +91,13 @@ def default_tok2vec_config():
|
||||||
def default_tok2vec():
|
def default_tok2vec():
|
||||||
loc = Path(__file__).parent / "tok2vec_defaults.cfg"
|
loc = Path(__file__).parent / "tok2vec_defaults.cfg"
|
||||||
return util.load_config(loc, create_objects=True)["model"]
|
return util.load_config(loc, create_objects=True)["model"]
|
||||||
|
|
||||||
|
|
||||||
|
def default_simple_ner_config():
|
||||||
|
loc = Path(__file__).parent / "simple_ner_defaults.cfg"
|
||||||
|
return util.load_config(loc, create_objects=False)
|
||||||
|
|
||||||
|
|
||||||
|
def default_simple_ner():
|
||||||
|
loc = Path(__file__).parent / "simple_ner_defaults.cfg"
|
||||||
|
return util.load_config(loc, create_objects=True)["model"]
|
||||||
|
|
12
spacy/ml/models/defaults/simple_ner_defaults.cfg
Normal file
12
spacy/ml/models/defaults/simple_ner_defaults.cfg
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.BiluoTagger.v1"
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
pretrained_vectors = null
|
||||||
|
width = 128
|
||||||
|
depth = 4
|
||||||
|
embed_size = 7000
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
subword_features = true
|
|
@ -1,9 +1,9 @@
|
||||||
from pydantic import StrictInt
|
from pydantic import StrictInt
|
||||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops, with_array
|
||||||
|
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from .._precomputable_affine import PrecomputableAffine
|
from .._precomputable_affine import PrecomputableAffine
|
||||||
from ...syntax._parser_model import ParserModel
|
from ..tb_framework import TransitionModel
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||||
|
@ -12,21 +12,27 @@ def build_tb_parser_model(
|
||||||
nr_feature_tokens: StrictInt,
|
nr_feature_tokens: StrictInt,
|
||||||
hidden_width: StrictInt,
|
hidden_width: StrictInt,
|
||||||
maxout_pieces: StrictInt,
|
maxout_pieces: StrictInt,
|
||||||
|
use_upper=True,
|
||||||
nO=None,
|
nO=None,
|
||||||
):
|
):
|
||||||
token_vector_width = tok2vec.get_dim("nO")
|
token_vector_width = tok2vec.get_dim("nO")
|
||||||
tok2vec = chain(tok2vec, list2array())
|
tok2vec = chain(
|
||||||
tok2vec.set_dim("nO", token_vector_width)
|
tok2vec,
|
||||||
|
with_array(Linear(hidden_width, token_vector_width)),
|
||||||
|
list2array(),
|
||||||
|
)
|
||||||
|
tok2vec.set_dim("nO", hidden_width)
|
||||||
|
|
||||||
lower = PrecomputableAffine(
|
lower = PrecomputableAffine(
|
||||||
nO=hidden_width,
|
nO=hidden_width if use_upper else nO,
|
||||||
nF=nr_feature_tokens,
|
nF=nr_feature_tokens,
|
||||||
nI=tok2vec.get_dim("nO"),
|
nI=tok2vec.get_dim("nO"),
|
||||||
nP=maxout_pieces,
|
nP=maxout_pieces
|
||||||
)
|
)
|
||||||
lower.set_dim("nP", maxout_pieces)
|
if use_upper:
|
||||||
with use_ops("numpy"):
|
with use_ops("numpy"):
|
||||||
# Initialize weights at zero, as it's a classification layer.
|
# Initialize weights at zero, as it's a classification layer.
|
||||||
upper = Linear(nO=nO, init_W=zero_init)
|
upper = Linear(nO=nO, init_W=zero_init)
|
||||||
model = ParserModel(tok2vec, lower, upper)
|
else:
|
||||||
return model
|
upper = None
|
||||||
|
return TransitionModel(tok2vec, lower, upper)
|
||||||
|
|
82
spacy/ml/models/simple_ner.py
Normal file
82
spacy/ml/models/simple_ner.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import functools
|
||||||
|
from typing import List, Tuple, Dict, Optional
|
||||||
|
from thinc.api import Ops, Model, Linear, Softmax, with_array, softmax_activation, padded2list
|
||||||
|
from thinc.api import chain, list2padded, configure_normal_init
|
||||||
|
from thinc.api import Dropout
|
||||||
|
from thinc.types import Padded, Ints1d, Ints3d, Floats2d, Floats3d
|
||||||
|
|
||||||
|
from ...tokens import Doc
|
||||||
|
from .._biluo import BILUO
|
||||||
|
from .._iob import IOB
|
||||||
|
from ...util import registry
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.BiluoTagger.v1")
|
||||||
|
def BiluoTagger(tok2vec: Model[List[Doc], List[Floats2d]]) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
biluo = BILUO()
|
||||||
|
linear = Linear(
|
||||||
|
nO=None,
|
||||||
|
nI=tok2vec.get_dim("nO"),
|
||||||
|
init_W=configure_normal_init(mean=0.02)
|
||||||
|
)
|
||||||
|
model = chain(
|
||||||
|
tok2vec,
|
||||||
|
list2padded(),
|
||||||
|
with_array(chain(Dropout(0.1), linear)),
|
||||||
|
biluo,
|
||||||
|
with_array(softmax_activation()),
|
||||||
|
padded2list()
|
||||||
|
)
|
||||||
|
|
||||||
|
return Model(
|
||||||
|
"biluo-tagger",
|
||||||
|
forward,
|
||||||
|
init=init,
|
||||||
|
layers=[model, linear],
|
||||||
|
refs={"tok2vec": tok2vec, "linear": linear, "biluo": biluo},
|
||||||
|
dims={"nO": None},
|
||||||
|
attrs={"get_num_actions": biluo.attrs["get_num_actions"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.IOBTagger.v1")
|
||||||
|
def IOBTagger(tok2vec: Model[List[Doc], List[Floats2d]]) -> Model[List[Doc], List[Floats2d]]:
|
||||||
|
biluo = IOB()
|
||||||
|
linear = Linear(nO=None, nI=tok2vec.get_dim("nO"))
|
||||||
|
model = chain(
|
||||||
|
tok2vec,
|
||||||
|
list2padded(),
|
||||||
|
with_array(linear),
|
||||||
|
biluo,
|
||||||
|
with_array(softmax_activation()),
|
||||||
|
padded2list()
|
||||||
|
)
|
||||||
|
|
||||||
|
return Model(
|
||||||
|
"iob-tagger",
|
||||||
|
forward,
|
||||||
|
init=init,
|
||||||
|
layers=[model],
|
||||||
|
refs={"tok2vec": tok2vec, "linear": linear, "biluo": biluo},
|
||||||
|
dims={"nO": None},
|
||||||
|
attrs={"get_num_actions": biluo.attrs["get_num_actions"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def init(model: Model[List[Doc], List[Floats2d]], X=None, Y=None) -> None:
|
||||||
|
if model.get_dim("nO") is None and Y:
|
||||||
|
model.set_dim("nO", Y[0].shape[1])
|
||||||
|
nO = model.get_dim("nO")
|
||||||
|
biluo = model.get_ref("biluo")
|
||||||
|
linear = model.get_ref("linear")
|
||||||
|
biluo.set_dim("nO", nO)
|
||||||
|
if linear.has_dim("nO") is None:
|
||||||
|
linear.set_dim("nO", nO)
|
||||||
|
model.layers[0].initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model: Model, X: List[Doc], is_train: bool):
|
||||||
|
return model.layers[0](X, is_train)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BiluoTagger"]
|
|
@ -1,4 +1,5 @@
|
||||||
from thinc.api import zero_init, with_array, Softmax, chain, Model
|
from thinc.api import zero_init, with_array, Softmax, chain, Model, Dropout
|
||||||
|
from thinc.api import glorot_uniform_init
|
||||||
|
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
|
||||||
|
@ -11,6 +12,6 @@ def build_tagger_model(tok2vec, nO=None) -> Model:
|
||||||
softmax = with_array(output_layer)
|
softmax = with_array(output_layer)
|
||||||
model = chain(tok2vec, softmax)
|
model = chain(tok2vec, softmax)
|
||||||
model.set_ref("tok2vec", tok2vec)
|
model.set_ref("tok2vec", tok2vec)
|
||||||
model.set_ref("softmax", softmax)
|
model.set_ref("softmax", output_layer)
|
||||||
model.set_ref("output_layer", output_layer)
|
model.set_ref("output_layer", output_layer)
|
||||||
return model
|
return model
|
||||||
|
|
86
spacy/ml/tb_framework.py
Normal file
86
spacy/ml/tb_framework.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
from thinc.api import Model, noop, use_ops, Linear
|
||||||
|
from ..syntax._parser_model import ParserStepModel
|
||||||
|
|
||||||
|
|
||||||
|
def TransitionModel(tok2vec, lower, upper, unseen_classes=set()):
|
||||||
|
"""Set up a stepwise transition-based model"""
|
||||||
|
if upper is None:
|
||||||
|
has_upper = False
|
||||||
|
upper = noop()
|
||||||
|
else:
|
||||||
|
has_upper = True
|
||||||
|
# don't define nO for this object, because we can't dynamically change it
|
||||||
|
return Model(
|
||||||
|
name="parser_model",
|
||||||
|
forward=forward,
|
||||||
|
dims={"nI": tok2vec.get_dim("nI") if tok2vec.has_dim("nI") else None},
|
||||||
|
layers=[tok2vec, lower, upper],
|
||||||
|
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
|
||||||
|
init=init,
|
||||||
|
attrs={
|
||||||
|
"has_upper": has_upper,
|
||||||
|
"unseen_classes": set(unseen_classes),
|
||||||
|
"resize_output": resize_output
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(model, X, is_train):
|
||||||
|
step_model = ParserStepModel(
|
||||||
|
X,
|
||||||
|
model.layers,
|
||||||
|
unseen_classes=model.attrs["unseen_classes"],
|
||||||
|
train=is_train,
|
||||||
|
has_upper=model.attrs["has_upper"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return step_model, step_model.finish_steps
|
||||||
|
|
||||||
|
|
||||||
|
def init(model, X=None, Y=None):
|
||||||
|
tok2vec = model.get_ref("tok2vec").initialize()
|
||||||
|
lower = model.get_ref("lower").initialize(X=X)
|
||||||
|
if model.attrs["has_upper"]:
|
||||||
|
statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
|
||||||
|
model.get_ref("upper").initialize(X=statevecs)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_output(model, new_nO):
|
||||||
|
tok2vec = model.get_ref("tok2vec")
|
||||||
|
lower = model.get_ref("lower")
|
||||||
|
upper = model.get_ref("upper")
|
||||||
|
if not model.attrs["has_upper"]:
|
||||||
|
if lower.has_dim("nO") is None:
|
||||||
|
lower.set_dim("nO", new_nO)
|
||||||
|
return
|
||||||
|
elif upper.has_dim("nO") is None:
|
||||||
|
upper.set_dim("nO", new_nO)
|
||||||
|
return
|
||||||
|
elif new_nO == upper.get_dim("nO"):
|
||||||
|
return
|
||||||
|
smaller = upper
|
||||||
|
nI = None
|
||||||
|
if smaller.has_dim("nI"):
|
||||||
|
nI = smaller.get_dim("nI")
|
||||||
|
with use_ops('numpy'):
|
||||||
|
larger = Linear(nO=new_nO, nI=nI)
|
||||||
|
larger.init = smaller.init
|
||||||
|
# it could be that the model is not initialized yet, then skip this bit
|
||||||
|
if nI:
|
||||||
|
larger_W = larger.ops.alloc2f(new_nO, nI)
|
||||||
|
larger_b = larger.ops.alloc1f(new_nO)
|
||||||
|
smaller_W = smaller.get_param("W")
|
||||||
|
smaller_b = smaller.get_param("b")
|
||||||
|
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
||||||
|
# just adding rows here.
|
||||||
|
if smaller.has_dim("nO"):
|
||||||
|
larger_W[:smaller.get_dim("nO")] = smaller_W
|
||||||
|
larger_b[:smaller.get_dim("nO")] = smaller_b
|
||||||
|
for i in range(smaller.get_dim("nO"), new_nO):
|
||||||
|
model.attrs["unseen_classes"].add(i)
|
||||||
|
|
||||||
|
larger.set_param("W", larger_W)
|
||||||
|
larger.set_param("b", larger_b)
|
||||||
|
model._layers[-1] = larger
|
||||||
|
model.set_ref("upper", larger)
|
||||||
|
return model
|
|
@ -1,6 +1,7 @@
|
||||||
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
||||||
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
||||||
from .pipes import SentenceRecognizer
|
from .pipes import SentenceRecognizer
|
||||||
|
from .simple_ner import SimpleNER
|
||||||
from .morphologizer import Morphologizer
|
from .morphologizer import Morphologizer
|
||||||
from .entityruler import EntityRuler
|
from .entityruler import EntityRuler
|
||||||
from .tok2vec import Tok2Vec
|
from .tok2vec import Tok2Vec
|
||||||
|
@ -22,6 +23,7 @@ __all__ = [
|
||||||
"SentenceSegmenter",
|
"SentenceSegmenter",
|
||||||
"SentenceRecognizer",
|
"SentenceRecognizer",
|
||||||
"SimilarityHook",
|
"SimilarityHook",
|
||||||
|
"SimpleNER",
|
||||||
"merge_entities",
|
"merge_entities",
|
||||||
"merge_noun_chunks",
|
"merge_noun_chunks",
|
||||||
"merge_subtokens",
|
"merge_subtokens",
|
||||||
|
|
|
@ -3,7 +3,7 @@ import numpy
|
||||||
import srsly
|
import srsly
|
||||||
import random
|
import random
|
||||||
from thinc.api import CosineDistance, to_categorical, get_array_module
|
from thinc.api import CosineDistance, to_categorical, get_array_module
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
@ -464,6 +464,9 @@ class Tagger(Pipe):
|
||||||
return
|
return
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples])
|
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples])
|
||||||
|
for sc in tag_scores:
|
||||||
|
if self.model.ops.xp.isnan(sc.sum()):
|
||||||
|
raise ValueError("nan value in scores")
|
||||||
loss, d_tag_scores = self.get_loss(examples, tag_scores)
|
loss, d_tag_scores = self.get_loss(examples, tag_scores)
|
||||||
bp_tag_scores(d_tag_scores)
|
bp_tag_scores(d_tag_scores)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
|
@ -497,29 +500,11 @@ class Tagger(Pipe):
|
||||||
losses[self.name] += (gradient**2).sum()
|
losses[self.name] += (gradient**2).sum()
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
scores = self.model.ops.flatten(scores)
|
loss_func = SequenceCategoricalCrossentropy(names=self.labels)
|
||||||
tag_index = {tag: i for i, tag in enumerate(self.labels)}
|
truths = [eg.gold.tags for eg in examples]
|
||||||
cdef int idx = 0
|
d_scores, loss = loss_func(scores, truths)
|
||||||
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
if self.model.ops.xp.isnan(loss):
|
||||||
guesses = scores.argmax(axis=1)
|
raise ValueError("nan value when computing loss")
|
||||||
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
|
|
||||||
for ex in examples:
|
|
||||||
gold = ex.gold
|
|
||||||
for tag in gold.tags:
|
|
||||||
if tag is None:
|
|
||||||
correct[idx] = guesses[idx]
|
|
||||||
elif tag in tag_index:
|
|
||||||
correct[idx] = tag_index[tag]
|
|
||||||
else:
|
|
||||||
correct[idx] = 0
|
|
||||||
known_labels[idx] = 0.
|
|
||||||
idx += 1
|
|
||||||
correct = self.model.ops.xp.array(correct, dtype="i")
|
|
||||||
d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
|
|
||||||
d_scores *= self.model.ops.asarray(known_labels)
|
|
||||||
loss = (d_scores**2).sum()
|
|
||||||
docs = [ex.doc for ex in examples]
|
|
||||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
|
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
|
||||||
|
|
149
spacy/pipeline/simple_ner.py
Normal file
149
spacy/pipeline/simple_ner.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
from typing import List
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate
|
||||||
|
from thinc.util import to_numpy
|
||||||
|
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
|
||||||
|
from ..tokens import Doc
|
||||||
|
from ..language import component
|
||||||
|
from ..util import link_vectors_to_models
|
||||||
|
from .pipes import Pipe
|
||||||
|
|
||||||
|
|
||||||
|
@component("simple_ner", assigns=["doc.ents"])
|
||||||
|
class SimpleNER(Pipe):
|
||||||
|
"""Named entity recognition with a tagging model. The model should include
|
||||||
|
validity constraints to ensure that only valid tag sequences are returned."""
|
||||||
|
|
||||||
|
def __init__(self, vocab, model):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
self.cfg = {"labels": []}
|
||||||
|
self.loss_func = SequenceCategoricalCrossentropy(
|
||||||
|
names=self.get_tag_names(),
|
||||||
|
normalize=True,
|
||||||
|
missing_value=None
|
||||||
|
)
|
||||||
|
assert self.model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
return self.cfg["labels"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_biluo(self):
|
||||||
|
return self.model.name.startswith("biluo")
|
||||||
|
|
||||||
|
def add_label(self, label):
|
||||||
|
if label not in self.cfg["labels"]:
|
||||||
|
self.cfg["labels"].append(label)
|
||||||
|
|
||||||
|
def get_tag_names(self):
|
||||||
|
if self.is_biluo:
|
||||||
|
return (
|
||||||
|
[f"B-{label}" for label in self.labels] +
|
||||||
|
[f"I-{label}" for label in self.labels] +
|
||||||
|
[f"L-{label}" for label in self.labels] +
|
||||||
|
[f"U-{label}" for label in self.labels] +
|
||||||
|
["O"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
[f"B-{label}" for label in self.labels] +
|
||||||
|
[f"I-{label}" for label in self.labels] +
|
||||||
|
["O"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, docs: List[Doc]) -> List[Floats2d]:
|
||||||
|
scores = self.model.predict(docs)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def set_annotations(self, docs: List[Doc], scores: List[Floats2d], tensors=None):
|
||||||
|
"""Set entities on a batch of documents from a batch of scores."""
|
||||||
|
tag_names = self.get_tag_names()
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
actions = to_numpy(scores[i].argmax(axis=1))
|
||||||
|
tags = [tag_names[actions[j]] for j in range(len(doc))]
|
||||||
|
if not self.is_biluo:
|
||||||
|
tags = iob_to_biluo(tags)
|
||||||
|
doc.ents = spans_from_biluo_tags(doc, tags)
|
||||||
|
|
||||||
|
def update(self, examples, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||||
|
if not any(_has_ner(eg) for eg in examples):
|
||||||
|
return 0
|
||||||
|
examples = Example.to_example_objects(examples)
|
||||||
|
docs = [ex.doc for ex in examples]
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
scores, bp_scores = self.model.begin_update(docs)
|
||||||
|
loss, d_scores = self.get_loss(examples, scores)
|
||||||
|
bp_scores(d_scores)
|
||||||
|
if set_annotations:
|
||||||
|
self.set_annotations(docs, scores)
|
||||||
|
if sgd is not None:
|
||||||
|
self.model.finish_update(sgd)
|
||||||
|
if losses is not None:
|
||||||
|
losses.setdefault("ner", 0.0)
|
||||||
|
losses["ner"] += loss
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_loss(self, examples, scores):
|
||||||
|
loss = 0
|
||||||
|
d_scores = []
|
||||||
|
truths = []
|
||||||
|
for eg in examples:
|
||||||
|
gold_tags = [(tag if tag != "-" else None) for tag in eg.gold.ner]
|
||||||
|
if not self.is_biluo:
|
||||||
|
gold_tags = biluo_to_iob(gold_tags)
|
||||||
|
truths.append(gold_tags)
|
||||||
|
for i in range(len(scores)):
|
||||||
|
if len(scores[i]) != len(truths[i]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatched output and gold sizes.\n"
|
||||||
|
f"Output: {len(scores[i])}, gold: {len(truths[i])}."
|
||||||
|
f"Input: {len(examples[i].doc)}"
|
||||||
|
)
|
||||||
|
d_scores, loss = self.loss_func(scores, truths)
|
||||||
|
return loss, d_scores
|
||||||
|
|
||||||
|
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
|
||||||
|
self.cfg.update(kwargs)
|
||||||
|
if not hasattr(get_examples, '__call__'):
|
||||||
|
gold_tuples = get_examples
|
||||||
|
get_examples = lambda: gold_tuples
|
||||||
|
labels = _get_labels(get_examples())
|
||||||
|
for label in _get_labels(get_examples()):
|
||||||
|
self.add_label(label)
|
||||||
|
labels = self.labels
|
||||||
|
n_actions = self.model.attrs["get_num_actions"](len(labels))
|
||||||
|
self.model.set_dim("nO", n_actions)
|
||||||
|
self.model.initialize()
|
||||||
|
if pipeline is not None:
|
||||||
|
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||||
|
link_vectors_to_models(self.vocab)
|
||||||
|
self.loss_func = SequenceCategoricalCrossentropy(
|
||||||
|
names=self.get_tag_names(),
|
||||||
|
normalize=True,
|
||||||
|
missing_value=None
|
||||||
|
)
|
||||||
|
|
||||||
|
return sgd
|
||||||
|
|
||||||
|
def init_multitask_objectives(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _has_ner(eg):
|
||||||
|
for ner_tag in eg.gold.ner:
|
||||||
|
if ner_tag != "-" and ner_tag != None:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_labels(examples):
|
||||||
|
labels = set()
|
||||||
|
for eg in examples:
|
||||||
|
for ner_tag in eg.token_annotation.entities:
|
||||||
|
if ner_tag != 'O' and ner_tag != '-':
|
||||||
|
_, label = ner_tag.split('-', 1)
|
||||||
|
labels.add(label)
|
||||||
|
return list(sorted(labels))
|
|
@ -12,7 +12,7 @@ cimport blis.cy
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import numpy.random
|
import numpy.random
|
||||||
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops
|
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
|
||||||
|
|
||||||
from ..typedefs cimport weight_t, class_t, hash_t
|
from ..typedefs cimport weight_t, class_t, hash_t
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
@ -219,112 +219,27 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
|
||||||
return best
|
return best
|
||||||
|
|
||||||
|
|
||||||
class ParserModel(Model):
|
|
||||||
def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
|
|
||||||
# don't define nO for this object, because we can't dynamically change it
|
|
||||||
Model.__init__(self, name="parser_model", forward=forward, dims={"nI": None})
|
|
||||||
if tok2vec.has_dim("nI"):
|
|
||||||
self.set_dim("nI", tok2vec.get_dim("nI"))
|
|
||||||
self._layers = [tok2vec, lower_model]
|
|
||||||
if upper_model is not None:
|
|
||||||
self._layers.append(upper_model)
|
|
||||||
self.unseen_classes = set()
|
|
||||||
if unseen_classes:
|
|
||||||
for class_ in unseen_classes:
|
|
||||||
self.unseen_classes.add(class_)
|
|
||||||
self.set_ref("tok2vec", tok2vec)
|
|
||||||
|
|
||||||
def predict(self, docs):
|
|
||||||
step_model = ParserStepModel(docs, self._layers,
|
|
||||||
unseen_classes=self.unseen_classes, train=False)
|
|
||||||
return step_model
|
|
||||||
|
|
||||||
def resize_output(self, new_nO):
|
|
||||||
if len(self._layers) == 2:
|
|
||||||
return
|
|
||||||
if self.upper.has_dim("nO") and (new_nO == self.upper.get_dim("nO")):
|
|
||||||
return
|
|
||||||
smaller = self.upper
|
|
||||||
nI = None
|
|
||||||
if smaller.has_dim("nI"):
|
|
||||||
nI = smaller.get_dim("nI")
|
|
||||||
with use_ops('numpy'):
|
|
||||||
larger = Linear(nO=new_nO, nI=nI)
|
|
||||||
larger.init = smaller.init
|
|
||||||
# it could be that the model is not initialized yet, then skip this bit
|
|
||||||
if nI:
|
|
||||||
larger_W = larger.ops.alloc2f(new_nO, nI)
|
|
||||||
larger_b = larger.ops.alloc1f(new_nO)
|
|
||||||
smaller_W = smaller.get_param("W")
|
|
||||||
smaller_b = smaller.get_param("b")
|
|
||||||
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
|
||||||
# just adding rows here.
|
|
||||||
if smaller.has_dim("nO"):
|
|
||||||
larger_W[:smaller.get_dim("nO")] = smaller_W
|
|
||||||
larger_b[:smaller.get_dim("nO")] = smaller_b
|
|
||||||
for i in range(smaller.get_dim("nO"), new_nO):
|
|
||||||
self.unseen_classes.add(i)
|
|
||||||
|
|
||||||
larger.set_param("W", larger_W)
|
|
||||||
larger.set_param("b", larger_b)
|
|
||||||
self._layers[-1] = larger
|
|
||||||
|
|
||||||
def initialize(self, X=None, Y=None):
|
|
||||||
self.tok2vec.initialize()
|
|
||||||
self.lower.initialize(X=X, Y=Y)
|
|
||||||
if self.upper is not None:
|
|
||||||
# In case we need to trigger the callbacks
|
|
||||||
statevecs = self.ops.alloc((2, self.lower.get_dim("nO")))
|
|
||||||
self.upper.initialize(X=statevecs)
|
|
||||||
|
|
||||||
def finish_update(self, optimizer):
|
|
||||||
self.tok2vec.finish_update(optimizer)
|
|
||||||
self.lower.finish_update(optimizer)
|
|
||||||
if self.upper is not None:
|
|
||||||
self.upper.finish_update(optimizer)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tok2vec(self):
|
|
||||||
return self._layers[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lower(self):
|
|
||||||
return self._layers[1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def upper(self):
|
|
||||||
return self._layers[2]
|
|
||||||
|
|
||||||
|
|
||||||
def forward(model:ParserModel, X, is_train):
|
|
||||||
step_model = ParserStepModel(X, model._layers, unseen_classes=model.unseen_classes,
|
|
||||||
train=is_train)
|
|
||||||
|
|
||||||
return step_model, step_model.finish_steps
|
|
||||||
|
|
||||||
|
|
||||||
class ParserStepModel(Model):
|
class ParserStepModel(Model):
|
||||||
def __init__(self, docs, layers, unseen_classes=None, train=True):
|
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True):
|
||||||
Model.__init__(self, name="parser_step_model", forward=step_forward)
|
Model.__init__(self, name="parser_step_model", forward=step_forward)
|
||||||
|
self.attrs["has_upper"] = has_upper
|
||||||
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
|
self.tokvecs, self.bp_tokvecs = layers[0](docs, is_train=train)
|
||||||
if layers[1].get_dim("nP") >= 2:
|
if layers[1].get_dim("nP") >= 2:
|
||||||
activation = "maxout"
|
activation = "maxout"
|
||||||
elif len(layers) == 2:
|
elif has_upper:
|
||||||
activation = None
|
activation = None
|
||||||
else:
|
else:
|
||||||
activation = "relu"
|
activation = "relu"
|
||||||
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
|
self.state2vec = precompute_hiddens(len(docs), self.tokvecs, layers[1],
|
||||||
activation=activation, train=train)
|
activation=activation, train=train)
|
||||||
if len(layers) == 3:
|
if has_upper:
|
||||||
self.vec2scores = layers[-1]
|
self.vec2scores = layers[-1]
|
||||||
else:
|
else:
|
||||||
self.vec2scores = None
|
self.vec2scores = None
|
||||||
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
|
self.cuda_stream = util.get_cuda_stream(non_blocking=True)
|
||||||
self.backprops = []
|
self.backprops = []
|
||||||
if self.vec2scores is None:
|
self._class_mask = numpy.zeros((self.nO,), dtype='f')
|
||||||
self._class_mask = numpy.zeros((self.state2vec.nO,), dtype='f')
|
|
||||||
else:
|
|
||||||
self._class_mask = numpy.zeros((self.vec2scores.get_dim("nO"),), dtype='f')
|
|
||||||
self._class_mask.fill(1)
|
self._class_mask.fill(1)
|
||||||
if unseen_classes is not None:
|
if unseen_classes is not None:
|
||||||
for class_ in unseen_classes:
|
for class_ in unseen_classes:
|
||||||
|
@ -332,7 +247,10 @@ class ParserStepModel(Model):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nO(self):
|
def nO(self):
|
||||||
return self.state2vec.nO
|
if self.attrs["has_upper"]:
|
||||||
|
return self.vec2scores.get_dim("nO")
|
||||||
|
else:
|
||||||
|
return self.state2vec.get_dim("nO")
|
||||||
|
|
||||||
def class_is_unseen(self, class_):
|
def class_is_unseen(self, class_):
|
||||||
return self._class_mask[class_]
|
return self._class_mask[class_]
|
||||||
|
@ -378,7 +296,7 @@ class ParserStepModel(Model):
|
||||||
def step_forward(model: ParserStepModel, states, is_train):
|
def step_forward(model: ParserStepModel, states, is_train):
|
||||||
token_ids = model.get_token_ids(states)
|
token_ids = model.get_token_ids(states)
|
||||||
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
|
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
|
||||||
if model.vec2scores is not None:
|
if model.attrs["has_upper"]:
|
||||||
scores, get_d_vector = model.vec2scores(vector, is_train)
|
scores, get_d_vector = model.vec2scores(vector, is_train)
|
||||||
else:
|
else:
|
||||||
scores = NumpyOps().asarray(vector)
|
scores = NumpyOps().asarray(vector)
|
||||||
|
|
|
@ -36,7 +36,6 @@ from ..util import link_vectors_to_models, create_default_optimizer, registry
|
||||||
from ..compat import copy_array
|
from ..compat import copy_array
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from .. import util
|
from .. import util
|
||||||
from ._parser_model import ParserModel
|
|
||||||
from . import _beam_utils
|
from . import _beam_utils
|
||||||
from . import nonproj
|
from . import nonproj
|
||||||
|
|
||||||
|
@ -69,7 +68,8 @@ cdef class Parser:
|
||||||
cfg.setdefault('beam_width', 1)
|
cfg.setdefault('beam_width', 1)
|
||||||
cfg.setdefault('beam_update_prob', 1.0) # or 0.5 (both defaults were previously used)
|
cfg.setdefault('beam_update_prob', 1.0) # or 0.5 (both defaults were previously used)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.set_output(self.moves.n_moves)
|
if self.moves.n_moves != 0:
|
||||||
|
self.set_output(self.moves.n_moves)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._multitasks = []
|
self._multitasks = []
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
|
@ -105,7 +105,7 @@ cdef class Parser:
|
||||||
@property
|
@property
|
||||||
def tok2vec(self):
|
def tok2vec(self):
|
||||||
'''Return the embedding and convolutional layer of the model.'''
|
'''Return the embedding and convolutional layer of the model.'''
|
||||||
return self.model.tok2vec
|
return self.model.get_ref("tok2vec")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def postprocesses(self):
|
def postprocesses(self):
|
||||||
|
@ -122,9 +122,11 @@ cdef class Parser:
|
||||||
self._resize()
|
self._resize()
|
||||||
|
|
||||||
def _resize(self):
|
def _resize(self):
|
||||||
self.model.resize_output(self.moves.n_moves)
|
self.model.attrs["resize_output"](self.model, self.moves.n_moves)
|
||||||
if self._rehearsal_model not in (True, False, None):
|
if self._rehearsal_model not in (True, False, None):
|
||||||
self._rehearsal_model.resize_output(self.moves.n_moves)
|
self._rehearsal_model.attrs["resize_output"](
|
||||||
|
self._rehearsal_model, self.moves.n_moves
|
||||||
|
)
|
||||||
|
|
||||||
def add_multitask_objective(self, target):
|
def add_multitask_objective(self, target):
|
||||||
# Defined in subclasses, to avoid circular import
|
# Defined in subclasses, to avoid circular import
|
||||||
|
@ -216,7 +218,6 @@ cdef class Parser:
|
||||||
# expand our model output.
|
# expand our model output.
|
||||||
self._resize()
|
self._resize()
|
||||||
model = self.model.predict(docs)
|
model = self.model.predict(docs)
|
||||||
W_param = model.vec2scores.get_param("W")
|
|
||||||
weights = get_c_weights(model)
|
weights = get_c_weights(model)
|
||||||
for state in batch:
|
for state in batch:
|
||||||
if not state.is_final():
|
if not state.is_final():
|
||||||
|
@ -237,7 +238,7 @@ cdef class Parser:
|
||||||
# if labels are missing. We therefore have to check whether we need to
|
# if labels are missing. We therefore have to check whether we need to
|
||||||
# expand our model output.
|
# expand our model output.
|
||||||
self._resize()
|
self._resize()
|
||||||
cdef int nr_feature = self.model.lower.get_dim("nF")
|
cdef int nr_feature = self.model.get_ref("lower").get_dim("nF")
|
||||||
model = self.model.predict(docs)
|
model = self.model.predict(docs)
|
||||||
token_ids = numpy.zeros((len(docs) * beam_width, nr_feature),
|
token_ids = numpy.zeros((len(docs) * beam_width, nr_feature),
|
||||||
dtype='i', order='C')
|
dtype='i', order='C')
|
||||||
|
@ -370,13 +371,16 @@ cdef class Parser:
|
||||||
beam_density=self.cfg.get('beam_density', 0.001))
|
beam_density=self.cfg.get('beam_density', 0.001))
|
||||||
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
cut_gold = True
|
||||||
# batch uniform length.
|
if cut_gold:
|
||||||
cut_gold = numpy.random.choice(range(20, 100))
|
# Chop sequences into lengths of this many transitions, to make the
|
||||||
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
|
# batch uniform length.
|
||||||
|
cut_gold = numpy.random.choice(range(20, 100))
|
||||||
|
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
|
||||||
|
else:
|
||||||
|
states, golds, max_steps = self._init_gold_batch_no_cut(examples)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||||
if not s.is_final() and g is not None]
|
if not s.is_final() and g is not None]
|
||||||
|
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update([ex.doc for ex in examples])
|
model, backprop_tok2vec = self.model.begin_update([ex.doc for ex in examples])
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
|
@ -456,9 +460,17 @@ cdef class Parser:
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
model, backprop_tok2vec = self.model.begin_update(docs)
|
model, backprop_tok2vec = self.model.begin_update(docs)
|
||||||
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
||||||
self.moves, self.model.lower.get_dim("nF"), 10000, states, golds,
|
self.moves,
|
||||||
model.state2vec, model.vec2scores, width, losses=losses,
|
self.model.get_ref("lower").get_dim("nF"),
|
||||||
beam_density=beam_density)
|
10000,
|
||||||
|
states,
|
||||||
|
golds,
|
||||||
|
model.state2vec,
|
||||||
|
model.vec2scores,
|
||||||
|
width,
|
||||||
|
losses=losses,
|
||||||
|
beam_density=beam_density
|
||||||
|
)
|
||||||
for i, d_scores in enumerate(states_d_scores):
|
for i, d_scores in enumerate(states_d_scores):
|
||||||
losses[self.name] += (d_scores**2).mean()
|
losses[self.name] += (d_scores**2).mean()
|
||||||
ids, bp_vectors, bp_scores = backprops[i]
|
ids, bp_vectors, bp_scores = backprops[i]
|
||||||
|
@ -497,6 +509,24 @@ cdef class Parser:
|
||||||
queue.extend(node._layers)
|
queue.extend(node._layers)
|
||||||
return gradients
|
return gradients
|
||||||
|
|
||||||
|
def _init_gold_batch_no_cut(self, whole_examples):
|
||||||
|
states = self.moves.init_batch([eg.doc for eg in whole_examples])
|
||||||
|
good_docs = []
|
||||||
|
good_golds = []
|
||||||
|
good_states = []
|
||||||
|
for i, eg in enumerate(whole_examples):
|
||||||
|
doc = eg.doc
|
||||||
|
gold = self.moves.preprocess_gold(eg.gold)
|
||||||
|
if gold is not None and self.moves.has_gold(gold):
|
||||||
|
good_docs.append(doc)
|
||||||
|
good_golds.append(gold)
|
||||||
|
good_states.append(states[i])
|
||||||
|
n_moves = []
|
||||||
|
for doc, gold in zip(good_docs, good_golds):
|
||||||
|
oracle_actions = self.moves.get_oracle_sequence(doc, gold)
|
||||||
|
n_moves.append(len(oracle_actions))
|
||||||
|
return good_states, good_golds, max(n_moves, default=0) * 2
|
||||||
|
|
||||||
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
|
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
|
||||||
"""Make a square batch, of length equal to the shortest doc. A long
|
"""Make a square batch, of length equal to the shortest doc. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
@ -550,16 +580,19 @@ cdef class Parser:
|
||||||
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
|
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
|
||||||
dtype='f', order='C')
|
dtype='f', order='C')
|
||||||
c_d_scores = <float*>d_scores.data
|
c_d_scores = <float*>d_scores.data
|
||||||
|
unseen_classes = self.model.attrs["unseen_classes"]
|
||||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||||
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
||||||
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
||||||
self.moves.set_costs(is_valid, costs, state, gold)
|
self.moves.set_costs(is_valid, costs, state, gold)
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
if costs[j] <= 0.0 and j in self.model.unseen_classes:
|
if costs[j] <= 0.0 and j in unseen_classes:
|
||||||
self.model.unseen_classes.remove(j)
|
unseen_classes.remove(j)
|
||||||
cpu_log_loss(c_d_scores,
|
cpu_log_loss(c_d_scores,
|
||||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||||
c_d_scores += d_scores.shape[1]
|
c_d_scores += d_scores.shape[1]
|
||||||
|
if len(states):
|
||||||
|
d_scores /= len(states)
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
|
@ -569,8 +602,7 @@ cdef class Parser:
|
||||||
return create_default_optimizer()
|
return create_default_optimizer()
|
||||||
|
|
||||||
def set_output(self, nO):
|
def set_output(self, nO):
|
||||||
if self.model.upper.has_dim("nO") is None:
|
self.model.attrs["resize_output"](self.model, nO)
|
||||||
self.model.upper.set_dim("nO", nO)
|
|
||||||
|
|
||||||
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
|
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
|
||||||
self.cfg.update(kwargs)
|
self.cfg.update(kwargs)
|
||||||
|
@ -597,7 +629,6 @@ cdef class Parser:
|
||||||
for doc, gold in parses:
|
for doc, gold in parses:
|
||||||
doc_sample.append(doc)
|
doc_sample.append(doc)
|
||||||
gold_sample.append(gold)
|
gold_sample.append(gold)
|
||||||
|
|
||||||
self.model.initialize(doc_sample, gold_sample)
|
self.model.initialize(doc_sample, gold_sample)
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||||
|
|
|
@ -65,7 +65,7 @@ def test_add_label_deserializes_correctly():
|
||||||
ner2 = EntityRecognizer(Vocab(), default_ner())
|
ner2 = EntityRecognizer(Vocab(), default_ner())
|
||||||
|
|
||||||
# the second model needs to be resized before we can call from_bytes
|
# the second model needs to be resized before we can call from_bytes
|
||||||
ner2.model.resize_output(ner1.moves.n_moves)
|
ner2.model.attrs["resize_output"](ner2.model, ner1.moves.n_moves)
|
||||||
ner2.from_bytes(ner1.to_bytes())
|
ner2.from_bytes(ner1.to_bytes())
|
||||||
assert ner1.moves.n_moves == ner2.moves.n_moves
|
assert ner1.moves.n_moves == ner2.moves.n_moves
|
||||||
for i in range(ner1.moves.n_moves):
|
for i in range(ner1.moves.n_moves):
|
||||||
|
|
|
@ -3,9 +3,9 @@ from spacy.ml.models.defaults import default_parser, default_tok2vec
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.syntax.arc_eager import ArcEager
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
from spacy.syntax.nn_parser import Parser
|
from spacy.syntax.nn_parser import Parser
|
||||||
from spacy.syntax._parser_model import ParserModel
|
|
||||||
from spacy.tokens.doc import Doc
|
from spacy.tokens.doc import Doc
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
|
from thinc.api import Model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -34,7 +34,7 @@ def parser(vocab, arc_eager):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model(arc_eager, tok2vec, vocab):
|
def model(arc_eager, tok2vec, vocab):
|
||||||
model = default_parser()
|
model = default_parser()
|
||||||
model.resize_output(arc_eager.n_moves)
|
model.attrs["resize_output"](model, arc_eager.n_moves)
|
||||||
model.initialize()
|
model.initialize()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def gold(doc):
|
||||||
|
|
||||||
|
|
||||||
def test_can_init_nn_parser(parser):
|
def test_can_init_nn_parser(parser):
|
||||||
assert isinstance(parser.model, ParserModel)
|
assert isinstance(parser.model, Model)
|
||||||
|
|
||||||
|
|
||||||
def test_build_model(parser, vocab):
|
def test_build_model(parser, vocab):
|
||||||
|
|
417
spacy/tests/pipeline/test_simple_ner.py
Normal file
417
spacy/tests/pipeline/test_simple_ner.py
Normal file
|
@ -0,0 +1,417 @@
|
||||||
|
import pytest
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from thinc.api import NumpyOps
|
||||||
|
from spacy.ml._biluo import BILUO, _get_transition_table
|
||||||
|
from spacy.pipeline.simple_ner import SimpleNER
|
||||||
|
import spacy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[
|
||||||
|
["PER", "ORG", "LOC", "MISC"],
|
||||||
|
["GPE", "PERSON", "NUMBER", "CURRENCY", "EVENT"]
|
||||||
|
])
|
||||||
|
def labels(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ops():
|
||||||
|
return NumpyOps()
|
||||||
|
|
||||||
|
def _get_actions(labels):
|
||||||
|
action_names = (
|
||||||
|
[f"B{label}" for label in labels] + \
|
||||||
|
[f"I{label}" for label in labels] + \
|
||||||
|
[f"L{label}" for label in labels] + \
|
||||||
|
[f"U{label}" for label in labels] + \
|
||||||
|
["O"]
|
||||||
|
)
|
||||||
|
A = namedtuple("actions", action_names)
|
||||||
|
return A(**{name: i for i, name in enumerate(action_names)})
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_biluo_layer(labels):
|
||||||
|
model = BILUO()
|
||||||
|
model.set_dim("nO", model.attrs["get_num_actions"](len(labels)))
|
||||||
|
model.initialize()
|
||||||
|
assert model.get_dim("nO") == len(labels) * 4 + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_transition_table(ops):
|
||||||
|
labels = ["per", "loc", "org"]
|
||||||
|
table = _get_transition_table(len(labels))
|
||||||
|
a = _get_actions(labels)
|
||||||
|
assert table.shape == (2, len(a), len(a))
|
||||||
|
# Not last token, prev action was B
|
||||||
|
assert table[0, a.Bper, a.Bper] == 0
|
||||||
|
assert table[0, a.Bper, a.Bloc] == 0
|
||||||
|
assert table[0, a.Bper, a.Borg] == 0
|
||||||
|
assert table[0, a.Bper, a.Iper] == 1
|
||||||
|
assert table[0, a.Bper, a.Iloc] == 0
|
||||||
|
assert table[0, a.Bper, a.Iorg] == 0
|
||||||
|
assert table[0, a.Bper, a.Lper] == 1
|
||||||
|
assert table[0, a.Bper, a.Lloc] == 0
|
||||||
|
assert table[0, a.Bper, a.Lorg] == 0
|
||||||
|
assert table[0, a.Bper, a.Uper] == 0
|
||||||
|
assert table[0, a.Bper, a.Uloc] == 0
|
||||||
|
assert table[0, a.Bper, a.Uorg] == 0
|
||||||
|
assert table[0, a.Bper, a.O] == 0
|
||||||
|
|
||||||
|
assert table[0, a.Bloc, a.Bper] == 0
|
||||||
|
assert table[0, a.Bloc, a.Bloc] == 0
|
||||||
|
assert table[0, a.Bloc, a.Borg] == 0
|
||||||
|
assert table[0, a.Bloc, a.Iper] == 0
|
||||||
|
assert table[0, a.Bloc, a.Iloc] == 1
|
||||||
|
assert table[0, a.Bloc, a.Iorg] == 0
|
||||||
|
assert table[0, a.Bloc, a.Lper] == 0
|
||||||
|
assert table[0, a.Bloc, a.Lloc] == 1
|
||||||
|
assert table[0, a.Bloc, a.Lorg] == 0
|
||||||
|
assert table[0, a.Bloc, a.Uper] == 0
|
||||||
|
assert table[0, a.Bloc, a.Uloc] == 0
|
||||||
|
assert table[0, a.Bloc, a.Uorg] == 0
|
||||||
|
assert table[0, a.Bloc, a.O] == 0
|
||||||
|
|
||||||
|
assert table[0, a.Borg, a.Bper] == 0
|
||||||
|
assert table[0, a.Borg, a.Bloc] == 0
|
||||||
|
assert table[0, a.Borg, a.Borg] == 0
|
||||||
|
assert table[0, a.Borg, a.Iper] == 0
|
||||||
|
assert table[0, a.Borg, a.Iloc] == 0
|
||||||
|
assert table[0, a.Borg, a.Iorg] == 1
|
||||||
|
assert table[0, a.Borg, a.Lper] == 0
|
||||||
|
assert table[0, a.Borg, a.Lloc] == 0
|
||||||
|
assert table[0, a.Borg, a.Lorg] == 1
|
||||||
|
assert table[0, a.Borg, a.Uper] == 0
|
||||||
|
assert table[0, a.Borg, a.Uloc] == 0
|
||||||
|
assert table[0, a.Borg, a.Uorg] == 0
|
||||||
|
assert table[0, a.Borg, a.O] == 0
|
||||||
|
|
||||||
|
# Not last token, prev action was I
|
||||||
|
assert table[0, a.Iper, a.Bper] == 0
|
||||||
|
assert table[0, a.Iper, a.Bloc] == 0
|
||||||
|
assert table[0, a.Iper, a.Borg] == 0
|
||||||
|
assert table[0, a.Iper, a.Iper] == 1
|
||||||
|
assert table[0, a.Iper, a.Iloc] == 0
|
||||||
|
assert table[0, a.Iper, a.Iorg] == 0
|
||||||
|
assert table[0, a.Iper, a.Lper] == 1
|
||||||
|
assert table[0, a.Iper, a.Lloc] == 0
|
||||||
|
assert table[0, a.Iper, a.Lorg] == 0
|
||||||
|
assert table[0, a.Iper, a.Uper] == 0
|
||||||
|
assert table[0, a.Iper, a.Uloc] == 0
|
||||||
|
assert table[0, a.Iper, a.Uorg] == 0
|
||||||
|
assert table[0, a.Iper, a.O] == 0
|
||||||
|
|
||||||
|
assert table[0, a.Iloc, a.Bper] == 0
|
||||||
|
assert table[0, a.Iloc, a.Bloc] == 0
|
||||||
|
assert table[0, a.Iloc, a.Borg] == 0
|
||||||
|
assert table[0, a.Iloc, a.Iper] == 0
|
||||||
|
assert table[0, a.Iloc, a.Iloc] == 1
|
||||||
|
assert table[0, a.Iloc, a.Iorg] == 0
|
||||||
|
assert table[0, a.Iloc, a.Lper] == 0
|
||||||
|
assert table[0, a.Iloc, a.Lloc] == 1
|
||||||
|
assert table[0, a.Iloc, a.Lorg] == 0
|
||||||
|
assert table[0, a.Iloc, a.Uper] == 0
|
||||||
|
assert table[0, a.Iloc, a.Uloc] == 0
|
||||||
|
assert table[0, a.Iloc, a.Uorg] == 0
|
||||||
|
assert table[0, a.Iloc, a.O] == 0
|
||||||
|
|
||||||
|
assert table[0, a.Iorg, a.Bper] == 0
|
||||||
|
assert table[0, a.Iorg, a.Bloc] == 0
|
||||||
|
assert table[0, a.Iorg, a.Borg] == 0
|
||||||
|
assert table[0, a.Iorg, a.Iper] == 0
|
||||||
|
assert table[0, a.Iorg, a.Iloc] == 0
|
||||||
|
assert table[0, a.Iorg, a.Iorg] == 1
|
||||||
|
assert table[0, a.Iorg, a.Lper] == 0
|
||||||
|
assert table[0, a.Iorg, a.Lloc] == 0
|
||||||
|
assert table[0, a.Iorg, a.Lorg] == 1
|
||||||
|
assert table[0, a.Iorg, a.Uper] == 0
|
||||||
|
assert table[0, a.Iorg, a.Uloc] == 0
|
||||||
|
assert table[0, a.Iorg, a.Uorg] == 0
|
||||||
|
assert table[0, a.Iorg, a.O] == 0
|
||||||
|
|
||||||
|
# Not last token, prev action was L
|
||||||
|
assert table[0, a.Lper, a.Bper] == 1
|
||||||
|
assert table[0, a.Lper, a.Bloc] == 1
|
||||||
|
assert table[0, a.Lper, a.Borg] == 1
|
||||||
|
assert table[0, a.Lper, a.Iper] == 0
|
||||||
|
assert table[0, a.Lper, a.Iloc] == 0
|
||||||
|
assert table[0, a.Lper, a.Iorg] == 0
|
||||||
|
assert table[0, a.Lper, a.Lper] == 0
|
||||||
|
assert table[0, a.Lper, a.Lloc] == 0
|
||||||
|
assert table[0, a.Lper, a.Lorg] == 0
|
||||||
|
assert table[0, a.Lper, a.Uper] == 1
|
||||||
|
assert table[0, a.Lper, a.Uloc] == 1
|
||||||
|
assert table[0, a.Lper, a.Uorg] == 1
|
||||||
|
assert table[0, a.Lper, a.O] == 1
|
||||||
|
|
||||||
|
assert table[0, a.Lloc, a.Bper] == 1
|
||||||
|
assert table[0, a.Lloc, a.Bloc] == 1
|
||||||
|
assert table[0, a.Lloc, a.Borg] == 1
|
||||||
|
assert table[0, a.Lloc, a.Iper] == 0
|
||||||
|
assert table[0, a.Lloc, a.Iloc] == 0
|
||||||
|
assert table[0, a.Lloc, a.Iorg] == 0
|
||||||
|
assert table[0, a.Lloc, a.Lper] == 0
|
||||||
|
assert table[0, a.Lloc, a.Lloc] == 0
|
||||||
|
assert table[0, a.Lloc, a.Lorg] == 0
|
||||||
|
assert table[0, a.Lloc, a.Uper] == 1
|
||||||
|
assert table[0, a.Lloc, a.Uloc] == 1
|
||||||
|
assert table[0, a.Lloc, a.Uorg] == 1
|
||||||
|
assert table[0, a.Lloc, a.O] == 1
|
||||||
|
|
||||||
|
assert table[0, a.Lorg, a.Bper] == 1
|
||||||
|
assert table[0, a.Lorg, a.Bloc] == 1
|
||||||
|
assert table[0, a.Lorg, a.Borg] == 1
|
||||||
|
assert table[0, a.Lorg, a.Iper] == 0
|
||||||
|
assert table[0, a.Lorg, a.Iloc] == 0
|
||||||
|
assert table[0, a.Lorg, a.Iorg] == 0
|
||||||
|
assert table[0, a.Lorg, a.Lper] == 0
|
||||||
|
assert table[0, a.Lorg, a.Lloc] == 0
|
||||||
|
assert table[0, a.Lorg, a.Lorg] == 0
|
||||||
|
assert table[0, a.Lorg, a.Uper] == 1
|
||||||
|
assert table[0, a.Lorg, a.Uloc] == 1
|
||||||
|
assert table[0, a.Lorg, a.Uorg] == 1
|
||||||
|
assert table[0, a.Lorg, a.O] == 1
|
||||||
|
|
||||||
|
# Not last token, prev action was U
|
||||||
|
assert table[0, a.Uper, a.Bper] == 1
|
||||||
|
assert table[0, a.Uper, a.Bloc] == 1
|
||||||
|
assert table[0, a.Uper, a.Borg] == 1
|
||||||
|
assert table[0, a.Uper, a.Iper] == 0
|
||||||
|
assert table[0, a.Uper, a.Iloc] == 0
|
||||||
|
assert table[0, a.Uper, a.Iorg] == 0
|
||||||
|
assert table[0, a.Uper, a.Lper] == 0
|
||||||
|
assert table[0, a.Uper, a.Lloc] == 0
|
||||||
|
assert table[0, a.Uper, a.Lorg] == 0
|
||||||
|
assert table[0, a.Uper, a.Uper] == 1
|
||||||
|
assert table[0, a.Uper, a.Uloc] == 1
|
||||||
|
assert table[0, a.Uper, a.Uorg] == 1
|
||||||
|
assert table[0, a.Uper, a.O] == 1
|
||||||
|
|
||||||
|
assert table[0, a.Uloc, a.Bper] == 1
|
||||||
|
assert table[0, a.Uloc, a.Bloc] == 1
|
||||||
|
assert table[0, a.Uloc, a.Borg] == 1
|
||||||
|
assert table[0, a.Uloc, a.Iper] == 0
|
||||||
|
assert table[0, a.Uloc, a.Iloc] == 0
|
||||||
|
assert table[0, a.Uloc, a.Iorg] == 0
|
||||||
|
assert table[0, a.Uloc, a.Lper] == 0
|
||||||
|
assert table[0, a.Uloc, a.Lloc] == 0
|
||||||
|
assert table[0, a.Uloc, a.Lorg] == 0
|
||||||
|
assert table[0, a.Uloc, a.Uper] == 1
|
||||||
|
assert table[0, a.Uloc, a.Uloc] == 1
|
||||||
|
assert table[0, a.Uloc, a.Uorg] == 1
|
||||||
|
assert table[0, a.Uloc, a.O] == 1
|
||||||
|
|
||||||
|
assert table[0, a.Uorg, a.Bper] == 1
|
||||||
|
assert table[0, a.Uorg, a.Bloc] == 1
|
||||||
|
assert table[0, a.Uorg, a.Borg] == 1
|
||||||
|
assert table[0, a.Uorg, a.Iper] == 0
|
||||||
|
assert table[0, a.Uorg, a.Iloc] == 0
|
||||||
|
assert table[0, a.Uorg, a.Iorg] == 0
|
||||||
|
assert table[0, a.Uorg, a.Lper] == 0
|
||||||
|
assert table[0, a.Uorg, a.Lloc] == 0
|
||||||
|
assert table[0, a.Uorg, a.Lorg] == 0
|
||||||
|
assert table[0, a.Uorg, a.Uper] == 1
|
||||||
|
assert table[0, a.Uorg, a.Uloc] == 1
|
||||||
|
assert table[0, a.Uorg, a.Uorg] == 1
|
||||||
|
assert table[0, a.Uorg, a.O] == 1
|
||||||
|
|
||||||
|
# Not last token, prev action was O
|
||||||
|
assert table[0, a.O, a.Bper] == 1
|
||||||
|
assert table[0, a.O, a.Bloc] == 1
|
||||||
|
assert table[0, a.O, a.Borg] == 1
|
||||||
|
assert table[0, a.O, a.Iper] == 0
|
||||||
|
assert table[0, a.O, a.Iloc] == 0
|
||||||
|
assert table[0, a.O, a.Iorg] == 0
|
||||||
|
assert table[0, a.O, a.Lper] == 0
|
||||||
|
assert table[0, a.O, a.Lloc] == 0
|
||||||
|
assert table[0, a.O, a.Lorg] == 0
|
||||||
|
assert table[0, a.O, a.Uper] == 1
|
||||||
|
assert table[0, a.O, a.Uloc] == 1
|
||||||
|
assert table[0, a.O, a.Uorg] == 1
|
||||||
|
assert table[0, a.O, a.O] == 1
|
||||||
|
|
||||||
|
# Last token, prev action was B
|
||||||
|
assert table[1, a.Bper, a.Bper] == 0
|
||||||
|
assert table[1, a.Bper, a.Bloc] == 0
|
||||||
|
assert table[1, a.Bper, a.Borg] == 0
|
||||||
|
assert table[1, a.Bper, a.Iper] == 0
|
||||||
|
assert table[1, a.Bper, a.Iloc] == 0
|
||||||
|
assert table[1, a.Bper, a.Iorg] == 0
|
||||||
|
assert table[1, a.Bper, a.Lper] == 1
|
||||||
|
assert table[1, a.Bper, a.Lloc] == 0
|
||||||
|
assert table[1, a.Bper, a.Lorg] == 0
|
||||||
|
assert table[1, a.Bper, a.Uper] == 0
|
||||||
|
assert table[1, a.Bper, a.Uloc] == 0
|
||||||
|
assert table[1, a.Bper, a.Uorg] == 0
|
||||||
|
assert table[1, a.Bper, a.O] == 0
|
||||||
|
|
||||||
|
assert table[1, a.Bloc, a.Bper] == 0
|
||||||
|
assert table[1, a.Bloc, a.Bloc] == 0
|
||||||
|
assert table[0, a.Bloc, a.Borg] == 0
|
||||||
|
assert table[1, a.Bloc, a.Iper] == 0
|
||||||
|
assert table[1, a.Bloc, a.Iloc] == 0
|
||||||
|
assert table[1, a.Bloc, a.Iorg] == 0
|
||||||
|
assert table[1, a.Bloc, a.Lper] == 0
|
||||||
|
assert table[1, a.Bloc, a.Lloc] == 1
|
||||||
|
assert table[1, a.Bloc, a.Lorg] == 0
|
||||||
|
assert table[1, a.Bloc, a.Uper] == 0
|
||||||
|
assert table[1, a.Bloc, a.Uloc] == 0
|
||||||
|
assert table[1, a.Bloc, a.Uorg] == 0
|
||||||
|
assert table[1, a.Bloc, a.O] == 0
|
||||||
|
|
||||||
|
assert table[1, a.Borg, a.Bper] == 0
|
||||||
|
assert table[1, a.Borg, a.Bloc] == 0
|
||||||
|
assert table[1, a.Borg, a.Borg] == 0
|
||||||
|
assert table[1, a.Borg, a.Iper] == 0
|
||||||
|
assert table[1, a.Borg, a.Iloc] == 0
|
||||||
|
assert table[1, a.Borg, a.Iorg] == 0
|
||||||
|
assert table[1, a.Borg, a.Lper] == 0
|
||||||
|
assert table[1, a.Borg, a.Lloc] == 0
|
||||||
|
assert table[1, a.Borg, a.Lorg] == 1
|
||||||
|
assert table[1, a.Borg, a.Uper] == 0
|
||||||
|
assert table[1, a.Borg, a.Uloc] == 0
|
||||||
|
assert table[1, a.Borg, a.Uorg] == 0
|
||||||
|
assert table[1, a.Borg, a.O] == 0
|
||||||
|
|
||||||
|
# Last token, prev action was I
|
||||||
|
assert table[1, a.Iper, a.Bper] == 0
|
||||||
|
assert table[1, a.Iper, a.Bloc] == 0
|
||||||
|
assert table[1, a.Iper, a.Borg] == 0
|
||||||
|
assert table[1, a.Iper, a.Iper] == 0
|
||||||
|
assert table[1, a.Iper, a.Iloc] == 0
|
||||||
|
assert table[1, a.Iper, a.Iorg] == 0
|
||||||
|
assert table[1, a.Iper, a.Lper] == 1
|
||||||
|
assert table[1, a.Iper, a.Lloc] == 0
|
||||||
|
assert table[1, a.Iper, a.Lorg] == 0
|
||||||
|
assert table[1, a.Iper, a.Uper] == 0
|
||||||
|
assert table[1, a.Iper, a.Uloc] == 0
|
||||||
|
assert table[1, a.Iper, a.Uorg] == 0
|
||||||
|
assert table[1, a.Iper, a.O] == 0
|
||||||
|
|
||||||
|
assert table[1, a.Iloc, a.Bper] == 0
|
||||||
|
assert table[1, a.Iloc, a.Bloc] == 0
|
||||||
|
assert table[1, a.Iloc, a.Borg] == 0
|
||||||
|
assert table[1, a.Iloc, a.Iper] == 0
|
||||||
|
assert table[1, a.Iloc, a.Iloc] == 0
|
||||||
|
assert table[1, a.Iloc, a.Iorg] == 0
|
||||||
|
assert table[1, a.Iloc, a.Lper] == 0
|
||||||
|
assert table[1, a.Iloc, a.Lloc] == 1
|
||||||
|
assert table[1, a.Iloc, a.Lorg] == 0
|
||||||
|
assert table[1, a.Iloc, a.Uper] == 0
|
||||||
|
assert table[1, a.Iloc, a.Uloc] == 0
|
||||||
|
assert table[1, a.Iloc, a.Uorg] == 0
|
||||||
|
assert table[1, a.Iloc, a.O] == 0
|
||||||
|
|
||||||
|
assert table[1, a.Iorg, a.Bper] == 0
|
||||||
|
assert table[1, a.Iorg, a.Bloc] == 0
|
||||||
|
assert table[1, a.Iorg, a.Borg] == 0
|
||||||
|
assert table[1, a.Iorg, a.Iper] == 0
|
||||||
|
assert table[1, a.Iorg, a.Iloc] == 0
|
||||||
|
assert table[1, a.Iorg, a.Iorg] == 0
|
||||||
|
assert table[1, a.Iorg, a.Lper] == 0
|
||||||
|
assert table[1, a.Iorg, a.Lloc] == 0
|
||||||
|
assert table[1, a.Iorg, a.Lorg] == 1
|
||||||
|
assert table[1, a.Iorg, a.Uper] == 0
|
||||||
|
assert table[1, a.Iorg, a.Uloc] == 0
|
||||||
|
assert table[1, a.Iorg, a.Uorg] == 0
|
||||||
|
assert table[1, a.Iorg, a.O] == 0
|
||||||
|
|
||||||
|
# Last token, prev action was L
|
||||||
|
assert table[1, a.Lper, a.Bper] == 0
|
||||||
|
assert table[1, a.Lper, a.Bloc] == 0
|
||||||
|
assert table[1, a.Lper, a.Borg] == 0
|
||||||
|
assert table[1, a.Lper, a.Iper] == 0
|
||||||
|
assert table[1, a.Lper, a.Iloc] == 0
|
||||||
|
assert table[1, a.Lper, a.Iorg] == 0
|
||||||
|
assert table[1, a.Lper, a.Lper] == 0
|
||||||
|
assert table[1, a.Lper, a.Lloc] == 0
|
||||||
|
assert table[1, a.Lper, a.Lorg] == 0
|
||||||
|
assert table[1, a.Lper, a.Uper] == 1
|
||||||
|
assert table[1, a.Lper, a.Uloc] == 1
|
||||||
|
assert table[1, a.Lper, a.Uorg] == 1
|
||||||
|
assert table[1, a.Lper, a.O] == 1
|
||||||
|
|
||||||
|
assert table[1, a.Lloc, a.Bper] == 0
|
||||||
|
assert table[1, a.Lloc, a.Bloc] == 0
|
||||||
|
assert table[1, a.Lloc, a.Borg] == 0
|
||||||
|
assert table[1, a.Lloc, a.Iper] == 0
|
||||||
|
assert table[1, a.Lloc, a.Iloc] == 0
|
||||||
|
assert table[1, a.Lloc, a.Iorg] == 0
|
||||||
|
assert table[1, a.Lloc, a.Lper] == 0
|
||||||
|
assert table[1, a.Lloc, a.Lloc] == 0
|
||||||
|
assert table[1, a.Lloc, a.Lorg] == 0
|
||||||
|
assert table[1, a.Lloc, a.Uper] == 1
|
||||||
|
assert table[1, a.Lloc, a.Uloc] == 1
|
||||||
|
assert table[1, a.Lloc, a.Uorg] == 1
|
||||||
|
assert table[1, a.Lloc, a.O] == 1
|
||||||
|
|
||||||
|
assert table[1, a.Lorg, a.Bper] == 0
|
||||||
|
assert table[1, a.Lorg, a.Bloc] == 0
|
||||||
|
assert table[1, a.Lorg, a.Borg] == 0
|
||||||
|
assert table[1, a.Lorg, a.Iper] == 0
|
||||||
|
assert table[1, a.Lorg, a.Iloc] == 0
|
||||||
|
assert table[1, a.Lorg, a.Iorg] == 0
|
||||||
|
assert table[1, a.Lorg, a.Lper] == 0
|
||||||
|
assert table[1, a.Lorg, a.Lloc] == 0
|
||||||
|
assert table[1, a.Lorg, a.Lorg] == 0
|
||||||
|
assert table[1, a.Lorg, a.Uper] == 1
|
||||||
|
assert table[1, a.Lorg, a.Uloc] == 1
|
||||||
|
assert table[1, a.Lorg, a.Uorg] == 1
|
||||||
|
assert table[1, a.Lorg, a.O] == 1
|
||||||
|
|
||||||
|
# Last token, prev action was U
|
||||||
|
assert table[1, a.Uper, a.Bper] == 0
|
||||||
|
assert table[1, a.Uper, a.Bloc] == 0
|
||||||
|
assert table[1, a.Uper, a.Borg] == 0
|
||||||
|
assert table[1, a.Uper, a.Iper] == 0
|
||||||
|
assert table[1, a.Uper, a.Iloc] == 0
|
||||||
|
assert table[1, a.Uper, a.Iorg] == 0
|
||||||
|
assert table[1, a.Uper, a.Lper] == 0
|
||||||
|
assert table[1, a.Uper, a.Lloc] == 0
|
||||||
|
assert table[1, a.Uper, a.Lorg] == 0
|
||||||
|
assert table[1, a.Uper, a.Uper] == 1
|
||||||
|
assert table[1, a.Uper, a.Uloc] == 1
|
||||||
|
assert table[1, a.Uper, a.Uorg] == 1
|
||||||
|
assert table[1, a.Uper, a.O] == 1
|
||||||
|
|
||||||
|
assert table[1, a.Uloc, a.Bper] == 0
|
||||||
|
assert table[1, a.Uloc, a.Bloc] == 0
|
||||||
|
assert table[1, a.Uloc, a.Borg] == 0
|
||||||
|
assert table[1, a.Uloc, a.Iper] == 0
|
||||||
|
assert table[1, a.Uloc, a.Iloc] == 0
|
||||||
|
assert table[1, a.Uloc, a.Iorg] == 0
|
||||||
|
assert table[1, a.Uloc, a.Lper] == 0
|
||||||
|
assert table[1, a.Uloc, a.Lloc] == 0
|
||||||
|
assert table[1, a.Uloc, a.Lorg] == 0
|
||||||
|
assert table[1, a.Uloc, a.Uper] == 1
|
||||||
|
assert table[1, a.Uloc, a.Uloc] == 1
|
||||||
|
assert table[1, a.Uloc, a.Uorg] == 1
|
||||||
|
assert table[1, a.Uloc, a.O] == 1
|
||||||
|
|
||||||
|
assert table[1, a.Uorg, a.Bper] == 0
|
||||||
|
assert table[1, a.Uorg, a.Bloc] == 0
|
||||||
|
assert table[1, a.Uorg, a.Borg] == 0
|
||||||
|
assert table[1, a.Uorg, a.Iper] == 0
|
||||||
|
assert table[1, a.Uorg, a.Iloc] == 0
|
||||||
|
assert table[1, a.Uorg, a.Iorg] == 0
|
||||||
|
assert table[1, a.Uorg, a.Lper] == 0
|
||||||
|
assert table[1, a.Uorg, a.Lloc] == 0
|
||||||
|
assert table[1, a.Uorg, a.Lorg] == 0
|
||||||
|
assert table[1, a.Uorg, a.Uper] == 1
|
||||||
|
assert table[1, a.Uorg, a.Uloc] == 1
|
||||||
|
assert table[1, a.Uorg, a.Uorg] == 1
|
||||||
|
assert table[1, a.Uorg, a.O] == 1
|
||||||
|
|
||||||
|
# Last token, prev action was O
|
||||||
|
assert table[1, a.O, a.Bper] == 0
|
||||||
|
assert table[1, a.O, a.Bloc] == 0
|
||||||
|
assert table[1, a.O, a.Borg] == 0
|
||||||
|
assert table[1, a.O, a.Iper] == 0
|
||||||
|
assert table[1, a.O, a.Iloc] == 0
|
||||||
|
assert table[1, a.O, a.Iorg] == 0
|
||||||
|
assert table[1, a.O, a.Lper] == 0
|
||||||
|
assert table[1, a.O, a.Lloc] == 0
|
||||||
|
assert table[1, a.O, a.Lorg] == 0
|
||||||
|
assert table[1, a.O, a.Uper] == 1
|
||||||
|
assert table[1, a.O, a.Uloc] == 1
|
||||||
|
assert table[1, a.O, a.Uorg] == 1
|
||||||
|
assert table[1, a.O, a.O] == 1
|
|
@ -34,7 +34,8 @@ def test_issue2179():
|
||||||
nlp2.add_pipe(nlp2.create_pipe("ner"))
|
nlp2.add_pipe(nlp2.create_pipe("ner"))
|
||||||
|
|
||||||
assert len(nlp2.get_pipe("ner").labels) == 0
|
assert len(nlp2.get_pipe("ner").labels) == 0
|
||||||
nlp2.get_pipe("ner").model.resize_output(nlp.get_pipe("ner").moves.n_moves)
|
model = nlp2.get_pipe("ner").model
|
||||||
|
model.attrs["resize_output"](model, nlp.get_pipe("ner").moves.n_moves)
|
||||||
nlp2.from_bytes(nlp.to_bytes())
|
nlp2.from_bytes(nlp.to_bytes())
|
||||||
assert "extra_labels" not in nlp2.get_pipe("ner").cfg
|
assert "extra_labels" not in nlp2.get_pipe("ner").cfg
|
||||||
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)
|
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)
|
||||||
|
|
|
@ -104,7 +104,8 @@ def test_issue3209():
|
||||||
assert ner.move_names == move_names
|
assert ner.move_names == move_names
|
||||||
nlp2 = English()
|
nlp2 = English()
|
||||||
nlp2.add_pipe(nlp2.create_pipe("ner"))
|
nlp2.add_pipe(nlp2.create_pipe("ner"))
|
||||||
nlp2.get_pipe("ner").model.resize_output(ner.moves.n_moves)
|
model = nlp2.get_pipe("ner").model
|
||||||
|
model.attrs["resize_output"](model, ner.moves.n_moves)
|
||||||
nlp2.from_bytes(nlp.to_bytes())
|
nlp2.from_bytes(nlp.to_bytes())
|
||||||
assert nlp2.get_pipe("ner").move_names == move_names
|
assert nlp2.get_pipe("ner").move_names == move_names
|
||||||
|
|
||||||
|
|
|
@ -110,10 +110,9 @@ def test_serialize_custom_nlp():
|
||||||
nlp2 = spacy.load(d)
|
nlp2 = spacy.load(d)
|
||||||
model = nlp2.get_pipe("parser").model
|
model = nlp2.get_pipe("parser").model
|
||||||
tok2vec = model.get_ref("tok2vec")
|
tok2vec = model.get_ref("tok2vec")
|
||||||
upper = model.upper
|
upper = model.get_ref("upper")
|
||||||
|
|
||||||
# check that we have the correct settings, not the default ones
|
# check that we have the correct settings, not the default ones
|
||||||
assert tok2vec.get_dim("nO") == 321
|
|
||||||
assert upper.get_dim("nI") == 65
|
assert upper.get_dim("nI") == 65
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,8 +130,7 @@ def test_serialize_parser():
|
||||||
nlp2 = spacy.load(d)
|
nlp2 = spacy.load(d)
|
||||||
model = nlp2.get_pipe("parser").model
|
model = nlp2.get_pipe("parser").model
|
||||||
tok2vec = model.get_ref("tok2vec")
|
tok2vec = model.get_ref("tok2vec")
|
||||||
upper = model.upper
|
upper = model.get_ref("upper")
|
||||||
|
|
||||||
# check that we have the correct settings, not the default ones
|
# check that we have the correct settings, not the default ones
|
||||||
assert upper.get_dim("nI") == 66
|
assert upper.get_dim("nI") == 66
|
||||||
assert tok2vec.get_dim("nO") == 333
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ def test_to_from_bytes(parser, blank_parser):
|
||||||
bytes_data = parser.to_bytes(exclude=["vocab"])
|
bytes_data = parser.to_bytes(exclude=["vocab"])
|
||||||
|
|
||||||
# the blank parser needs to be resized before we can call from_bytes
|
# the blank parser needs to be resized before we can call from_bytes
|
||||||
blank_parser.model.resize_output(parser.moves.n_moves)
|
blank_parser.model.attrs["resize_output"](blank_parser.model, parser.moves.n_moves)
|
||||||
blank_parser.from_bytes(bytes_data)
|
blank_parser.from_bytes(bytes_data)
|
||||||
assert blank_parser.model is not True
|
assert blank_parser.model is not True
|
||||||
assert blank_parser.moves.n_moves == parser.moves.n_moves
|
assert blank_parser.moves.n_moves == parser.moves.n_moves
|
||||||
|
|
|
@ -38,7 +38,7 @@ def test_util_get_package_path(package):
|
||||||
|
|
||||||
|
|
||||||
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
|
def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
|
||||||
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP)
|
model = PrecomputableAffine(nO=nO, nI=nI, nF=nF, nP=nP).initialize()
|
||||||
assert model.get_param("W").shape == (nF, nO, nP, nI)
|
assert model.get_param("W").shape == (nF, nO, nP, nI)
|
||||||
tensor = model.ops.alloc((10, nI))
|
tensor = model.ops.alloc((10, nI))
|
||||||
Y, get_dX = model.begin_update(tensor)
|
Y, get_dX = model.begin_update(tensor)
|
||||||
|
|
|
@ -571,8 +571,10 @@ def decaying(start, stop, decay):
|
||||||
curr -= decay
|
curr -= decay
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_words(examples, size, tuples=True, count_words=len):
|
def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2):
|
||||||
"""Create minibatches of a given number of words."""
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
|
themselves."""
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
elif isinstance(size, List):
|
elif isinstance(size, List):
|
||||||
|
@ -580,18 +582,36 @@ def minibatch_by_words(examples, size, tuples=True, count_words=len):
|
||||||
else:
|
else:
|
||||||
size_ = size
|
size_ = size
|
||||||
examples = iter(examples)
|
examples = iter(examples)
|
||||||
|
oversize = []
|
||||||
while True:
|
while True:
|
||||||
batch_size = next(size_)
|
batch_size = next(size_)
|
||||||
|
tol_size = batch_size * 0.2
|
||||||
batch = []
|
batch = []
|
||||||
while batch_size >= 0:
|
if oversize:
|
||||||
|
example = oversize.pop(0)
|
||||||
|
n_words = count_words(example.doc)
|
||||||
|
batch.append(example)
|
||||||
|
batch_size -= n_words
|
||||||
|
while batch_size >= 1:
|
||||||
try:
|
try:
|
||||||
example = next(examples)
|
example = next(examples)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if batch:
|
if oversize:
|
||||||
yield batch
|
examples = iter(oversize)
|
||||||
return
|
oversize = []
|
||||||
batch_size -= count_words(example.doc)
|
if batch:
|
||||||
batch.append(example)
|
yield batch
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
return
|
||||||
|
n_words = count_words(example.doc)
|
||||||
|
if n_words < (batch_size + tol_size):
|
||||||
|
batch_size -= n_words
|
||||||
|
batch.append(example)
|
||||||
|
else:
|
||||||
|
oversize.append(example)
|
||||||
if batch:
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user