* Move to thinc 5.0

This commit is contained in:
Matthew Honnibal 2016-01-29 03:58:55 +01:00
parent 9721502c81
commit b0718b6ee1
2 changed files with 26 additions and 30 deletions

View File

@ -1,14 +1,13 @@
from thinc.api cimport AveragedPerceptron
from thinc.api cimport ExampleC
from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.extra.eg cimport Example
from thinc.structs cimport ExampleC
from .structs cimport TokenC
from .vocab cimport Vocab
cdef class TaggerModel(AveragedPerceptron):
cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *
cdef void set_costs(self, ExampleC* eg, int gold) except *
cdef void update(self, ExampleC* eg) except *
cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *
cdef class Tagger:

View File

@ -5,8 +5,10 @@ from libc.string cimport memset
from cymem.cymem cimport Pool
from thinc.typedefs cimport atom_t, weight_t
from thinc.api cimport Example, ExampleC
from thinc.features cimport ConjunctionExtracter
from thinc.extra.eg cimport Example
from thinc.structs cimport ExampleC
from thinc.linear.avgtron cimport AveragedPerceptron
from thinc.linalg cimport VecVec
from .typedefs cimport attr_t
from .tokens.doc cimport Doc
@ -69,7 +71,7 @@ cpdef enum:
cdef class TaggerModel(AveragedPerceptron):
cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *:
cdef void set_featuresC(self, ExampleC* eg, const TokenC* tokens, int i) except *:
_fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])
_fill_from_token(&eg.atoms[P1_orth], &tokens[i-1])
_fill_from_token(&eg.atoms[W_orth], &tokens[i])
@ -78,9 +80,6 @@ cdef class TaggerModel(AveragedPerceptron):
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
cdef void update(self, ExampleC* eg) except *:
self.updater.update(eg)
cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
context[0] = t.lex.lower
@ -143,8 +142,7 @@ cdef class Tagger:
@classmethod
def blank(cls, vocab, templates):
model = TaggerModel(vocab.morphology.n_tags,
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
model = TaggerModel(N_CONTEXT_FIELDS, templates)
return cls(vocab, model)
@classmethod
@ -159,13 +157,9 @@ cdef class Tagger:
# 'pos', 'templates.json',
# default=cls.default_templates())
model = TaggerModel(vocab.morphology.n_tags,
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
if pkg.has_file('pos', 'model'): # TODO: really optional?
model = TaggerModel(templates)
if pkg.has_file('pos', 'model'):
model.load(pkg.file_path('pos', 'model'))
return cls(vocab, model)
def __init__(self, Vocab vocab, TaggerModel model):
@ -202,15 +196,16 @@ cdef class Tagger:
return 0
cdef Pool mem = Pool()
cdef ExampleC eg
cdef int i, tag
cdef Example eg = Example(self.vocab.morphology.n_tags)
for i in range(tokens.length):
if tokens.c[i].pos == 0:
eg = self.model.allocate(mem)
self.model.set_features(&eg, tokens.c, i)
self.model.set_prediction(&eg)
self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess)
self.model.set_featuresC(&eg.c, tokens.c, i)
self.model.set_scoresC(eg.c.scores,
eg.c.features, eg.c.nr_feat)
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
self.vocab.morphology.assign_tag(&tokens.c[i], guess)
tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length
@ -219,18 +214,20 @@ cdef class Tagger:
golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs]
cdef int correct = 0
cdef Pool mem = Pool()
cdef ExampleC eg
cdef Example eg = Example(self.vocab.morphology.n_tags)
for i in range(tokens.length):
eg = self.model.allocate(mem)
self.model.set_features(&eg, tokens.c, i)
self.model.set_costs(&eg, golds[i])
self.model.set_prediction(&eg)
self.model.update(&eg)
self.model.set_featuresC(&eg.c, tokens.c, i)
eg.set_label(golds[i])
self.model.set_scoresC(eg.c.scores,
eg.c.features, eg.c.nr_feat)
self.model.updateC(&eg.c)
self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess)
correct += eg.cost == 0
self.freqs[TAG][tokens.c[i].tag] += 1
eg.wipe(tuple())
tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length
return correct