mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Merge branch 'v2' into develop
* Move v2 parser into nn_parser.pyx * New TokenVectorEncoder class in pipeline.pyx * New spacy/_ml.py module Currently the two parsers live side-by-side, until we figure out how to organize them.
This commit is contained in:
commit
4b9d69f428
|
@ -1,4 +1,4 @@
|
|||
from __future__ import unicode_literals
|
||||
from __future__ import unicode_literals, print_function
|
||||
import plac
|
||||
import json
|
||||
import random
|
||||
|
@ -9,13 +9,27 @@ from spacy.syntax.nonproj import PseudoProjectivity
|
|||
from spacy.language import Language
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.tagger import Tagger
|
||||
from spacy.pipeline import DependencyParser, BeamDependencyParser
|
||||
from spacy.pipeline import DependencyParser, TokenVectorEncoder
|
||||
from spacy.syntax.parser import get_templates
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.language_data.tag_map import TAG_MAP as DEFAULT_TAG_MAP
|
||||
import spacy.attrs
|
||||
import io
|
||||
from thinc.neural.ops import CupyOps
|
||||
from thinc.neural import Model
|
||||
from spacy.es import Spanish
|
||||
from spacy.attrs import POS
|
||||
|
||||
|
||||
from thinc.neural import Model
|
||||
|
||||
|
||||
try:
|
||||
import cupy
|
||||
from thinc.neural.ops import CupyOps
|
||||
except:
|
||||
cupy = None
|
||||
|
||||
|
||||
def read_conllx(loc, n=0):
|
||||
|
@ -36,10 +50,11 @@ def read_conllx(loc, n=0):
|
|||
try:
|
||||
id_ = int(id_) - 1
|
||||
head = (int(head) - 1) if head != '0' else id_
|
||||
dep = 'ROOT' if dep == 'root' else dep
|
||||
dep = 'ROOT' if dep == 'root' else dep #'unlabelled'
|
||||
tag = pos+'__'+dep+'__'+morph
|
||||
Spanish.Defaults.tag_map[tag] = {POS: pos}
|
||||
tokens.append((id_, word, tag, head, dep, 'O'))
|
||||
except:
|
||||
print(line)
|
||||
raise
|
||||
tuples = [list(t) for t in zip(*tokens)]
|
||||
yield (None, [[tuples, []]])
|
||||
|
@ -48,22 +63,43 @@ def read_conllx(loc, n=0):
|
|||
break
|
||||
|
||||
|
||||
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
||||
def score_model(vocab, encoder, parser, Xs, ys, verbose=False):
|
||||
scorer = Scorer()
|
||||
for _, gold_doc in gold_docs:
|
||||
for (ids, words, tags, heads, deps, entities), _ in gold_doc:
|
||||
doc = Doc(vocab, words=words)
|
||||
tagger(doc)
|
||||
parser(doc)
|
||||
PseudoProjectivity.deprojectivize(doc)
|
||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||
scorer.score(doc, gold, verbose=verbose)
|
||||
correct = 0.
|
||||
total = 0.
|
||||
for doc, gold in zip(Xs, ys):
|
||||
doc = Doc(vocab, words=[w.text for w in doc])
|
||||
encoder(doc)
|
||||
parser(doc)
|
||||
PseudoProjectivity.deprojectivize(doc)
|
||||
scorer.score(doc, gold, verbose=verbose)
|
||||
for token, tag in zip(doc, gold.tags):
|
||||
if '_' in token.tag_:
|
||||
univ_guess, _ = token.tag_.split('_', 1)
|
||||
else:
|
||||
univ_guess = ''
|
||||
univ_truth, _ = tag.split('_', 1)
|
||||
correct += univ_guess == univ_truth
|
||||
total += 1
|
||||
return scorer
|
||||
|
||||
|
||||
def organize_data(vocab, train_sents):
|
||||
Xs = []
|
||||
ys = []
|
||||
for _, doc_sents in train_sents:
|
||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||
doc = Doc(vocab, words=words)
|
||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||
Xs.append(doc)
|
||||
ys.append(gold)
|
||||
return Xs, ys
|
||||
|
||||
|
||||
def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
|
||||
LangClass = spacy.util.get_lang_class(lang_name)
|
||||
train_sents = list(read_conllx(train_loc))
|
||||
dev_sents = list(read_conllx(dev_loc))
|
||||
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||
|
||||
actions = ArcEager.get_actions(gold_parses=train_sents)
|
||||
|
@ -112,28 +148,54 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None):
|
|||
_ = vocab[tag]
|
||||
if vocab.morphology.tag_map:
|
||||
for tag in tags:
|
||||
assert tag in vocab.morphology.tag_map, repr(tag)
|
||||
vocab.morphology.tag_map[tag] = {POS: tag.split('__', 1)[0]}
|
||||
tagger = Tagger(vocab)
|
||||
encoder = TokenVectorEncoder(vocab, width=64)
|
||||
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
|
||||
|
||||
for itn in range(30):
|
||||
loss = 0.
|
||||
for _, doc_sents in train_sents:
|
||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||
doc = Doc(vocab, words=words)
|
||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||
tagger(doc)
|
||||
loss += parser.update(doc, gold, itn=itn)
|
||||
doc = Doc(vocab, words=words)
|
||||
tagger.update(doc, gold)
|
||||
random.shuffle(train_sents)
|
||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
|
||||
nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser)
|
||||
nlp.end_training(model_dir)
|
||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
Xs, ys = organize_data(vocab, train_sents)
|
||||
dev_Xs, dev_ys = organize_data(vocab, dev_sents)
|
||||
with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer):
|
||||
docs = list(Xs)
|
||||
for doc in docs:
|
||||
encoder(doc)
|
||||
nn_loss = [0.]
|
||||
def track_progress():
|
||||
with encoder.tagger.use_params(optimizer.averages):
|
||||
with parser.model.use_params(optimizer.averages):
|
||||
scorer = score_model(vocab, encoder, parser, dev_Xs, dev_ys)
|
||||
itn = len(nn_loss)
|
||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc))
|
||||
nn_loss.append(0.)
|
||||
track_progress()
|
||||
trainer.each_epoch.append(track_progress)
|
||||
trainer.batch_size = 24
|
||||
trainer.nb_epoch = 40
|
||||
for docs, golds in trainer.iterate(Xs, ys, progress_bar=True):
|
||||
docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs]
|
||||
tokvecs, upd_tokvecs = encoder.begin_update(docs)
|
||||
for doc, tokvec in zip(docs, tokvecs):
|
||||
doc.tensor = tokvec
|
||||
d_tokvecs = parser.update(docs, golds, sgd=optimizer)
|
||||
upd_tokvecs(d_tokvecs, sgd=optimizer)
|
||||
encoder.update(docs, golds, sgd=optimizer)
|
||||
nlp = LangClass(vocab=vocab, parser=parser)
|
||||
scorer = score_model(vocab, encoder, parser, read_conllx(dev_loc))
|
||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
||||
#nlp.end_training(model_dir)
|
||||
#scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
#print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cProfile
|
||||
import pstats
|
||||
if 1:
|
||||
plac.call(main)
|
||||
else:
|
||||
cProfile.runctx("plac.call(main)", globals(), locals(), "Profile.prof")
|
||||
s = pstats.Stats("Profile.prof")
|
||||
s.strip_dirs().sort_stats("time").print_stats()
|
||||
|
||||
|
||||
plac.call(main)
|
||||
|
|
180
spacy/_ml.py
Normal file
180
spacy/_ml.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
from thinc.neural.ops import NumpyOps, CupyOps
|
||||
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural._classes.static_vectors import StaticVectors
|
||||
from thinc.neural._classes.batchnorm import BatchNorm
|
||||
from thinc.neural._classes.resnet import Residual
|
||||
from thinc import describe
|
||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
@describe.on_data(_set_dimensions_if_needed)
|
||||
@describe.attributes(
|
||||
nI=Dimension("Input size"),
|
||||
nF=Dimension("Number of features"),
|
||||
nO=Dimension("Output size"),
|
||||
W=Synapses("Weights matrix",
|
||||
lambda obj: (obj.nO, obj.nF, obj.nI),
|
||||
lambda W, ops: ops.xavier_uniform_init(W)),
|
||||
b=Biases("Bias vector",
|
||||
lambda obj: (obj.nO,)),
|
||||
d_W=Gradient("W"),
|
||||
d_b=Gradient("b")
|
||||
)
|
||||
class PrecomputableAffine(Model):
|
||||
def __init__(self, nO=None, nI=None, nF=None, **kwargs):
|
||||
Model.__init__(self, **kwargs)
|
||||
self.nO = nO
|
||||
self.nI = nI
|
||||
self.nF = nF
|
||||
|
||||
def begin_update(self, X, drop=0.):
|
||||
# X: (b, i)
|
||||
# Xf: (b, f, i)
|
||||
# dY: (b, o)
|
||||
# dYf: (b, f, o)
|
||||
#Yf = numpy.einsum('bi,ofi->bfo', X, self.W)
|
||||
Yf = self.ops.xp.tensordot(
|
||||
X, self.W, axes=[[1], [2]]).transpose((0, 2, 1))
|
||||
Yf += self.b
|
||||
def backward(dY_ids, sgd=None):
|
||||
dY, ids = dY_ids
|
||||
Xf = X[ids]
|
||||
|
||||
#dW = numpy.einsum('bo,bfi->ofi', dY, Xf)
|
||||
dW = self.ops.xp.tensordot(dY, Xf, axes=[[0], [0]])
|
||||
db = dY.sum(axis=0)
|
||||
#dXf = numpy.einsum('bo,ofi->bfi', dY, self.W)
|
||||
dXf = self.ops.xp.tensordot(dY, self.W, axes=[[1], [0]])
|
||||
|
||||
self.d_W += dW
|
||||
self.d_b += db
|
||||
|
||||
if sgd is not None:
|
||||
sgd(self._mem.weights, self._mem.gradient, key=self.id)
|
||||
return dXf
|
||||
return Yf, backward
|
||||
|
||||
|
||||
@describe.on_data(_set_dimensions_if_needed)
|
||||
@describe.attributes(
|
||||
nI=Dimension("Input size"),
|
||||
nF=Dimension("Number of features"),
|
||||
nP=Dimension("Number of pieces"),
|
||||
nO=Dimension("Output size"),
|
||||
W=Synapses("Weights matrix",
|
||||
lambda obj: (obj.nF, obj.nO, obj.nP, obj.nI),
|
||||
lambda W, ops: ops.xavier_uniform_init(W)),
|
||||
b=Biases("Bias vector",
|
||||
lambda obj: (obj.nO, obj.nP)),
|
||||
d_W=Gradient("W"),
|
||||
d_b=Gradient("b")
|
||||
)
|
||||
class PrecomputableMaxouts(Model):
|
||||
def __init__(self, nO=None, nI=None, nF=None, pieces=3, **kwargs):
|
||||
Model.__init__(self, **kwargs)
|
||||
self.nO = nO
|
||||
self.nP = pieces
|
||||
self.nI = nI
|
||||
self.nF = nF
|
||||
|
||||
def begin_update(self, X, drop=0.):
|
||||
# X: (b, i)
|
||||
# Yfp: (b, f, o, p)
|
||||
# Xf: (f, b, i)
|
||||
# dYp: (b, o, p)
|
||||
# W: (f, o, p, i)
|
||||
# b: (o, p)
|
||||
|
||||
# bi,opfi->bfop
|
||||
# bop,fopi->bfi
|
||||
# bop,fbi->opfi : fopi
|
||||
|
||||
tensordot = self.ops.xp.tensordot
|
||||
ascontiguous = self.ops.xp.ascontiguousarray
|
||||
|
||||
Yfp = tensordot(X, self.W, axes=[[1], [3]])
|
||||
Yfp += self.b
|
||||
|
||||
def backward(dYp_ids, sgd=None):
|
||||
dYp, ids = dYp_ids
|
||||
Xf = X[ids]
|
||||
|
||||
dXf = tensordot(dYp, self.W, axes=[[1, 2], [1,2]])
|
||||
dW = tensordot(dYp, Xf, axes=[[0], [0]])
|
||||
|
||||
self.d_W += dW.transpose((2, 0, 1, 3))
|
||||
self.d_b += dYp.sum(axis=0)
|
||||
|
||||
if sgd is not None:
|
||||
sgd(self._mem.weights, self._mem.gradient, key=self.id)
|
||||
return dXf
|
||||
return Yfp, backward
|
||||
|
||||
|
||||
def get_col(idx):
|
||||
def forward(X, drop=0.):
|
||||
if isinstance(X, numpy.ndarray):
|
||||
ops = NumpyOps()
|
||||
else:
|
||||
ops = CupyOps()
|
||||
assert len(X.shape) <= 3
|
||||
output = ops.xp.ascontiguousarray(X[:, idx])
|
||||
def backward(y, sgd=None):
|
||||
dX = ops.allocate(X.shape)
|
||||
dX[:, idx] += y
|
||||
return dX
|
||||
return output, backward
|
||||
return layerize(forward)
|
||||
|
||||
|
||||
def zero_init(model):
|
||||
def _hook(self, X, y=None):
|
||||
self.W.fill(0)
|
||||
model.on_data_hooks.append(_hook)
|
||||
return model
|
||||
|
||||
|
||||
def doc2feats(cols=None):
|
||||
cols = [ID, LOWER, PREFIX, SUFFIX, SHAPE]
|
||||
def forward(docs, drop=0.):
|
||||
feats = [doc.to_array(cols) for doc in docs]
|
||||
feats = [model.ops.asarray(f, dtype='uint64') for f in feats]
|
||||
return feats, None
|
||||
model = layerize(forward)
|
||||
model.cols = cols
|
||||
return model
|
||||
|
||||
def print_shape(prefix):
|
||||
def forward(X, drop=0.):
|
||||
return X, lambda dX, **kwargs: dX
|
||||
return layerize(forward)
|
||||
|
||||
|
||||
@layerize
|
||||
def get_token_vectors(tokens_attrs_vectors, drop=0.):
|
||||
ops = Model.ops
|
||||
tokens, attrs, vectors = tokens_attrs_vectors
|
||||
def backward(d_output, sgd=None):
|
||||
return (tokens, d_output)
|
||||
return vectors, backward
|
||||
|
||||
|
||||
@layerize
|
||||
def flatten(seqs, drop=0.):
|
||||
if isinstance(seqs[0], numpy.ndarray):
|
||||
ops = NumpyOps()
|
||||
else:
|
||||
ops = CupyOps()
|
||||
lengths = [len(seq) for seq in seqs]
|
||||
def finish_update(d_X, sgd=None):
|
||||
return ops.unflatten(d_X, lengths)
|
||||
X = ops.xp.vstack(seqs)
|
||||
return X, finish_update
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
|
||||
TAG_MAP = {
|
||||
"ADJ___": {"morph": "_", "pos": "ADJ"},
|
||||
"ADJ__AdpType=Prep": {"morph": "AdpType=Prep", "pos": "ADJ"},
|
||||
|
@ -302,5 +303,5 @@ TAG_MAP = {
|
|||
"VERB__VerbForm=Ger": {"morph": "VerbForm=Ger", "pos": "VERB"},
|
||||
"VERB__VerbForm=Inf": {"morph": "VerbForm=Inf", "pos": "VERB"},
|
||||
"X___": {"morph": "_", "pos": "X"},
|
||||
"SP": {"morph": "_", "pos": "SPACE"}
|
||||
"SP": {"morph": "_", "pos": "SPACE"},
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .syntax.parser cimport Parser
|
||||
from .syntax.beam_parser cimport BeamParser
|
||||
#from .syntax.beam_parser cimport BeamParser
|
||||
from .syntax.ner cimport BiluoPushDown
|
||||
from .syntax.arc_eager cimport ArcEager
|
||||
from .tagger cimport Tagger
|
||||
|
@ -13,9 +13,9 @@ cdef class DependencyParser(Parser):
|
|||
pass
|
||||
|
||||
|
||||
cdef class BeamEntityRecognizer(BeamParser):
|
||||
pass
|
||||
|
||||
|
||||
cdef class BeamDependencyParser(BeamParser):
|
||||
pass
|
||||
#cdef class BeamEntityRecognizer(BeamParser):
|
||||
# pass
|
||||
#
|
||||
#
|
||||
#cdef class BeamDependencyParser(BeamParser):
|
||||
# pass
|
||||
|
|
|
@ -1,16 +1,112 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from thinc.api import chain, layerize, with_getitem
|
||||
from thinc.neural import Model, Softmax
|
||||
import numpy
|
||||
cimport numpy as np
|
||||
|
||||
from .tokens.doc cimport Doc
|
||||
from .syntax.parser cimport Parser
|
||||
from .syntax.parser import get_templates as get_feature_templates
|
||||
from .syntax.beam_parser cimport BeamParser
|
||||
from .syntax.ner cimport BiluoPushDown
|
||||
from .syntax.arc_eager cimport ArcEager
|
||||
from .tagger import Tagger
|
||||
from .gold cimport GoldParse
|
||||
|
||||
# TODO: The disorganization here is pretty embarrassing. At least it's only
|
||||
# internals.
|
||||
from .syntax.parser import get_templates as get_feature_templates
|
||||
from .attrs import DEP, ENT_TYPE
|
||||
from thinc.api import add, layerize, chain, clone, concatenate
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
from thinc.neural.util import to_categorical
|
||||
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural._classes.resnet import Residual
|
||||
from thinc.neural._classes.batchnorm import BatchNorm as BN
|
||||
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from ._ml import flatten, get_col, doc2feats
|
||||
|
||||
|
||||
|
||||
class TokenVectorEncoder(object):
|
||||
'''Assign position-sensitive vectors to tokens, using a CNN or RNN.'''
|
||||
def __init__(self, vocab, token_vector_width, **cfg):
|
||||
self.vocab = vocab
|
||||
self.doc2feats = doc2feats()
|
||||
self.model = self.build_model(vocab.lang, token_vector_width, **cfg)
|
||||
self.tagger = chain(
|
||||
self.model,
|
||||
Softmax(self.vocab.morphology.n_tags,
|
||||
token_vector_width))
|
||||
|
||||
def build_model(self, lang, width, embed_size=5000, **cfg):
|
||||
cols = self.doc2feats.cols
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
|
||||
lower = get_col(cols.index(LOWER)) >> (HashEmbed(width, embed_size)
|
||||
+HashEmbed(width, embed_size))
|
||||
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size//2)
|
||||
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size//2)
|
||||
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size//2)
|
||||
|
||||
tok2vec = (
|
||||
flatten
|
||||
>> (lower | prefix | suffix | shape )
|
||||
>> Maxout(width, pieces=3)
|
||||
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> Maxout(width, width*3))
|
||||
)
|
||||
return tok2vec
|
||||
|
||||
def pipe(self, docs):
|
||||
docs = list(docs)
|
||||
self.predict_tags(docs)
|
||||
for doc in docs:
|
||||
yield doc
|
||||
|
||||
def __call__(self, doc):
|
||||
self.predict_tags([doc])
|
||||
|
||||
def begin_update(self, feats, drop=0.):
|
||||
tokvecs, bp_tokvecs = self.model.begin_update(feats, drop=drop)
|
||||
return tokvecs, bp_tokvecs
|
||||
|
||||
def predict_tags(self, docs, drop=0.):
|
||||
cdef Doc doc
|
||||
feats = self.doc2feats(docs)
|
||||
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
|
||||
scores, _ = self.tagger.begin_update(feats, drop=drop)
|
||||
idx = 0
|
||||
guesses = scores.argmax(axis=1)
|
||||
if not isinstance(guesses, numpy.ndarray):
|
||||
guesses = guesses.get()
|
||||
for i, doc in enumerate(docs):
|
||||
tag_ids = guesses[idx:idx+len(doc)]
|
||||
for j, tag_id in enumerate(tag_ids):
|
||||
doc.vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
|
||||
idx += 1
|
||||
|
||||
def update(self, docs_feats, golds, drop=0., sgd=None):
|
||||
cdef int i, j, idx
|
||||
cdef GoldParse gold
|
||||
docs, feats = docs_feats
|
||||
scores, finish_update = self.tagger.begin_update(feats, drop=drop)
|
||||
|
||||
tag_index = {tag: i for i, tag in enumerate(docs[0].vocab.morphology.tag_names)}
|
||||
|
||||
idx = 0
|
||||
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
||||
for gold in golds:
|
||||
for tag in gold.tags:
|
||||
correct[idx] = tag_index[tag]
|
||||
idx += 1
|
||||
correct = self.model.ops.xp.array(correct)
|
||||
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
|
||||
finish_update(d_scores, sgd)
|
||||
|
||||
|
||||
cdef class EntityRecognizer(Parser):
|
||||
|
@ -43,7 +139,6 @@ cdef class BeamEntityRecognizer(BeamParser):
|
|||
|
||||
cdef class DependencyParser(Parser):
|
||||
TransitionSystem = ArcEager
|
||||
|
||||
feature_templates = get_feature_templates('basic')
|
||||
|
||||
def add_label(self, label):
|
||||
|
@ -63,4 +158,5 @@ cdef class BeamDependencyParser(BeamParser):
|
|||
label = self.vocab.strings[label]
|
||||
|
||||
|
||||
__all__ = [Tagger, DependencyParser, EntityRecognizer, BeamDependencyParser, BeamEntityRecognizer]
|
||||
__all__ = ['Tagger', 'DependencyParser', 'EntityRecognizer', 'BeamDependencyParser',
|
||||
'BeamEntityRecognizer', 'TokenVectorEnoder']
|
||||
|
|
18
spacy/syntax/nn_parser.pxd
Normal file
18
spacy/syntax/nn_parser.pxd
Normal file
|
@ -0,0 +1,18 @@
|
|||
from thinc.typedefs cimport atom_t
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from .arc_eager cimport TransitionSystem
|
||||
from ..vocab cimport Vocab
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..structs cimport TokenC
|
||||
from ._state cimport StateC
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef readonly Vocab vocab
|
||||
cdef readonly object model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef readonly object cfg
|
||||
cdef public object feature_maps
|
||||
|
||||
#cdef int parseC(self, TokenC* tokens, int length, int nr_feat) nogil
|
677
spacy/syntax/nn_parser.pyx
Normal file
677
spacy/syntax/nn_parser.pyx
Normal file
|
@ -0,0 +1,677 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
# coding: utf-8
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
from collections import Counter
|
||||
import ujson
|
||||
|
||||
from libc.math cimport exp
|
||||
cimport cython
|
||||
cimport cython.parallel
|
||||
import cytoolz
|
||||
|
||||
import numpy.random
|
||||
cimport numpy as np
|
||||
|
||||
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
|
||||
from cpython.exc cimport PyErr_CheckSignals
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport malloc, calloc, free
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||
from thinc.linear.avgtron cimport AveragedPerceptron
|
||||
from thinc.linalg cimport VecVec
|
||||
from thinc.structs cimport SparseArrayC, FeatureC, ExampleC
|
||||
from thinc.extra.eg cimport Example
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from preshed.maps cimport MapStruct
|
||||
from preshed.maps cimport map_get
|
||||
|
||||
from thinc.api import layerize, chain
|
||||
from thinc.neural import BatchNorm, Model, Affine, ELU, ReLu, Maxout
|
||||
from thinc.neural.ops import NumpyOps
|
||||
|
||||
from ..util import get_cuda_stream
|
||||
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
||||
|
||||
from . import _parse_features
|
||||
from ._parse_features cimport CONTEXT_SIZE
|
||||
from ._parse_features cimport fill_context
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .nonproj import PseudoProjectivity
|
||||
from .transition_system import OracleError
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..strings cimport StringStore
|
||||
from ..gold cimport GoldParse
|
||||
from ..attrs cimport TAG, DEP
|
||||
|
||||
|
||||
def get_templates(*args, **kwargs):
|
||||
return []
|
||||
|
||||
USE_FTRL = True
|
||||
DEBUG = False
|
||||
def set_debug(val):
|
||||
global DEBUG
|
||||
DEBUG = val
|
||||
|
||||
|
||||
def get_greedy_model_for_batch(batch_size, tokvecs, lower_model, cuda_stream=None,
|
||||
drop=0.):
|
||||
'''Allow a model to be "primed" by pre-computing input features in bulk.
|
||||
|
||||
This is used for the parser, where we want to take a batch of documents,
|
||||
and compute vectors for each (token, position) pair. These vectors can then
|
||||
be reused, especially for beam-search.
|
||||
|
||||
Let's say we're using 12 features for each state, e.g. word at start of
|
||||
buffer, three words on stack, their children, etc. In the normal arc-eager
|
||||
system, a document of length N is processed in 2*N states. This means we'll
|
||||
create 2*N*12 feature vectors --- but if we pre-compute, we only need
|
||||
N*12 vector computations. The saving for beam-search is much better:
|
||||
if we have a beam of k, we'll normally make 2*N*12*K computations --
|
||||
so we can save the factor k. This also gives a nice CPU/GPU division:
|
||||
we can do all our hard maths up front, packed into large multiplications,
|
||||
and do the hard-to-program parsing on the CPU.
|
||||
'''
|
||||
gpu_cached, bp_features = lower_model.begin_update(tokvecs, drop=drop)
|
||||
cdef np.ndarray cached
|
||||
if not isinstance(gpu_cached, numpy.ndarray):
|
||||
cached = gpu_cached.get(stream=cuda_stream)
|
||||
else:
|
||||
cached = gpu_cached
|
||||
nF = gpu_cached.shape[1]
|
||||
nO = gpu_cached.shape[2]
|
||||
nP = gpu_cached.shape[3]
|
||||
ops = lower_model.ops
|
||||
features = numpy.zeros((batch_size, nO, nP), dtype='f')
|
||||
synchronized = False
|
||||
|
||||
def forward(token_ids, drop=0.):
|
||||
nonlocal synchronized
|
||||
if not synchronized and cuda_stream is not None:
|
||||
cuda_stream.synchronize()
|
||||
synchronized = True
|
||||
# This is tricky, but:
|
||||
# - Input to forward on CPU
|
||||
# - Output from forward on CPU
|
||||
# - Input to backward on GPU!
|
||||
# - Output from backward on GPU
|
||||
nonlocal features
|
||||
features = features[:len(token_ids)]
|
||||
features.fill(0)
|
||||
cdef float[:, :, ::1] feats = features
|
||||
cdef int[:, ::1] ids = token_ids
|
||||
_sum_features(<float*>&feats[0,0,0],
|
||||
<float*>cached.data, &ids[0,0],
|
||||
token_ids.shape[0], nF, nO*nP)
|
||||
|
||||
if nP >= 2:
|
||||
best, which = ops.maxout(features)
|
||||
else:
|
||||
best = features.reshape((features.shape[0], features.shape[1]))
|
||||
which = None
|
||||
|
||||
def backward(d_best, sgd=None):
|
||||
# This will usually be on GPU
|
||||
if isinstance(d_best, numpy.ndarray):
|
||||
d_best = ops.xp.array(d_best)
|
||||
if nP >= 2:
|
||||
d_features = ops.backprop_maxout(d_best, which, nP)
|
||||
else:
|
||||
d_features = d_best.reshape((d_best.shape[0], d_best.shape[1], 1))
|
||||
d_tokens = bp_features((d_features, token_ids), sgd)
|
||||
return d_tokens
|
||||
|
||||
return best, backward
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
cdef void _sum_features(float* output,
|
||||
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef const float* feature
|
||||
for b in range(B):
|
||||
for f in range(F):
|
||||
if token_ids[f] < 0:
|
||||
continue
|
||||
idx = token_ids[f] * F * O + f*O
|
||||
feature = &cached[idx]
|
||||
for i in range(O):
|
||||
output[i] += feature[i]
|
||||
output += O
|
||||
token_ids += F
|
||||
|
||||
|
||||
def get_batch_loss(TransitionSystem moves, states, golds, float[:, ::1] scores):
|
||||
cdef StateClass state
|
||||
cdef GoldParse gold
|
||||
cdef Pool mem = Pool()
|
||||
cdef int i
|
||||
is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
|
||||
costs = <float*>mem.alloc(moves.n_moves, sizeof(float))
|
||||
cdef np.ndarray d_scores = numpy.zeros((len(states), moves.n_moves), dtype='f',
|
||||
order='c')
|
||||
c_d_scores = <float*>d_scores.data
|
||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
memset(is_valid, 0, moves.n_moves * sizeof(int))
|
||||
memset(costs, 0, moves.n_moves * sizeof(float))
|
||||
moves.set_costs(is_valid, costs, state, gold)
|
||||
cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||
#cpu_regression_loss(c_d_scores,
|
||||
# costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||
c_d_scores += d_scores.shape[1]
|
||||
return d_scores
|
||||
|
||||
|
||||
cdef void cpu_log_loss(float* d_scores,
|
||||
const float* costs, const int* is_valid, const float* scores,
|
||||
int O) nogil:
|
||||
"""Do multi-label log loss"""
|
||||
cdef double max_, gmax, Z, gZ
|
||||
best = arg_max_if_gold(scores, costs, is_valid, O)
|
||||
guess = arg_max_if_valid(scores, is_valid, O)
|
||||
Z = 1e-10
|
||||
gZ = 1e-10
|
||||
max_ = scores[guess]
|
||||
gmax = scores[best]
|
||||
for i in range(O):
|
||||
if is_valid[i]:
|
||||
Z += exp(scores[i] - max_)
|
||||
if costs[i] <= costs[best]:
|
||||
gZ += exp(scores[i] - gmax)
|
||||
for i in range(O):
|
||||
if not is_valid[i]:
|
||||
d_scores[i] = 0.
|
||||
elif costs[i] <= costs[best]:
|
||||
d_scores[i] = (exp(scores[i]-max_) / Z) - (exp(scores[i]-gmax)/gZ)
|
||||
else:
|
||||
d_scores[i] = exp(scores[i]-max_) / Z
|
||||
|
||||
|
||||
cdef void cpu_regression_loss(float* d_scores,
|
||||
const float* costs, const int* is_valid, const float* scores,
|
||||
int O) nogil:
|
||||
cdef float eps = 2.
|
||||
best = arg_max_if_gold(scores, costs, is_valid, O)
|
||||
for i in range(O):
|
||||
if not is_valid[i]:
|
||||
d_scores[i] = 0.
|
||||
elif scores[i] < scores[best]:
|
||||
d_scores[i] = 0.
|
||||
else:
|
||||
# I doubt this is correct?
|
||||
# Looking for something like Huber loss
|
||||
diff = scores[i] - -costs[i]
|
||||
if diff > eps:
|
||||
d_scores[i] = eps
|
||||
elif diff < -eps:
|
||||
d_scores[i] = -eps
|
||||
else:
|
||||
d_scores[i] = diff
|
||||
|
||||
|
||||
def init_states(TransitionSystem moves, docs):
|
||||
cdef Doc doc
|
||||
cdef StateClass state
|
||||
offsets = []
|
||||
states = []
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
state = StateClass.init(doc.c, doc.length)
|
||||
moves.initialize_state(state.c)
|
||||
states.append(state)
|
||||
offsets.append(offset)
|
||||
offset += len(doc)
|
||||
return states, offsets
|
||||
|
||||
|
||||
def extract_token_ids(states, offsets=None, nF=1, nB=0, nS=2, nL=0, nR=0):
|
||||
cdef StateClass state
|
||||
cdef int n_tokens = states[0].nr_context_tokens(nF, nB, nS, nL, nR)
|
||||
ids = numpy.zeros((len(states), n_tokens), dtype='i', order='c')
|
||||
if offsets is None:
|
||||
offsets = [0] * len(states)
|
||||
for i, (state, offset) in enumerate(zip(states, offsets)):
|
||||
state.set_context_tokens(ids[i], nF, nB, nS, nL, nR)
|
||||
ids[i] += (ids[i] >= 0) * offset
|
||||
return ids
|
||||
|
||||
|
||||
_n_iter = 0
|
||||
@layerize
|
||||
def print_mean_variance(X, drop=0.):
|
||||
global _n_iter
|
||||
_n_iter += 1
|
||||
fwd_iter = _n_iter
|
||||
means = X.mean(axis=0)
|
||||
variance = X.var(axis=0)
|
||||
print(fwd_iter, "M", ', '.join(('%.2f' % m) for m in means))
|
||||
print(fwd_iter, "V", ', '.join(('%.2f' % m) for m in variance))
|
||||
def backward(dX, sgd=None):
|
||||
means = dX.mean(axis=0)
|
||||
variance = dX.var(axis=0)
|
||||
print(fwd_iter, "dM", ', '.join(('%.2f' % m) for m in means))
|
||||
print(fwd_iter, "dV", ', '.join(('%.2f' % m) for m in variance))
|
||||
return X, backward
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
"""
|
||||
Base class of the DependencyParser and EntityRecognizer.
|
||||
"""
|
||||
@classmethod
|
||||
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False, **cfg):
|
||||
"""
|
||||
Load the statistical model from the supplied path.
|
||||
|
||||
Arguments:
|
||||
path (Path):
|
||||
The path to load from.
|
||||
vocab (Vocab):
|
||||
The vocabulary. Must be shared by the documents to be processed.
|
||||
require (bool):
|
||||
Whether to raise an error if the files are not found.
|
||||
Returns (Parser):
|
||||
The newly constructed object.
|
||||
"""
|
||||
with (path / 'config.json').open() as file_:
|
||||
cfg = ujson.load(file_)
|
||||
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
||||
if (path / 'model').exists():
|
||||
self.model.load(str(path / 'model'))
|
||||
elif require:
|
||||
raise IOError(
|
||||
"Required file %s/model not found when loading" % str(path))
|
||||
return self
|
||||
|
||||
def __init__(self, Vocab vocab, TransitionSystem=None, model=None, **cfg):
|
||||
"""
|
||||
Create a Parser.
|
||||
|
||||
Arguments:
|
||||
vocab (Vocab):
|
||||
The vocabulary object. Must be shared with documents to be processed.
|
||||
model (thinc Model):
|
||||
The statistical model.
|
||||
Returns (Parser):
|
||||
The newly constructed object.
|
||||
"""
|
||||
if TransitionSystem is None:
|
||||
TransitionSystem = self.TransitionSystem
|
||||
self.vocab = vocab
|
||||
cfg['actions'] = TransitionSystem.get_actions(**cfg)
|
||||
self.moves = TransitionSystem(vocab.strings, cfg['actions'])
|
||||
if model is None:
|
||||
self.model, self.feature_maps = self.build_model(**cfg)
|
||||
else:
|
||||
self.model, self.feature_maps = model
|
||||
self.cfg = cfg
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||
|
||||
def build_model(self,
|
||||
hidden_width=128, token_vector_width=96, nr_vector=1000,
|
||||
nF=1, nB=1, nS=1, nL=1, nR=1, **cfg):
|
||||
nr_context_tokens = StateClass.nr_context_tokens(nF, nB, nS, nL, nR)
|
||||
with Model.use_device('cpu'):
|
||||
upper = chain(
|
||||
Maxout(hidden_width, hidden_width),
|
||||
#print_mean_variance,
|
||||
zero_init(Affine(self.moves.n_moves, hidden_width)))
|
||||
assert isinstance(upper.ops, NumpyOps)
|
||||
lower = PrecomputableMaxouts(hidden_width, nF=nr_context_tokens, nI=token_vector_width,
|
||||
pieces=cfg.get('maxout_pieces', 1))
|
||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||
upper.begin_training(upper.ops.allocate((500, hidden_width)))
|
||||
return upper, lower
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
"""
|
||||
Apply the parser or entity recognizer, setting the annotations onto the Doc object.
|
||||
|
||||
Arguments:
|
||||
doc (Doc): The document to be processed.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.parse_batch([tokens])
|
||||
|
||||
def pipe(self, stream, int batch_size=1000, int n_threads=2):
|
||||
"""
|
||||
Process a stream of documents.
|
||||
|
||||
Arguments:
|
||||
stream: The sequence of documents to process.
|
||||
batch_size (int):
|
||||
The number of documents to accumulate into a working set.
|
||||
n_threads (int):
|
||||
The number of threads with which to work on the buffer in parallel.
|
||||
Yields (Doc): Documents, in order.
|
||||
"""
|
||||
queue = []
|
||||
for doc in stream:
|
||||
queue.append(doc)
|
||||
if len(queue) == batch_size:
|
||||
self.parse_batch(queue)
|
||||
for doc in queue:
|
||||
self.moves.finalize_doc(doc)
|
||||
yield doc
|
||||
queue = []
|
||||
if queue:
|
||||
self.parse_batch(queue)
|
||||
for doc in queue:
|
||||
self.moves.finalize_doc(doc)
|
||||
yield doc
|
||||
|
||||
def parse_batch(self, docs_tokvecs):
|
||||
cdef:
|
||||
int nC
|
||||
Doc doc
|
||||
StateClass state
|
||||
np.ndarray py_scores
|
||||
int[500] is_valid # Hacks for now
|
||||
|
||||
cuda_stream = get_cuda_stream()
|
||||
docs, tokvecs = docs_tokvecs
|
||||
lower_model = get_greedy_model_for_batch(len(docs), tokvecs, self.feature_maps,
|
||||
cuda_stream)
|
||||
upper_model = self.model
|
||||
|
||||
states, offsets = init_states(self.moves, docs)
|
||||
all_states = list(states)
|
||||
todo = [st for st in zip(states, offsets) if not st[0].py_is_final()]
|
||||
|
||||
while todo:
|
||||
states, offsets = zip(*todo)
|
||||
token_ids = extract_token_ids(states, offsets=offsets)
|
||||
|
||||
py_scores = upper_model(lower_model(token_ids)[0])
|
||||
scores = <float*>py_scores.data
|
||||
nC = py_scores.shape[1]
|
||||
for state, offset in zip(states, offsets):
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
guess = arg_max_if_valid(scores, is_valid, nC)
|
||||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
scores += nC
|
||||
todo = [st for st in todo if not st[0].py_is_final()]
|
||||
|
||||
for state, doc in zip(all_states, docs):
|
||||
self.moves.finalize_state(state.c)
|
||||
for i in range(doc.length):
|
||||
doc.c[i] = state.c._sent[i]
|
||||
self.moves.finalize_doc(doc)
|
||||
|
||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None):
|
||||
cdef:
|
||||
int nC
|
||||
Doc doc
|
||||
StateClass state
|
||||
np.ndarray scores
|
||||
|
||||
docs, tokvecs = docs_tokvecs
|
||||
cuda_stream = get_cuda_stream()
|
||||
lower_model = get_greedy_model_for_batch(len(docs),
|
||||
tokvecs, self.feature_maps, cuda_stream=cuda_stream,
|
||||
drop=drop)
|
||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||
return self.update(([docs], tokvecs), [golds], drop=drop)
|
||||
for gold in golds:
|
||||
self.moves.preprocess_gold(gold)
|
||||
|
||||
states, offsets = init_states(self.moves, docs)
|
||||
|
||||
todo = zip(states, offsets, golds)
|
||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||
|
||||
cdef Pool mem = Pool()
|
||||
is_valid = <int*>mem.alloc(len(states) * self.moves.n_moves, sizeof(int))
|
||||
costs = <float*>mem.alloc(len(states) * self.moves.n_moves, sizeof(float))
|
||||
|
||||
upper_model = self.model
|
||||
d_tokens = self.feature_maps.ops.allocate(tokvecs.shape)
|
||||
backprops = []
|
||||
n_tokens = tokvecs.shape[0]
|
||||
nF = self.feature_maps.nF
|
||||
loss = 0.
|
||||
total = 1e-4
|
||||
follow_gold = False
|
||||
cupy = self.feature_maps.ops.xp
|
||||
while len(todo) >= 4:
|
||||
states, offsets, golds = zip(*todo)
|
||||
|
||||
token_ids = extract_token_ids(states, offsets=offsets)
|
||||
lower, bp_lower = lower_model(token_ids, drop=drop)
|
||||
scores, bp_scores = upper_model.begin_update(lower, drop=drop)
|
||||
|
||||
d_scores = get_batch_loss(self.moves, states, golds, scores)
|
||||
loss += numpy.abs(d_scores).sum()
|
||||
total += d_scores.shape[0]
|
||||
d_lower = bp_scores(d_scores, sgd=sgd)
|
||||
|
||||
if isinstance(tokvecs, cupy.ndarray):
|
||||
gpu_tok_ids = cupy.ndarray(token_ids.shape, dtype='i', order='C')
|
||||
gpu_d_lower = cupy.ndarray(d_lower.shape, dtype='f', order='C')
|
||||
gpu_tok_ids.set(token_ids, stream=cuda_stream)
|
||||
gpu_d_lower.set(d_lower, stream=cuda_stream)
|
||||
backprops.append((gpu_tok_ids, gpu_d_lower, bp_lower))
|
||||
else:
|
||||
backprops.append((token_ids, d_lower, bp_lower))
|
||||
|
||||
c_scores = <float*>scores.data
|
||||
for state, gold in zip(states, golds):
|
||||
if follow_gold:
|
||||
self.moves.set_costs(is_valid, costs, state, gold)
|
||||
guess = arg_max_if_gold(c_scores, costs, is_valid, scores.shape[1])
|
||||
else:
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
|
||||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
c_scores += scores.shape[1]
|
||||
|
||||
todo = filter(lambda sp: not sp[0].py_is_final(), todo)
|
||||
# This tells CUDA to block --- so we know our copies are complete.
|
||||
cuda_stream.synchronize()
|
||||
for token_ids, d_lower, bp_lower in backprops:
|
||||
d_state_features = bp_lower(d_lower, sgd=sgd)
|
||||
active_feats = token_ids * (token_ids >= 0)
|
||||
active_feats = active_feats.reshape((token_ids.shape[0], token_ids.shape[1], 1))
|
||||
if hasattr(self.feature_maps.ops.xp, 'scatter_add'):
|
||||
self.feature_maps.ops.xp.scatter_add(d_tokens,
|
||||
token_ids, d_state_features * active_feats)
|
||||
else:
|
||||
self.model.ops.xp.add.at(d_tokens,
|
||||
token_ids, d_state_features * active_feats)
|
||||
return d_tokens, loss / total
|
||||
|
||||
def step_through(self, Doc doc, GoldParse gold=None):
|
||||
"""
|
||||
Set up a stepwise state, to introspect and control the transition sequence.
|
||||
|
||||
Arguments:
|
||||
doc (Doc): The document to step through.
|
||||
gold (GoldParse): Optional gold parse
|
||||
Returns (StepwiseState):
|
||||
A state object, to step through the annotation process.
|
||||
"""
|
||||
return StepwiseState(self, doc, gold=gold)
|
||||
|
||||
def from_transition_sequence(self, Doc doc, sequence):
|
||||
"""Control the annotations on a document by specifying a transition sequence
|
||||
to follow.
|
||||
|
||||
Arguments:
|
||||
doc (Doc): The document to annotate.
|
||||
sequence: A sequence of action names, as unicode strings.
|
||||
Returns: None
|
||||
"""
|
||||
with self.step_through(doc) as stepwise:
|
||||
for transition in sequence:
|
||||
stepwise.transition(transition)
|
||||
|
||||
def add_label(self, label):
|
||||
# Doesn't set label into serializer -- subclasses override it to do that.
|
||||
for action in self.moves.action_types:
|
||||
added = self.moves.add_action(action, label)
|
||||
if added:
|
||||
# Important that the labels be stored as a list! We need the
|
||||
# order, or the model goes out of synch
|
||||
self.cfg.setdefault('extra_labels', []).append(label)
|
||||
|
||||
|
||||
cdef class StepwiseState:
|
||||
cdef readonly StateClass stcls
|
||||
cdef readonly Example eg
|
||||
cdef readonly Doc doc
|
||||
cdef readonly GoldParse gold
|
||||
cdef readonly Parser parser
|
||||
|
||||
def __init__(self, Parser parser, Doc doc, GoldParse gold=None):
|
||||
self.parser = parser
|
||||
self.doc = doc
|
||||
if gold is not None:
|
||||
self.gold = gold
|
||||
self.parser.moves.preprocess_gold(self.gold)
|
||||
else:
|
||||
self.gold = GoldParse(doc)
|
||||
self.stcls = StateClass.init(doc.c, doc.length)
|
||||
self.parser.moves.initialize_state(self.stcls.c)
|
||||
self.eg = Example(
|
||||
nr_class=self.parser.moves.n_moves,
|
||||
nr_atom=CONTEXT_SIZE,
|
||||
nr_feat=self.parser.model.nr_feat)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.finish()
|
||||
|
||||
@property
|
||||
def is_final(self):
|
||||
return self.stcls.is_final()
|
||||
|
||||
@property
|
||||
def stack(self):
|
||||
return self.stcls.stack
|
||||
|
||||
@property
|
||||
def queue(self):
|
||||
return self.stcls.queue
|
||||
|
||||
@property
|
||||
def heads(self):
|
||||
return [self.stcls.H(i) for i in range(self.stcls.c.length)]
|
||||
|
||||
@property
|
||||
def deps(self):
|
||||
return [self.doc.vocab.strings[self.stcls.c._sent[i].dep]
|
||||
for i in range(self.stcls.c.length)]
|
||||
|
||||
@property
|
||||
def costs(self):
|
||||
"""
|
||||
Find the action-costs for the current state.
|
||||
"""
|
||||
if not self.gold:
|
||||
raise ValueError("Can't set costs: No GoldParse provided")
|
||||
self.parser.moves.set_costs(self.eg.c.is_valid, self.eg.c.costs,
|
||||
self.stcls, self.gold)
|
||||
costs = {}
|
||||
for i in range(self.parser.moves.n_moves):
|
||||
if not self.eg.c.is_valid[i]:
|
||||
continue
|
||||
transition = self.parser.moves.c[i]
|
||||
name = self.parser.moves.move_name(transition.move, transition.label)
|
||||
costs[name] = self.eg.c.costs[i]
|
||||
return costs
|
||||
|
||||
def predict(self):
|
||||
self.eg.reset()
|
||||
#self.eg.c.nr_feat = self.parser.model.set_featuresC(self.eg.c.atoms, self.eg.c.features,
|
||||
# self.stcls.c)
|
||||
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls.c)
|
||||
#self.parser.model.set_scoresC(self.eg.c.scores,
|
||||
# self.eg.c.features, self.eg.c.nr_feat)
|
||||
|
||||
cdef Transition action = self.parser.moves.c[self.eg.guess]
|
||||
return self.parser.moves.move_name(action.move, action.label)
|
||||
|
||||
def transition(self, action_name=None):
|
||||
if action_name is None:
|
||||
action_name = self.predict()
|
||||
moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3}
|
||||
if action_name == '_':
|
||||
action_name = self.predict()
|
||||
action = self.parser.moves.lookup_transition(action_name)
|
||||
elif action_name == 'L' or action_name == 'R':
|
||||
self.predict()
|
||||
move = moves[action_name]
|
||||
clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c,
|
||||
self.eg.c.nr_class)
|
||||
action = self.parser.moves.c[clas]
|
||||
else:
|
||||
action = self.parser.moves.lookup_transition(action_name)
|
||||
action.do(self.stcls.c, action.label)
|
||||
|
||||
def finish(self):
|
||||
if self.stcls.is_final():
|
||||
self.parser.moves.finalize_state(self.stcls.c)
|
||||
self.doc.set_parse(self.stcls.c._sent)
|
||||
self.parser.moves.finalize_doc(self.doc)
|
||||
|
||||
|
||||
class ParserStateError(ValueError):
|
||||
def __init__(self, doc):
|
||||
ValueError.__init__(self,
|
||||
"Error analysing doc -- no valid actions available. This should "
|
||||
"never happen, so please report the error on the issue tracker. "
|
||||
"Here's the thread to do so --- reopen it if it's closed:\n"
|
||||
"https://github.com/spacy-io/spaCy/issues/429\n"
|
||||
"Please include the text that the parser failed on, which is:\n"
|
||||
"%s" % repr(doc.text))
|
||||
|
||||
|
||||
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, const int* is_valid, int n) nogil:
|
||||
# Find minimum cost
|
||||
cdef float cost = 1
|
||||
for i in range(n):
|
||||
if is_valid[i] and costs[i] < cost:
|
||||
cost = costs[i]
|
||||
# Now find best-scoring with that cost
|
||||
cdef int best = -1
|
||||
for i in range(n):
|
||||
if costs[i] <= cost and is_valid[i]:
|
||||
if best == -1 or scores[i] > scores[best]:
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil:
|
||||
cdef int best = -1
|
||||
for i in range(n):
|
||||
if is_valid[i] >= 1:
|
||||
if best == -1 or scores[i] > scores[best]:
|
||||
best = i
|
||||
return best
|
||||
|
||||
|
||||
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
|
||||
int nr_class) except -1:
|
||||
cdef weight_t score = 0
|
||||
cdef int mode = -1
|
||||
cdef int i
|
||||
for i in range(nr_class):
|
||||
if actions[i].move == move and (mode == -1 or scores[i] >= score):
|
||||
mode = i
|
||||
score = scores[i]
|
||||
return mode
|
|
@ -1,6 +1,7 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
cimport cython
|
||||
|
||||
from ..structs cimport TokenC, Entity
|
||||
|
||||
|
@ -8,7 +9,7 @@ from ..vocab cimport EMPTY_LEXEME
|
|||
from ._state cimport StateC
|
||||
|
||||
|
||||
|
||||
@cython.final
|
||||
cdef class StateClass:
|
||||
cdef Pool mem
|
||||
cdef StateC* c
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
# coding: utf-8
|
||||
# cython: infer_types=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.stdint cimport uint32_t
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport Entity
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..symbols cimport punct
|
||||
from ..attrs cimport IS_SPACE
|
||||
from ..attrs cimport attr_id_t
|
||||
from ..tokens.token cimport Token
|
||||
|
||||
|
||||
cdef class StateClass:
|
||||
|
@ -27,6 +30,13 @@ cdef class StateClass:
|
|||
def queue(self):
|
||||
return {self.B(i) for i in range(self.c.buffer_length())}
|
||||
|
||||
@property
|
||||
def token_vector_lenth(self):
|
||||
return self.doc.tensor.shape[1]
|
||||
|
||||
def py_is_final(self):
|
||||
return self.c.is_final()
|
||||
|
||||
def print_state(self, words):
|
||||
words = list(words) + ['_']
|
||||
top = words[self.S(0)] + '_%d' % self.S_(0).head
|
||||
|
@ -35,3 +45,43 @@ cdef class StateClass:
|
|||
n0 = words[self.B(0)]
|
||||
n1 = words[self.B(1)]
|
||||
return ' '.join((third, second, top, '|', n0, n1))
|
||||
|
||||
@classmethod
|
||||
def nr_context_tokens(cls, int nF, int nB, int nS, int nL, int nR):
|
||||
return 13
|
||||
|
||||
def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2,
|
||||
nL=2, nR=2):
|
||||
output[0] = self.B(0)
|
||||
output[1] = self.B(1)
|
||||
output[2] = self.S(0)
|
||||
output[3] = self.S(1)
|
||||
output[4] = self.S(2)
|
||||
output[5] = self.L(self.S(0), 1)
|
||||
output[6] = self.L(self.S(0), 2)
|
||||
output[6] = self.R(self.S(0), 1)
|
||||
output[7] = self.L(self.B(0), 1)
|
||||
output[8] = self.R(self.S(0), 2)
|
||||
output[9] = self.L(self.S(1), 1)
|
||||
output[10] = self.L(self.S(1), 2)
|
||||
output[11] = self.R(self.S(1), 1)
|
||||
output[12] = self.R(self.S(1), 2)
|
||||
|
||||
def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names):
|
||||
cdef int i, j, tok_i
|
||||
for i in range(tokens.shape[0]):
|
||||
tok_i = tokens[i]
|
||||
if tok_i >= 0:
|
||||
token = &self.c._sent[tok_i]
|
||||
for j in range(names.shape[0]):
|
||||
vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j])
|
||||
else:
|
||||
vals[i] = 0
|
||||
|
||||
def set_token_vectors(self, tokvecs,
|
||||
all_tokvecs, int[:] indices):
|
||||
for i in range(indices.shape[0]):
|
||||
if indices[i] >= 0:
|
||||
tokvecs[i] = all_tokvecs[indices[i]]
|
||||
else:
|
||||
tokvecs[i] = 0
|
||||
|
|
|
@ -32,7 +32,7 @@ cdef class Doc:
|
|||
cdef public object _vector
|
||||
cdef public object _vector_norm
|
||||
|
||||
cdef public np.ndarray tensor
|
||||
cdef public object tensor
|
||||
cdef public object user_data
|
||||
|
||||
cdef TokenC* c
|
||||
|
|
|
@ -149,6 +149,16 @@ def parse_package_meta(package_path, require=True):
|
|||
return None
|
||||
|
||||
|
||||
def get_cuda_stream(require=False):
|
||||
# TODO: Error and tell to install chainer if not found
|
||||
# Requires GPU
|
||||
try:
|
||||
from cupy.cuda.stream import Stream
|
||||
except ImportError:
|
||||
return None
|
||||
return Stream()
|
||||
|
||||
|
||||
def read_regex(path):
|
||||
path = ensure_path(path)
|
||||
with path.open() as file_:
|
||||
|
|
Loading…
Reference in New Issue
Block a user