mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
* Refactor sense tagger to get rid of intermediary layers
This commit is contained in:
parent
6735439abf
commit
2fbcdd0ea8
|
@ -1,13 +1,17 @@
|
||||||
from thinc.api cimport Example
|
|
||||||
from thinc.typedefs cimport atom_t
|
|
||||||
|
|
||||||
from .typedefs cimport flags_t
|
from .typedefs cimport flags_t
|
||||||
from .structs cimport TokenC
|
from .structs cimport TokenC
|
||||||
from .strings cimport StringStore
|
from .strings cimport StringStore
|
||||||
from .tokens cimport Tokens
|
from .tokens cimport Tokens
|
||||||
from ._ml cimport Model
|
|
||||||
from .senses cimport POS_SENSES, N_SENSES, encode_sense_strs
|
from .senses cimport POS_SENSES, N_SENSES, encode_sense_strs
|
||||||
from .gold cimport GoldParse
|
from .gold cimport GoldParse
|
||||||
|
from .parts_of_speech cimport NOUN, VERB
|
||||||
|
|
||||||
|
from thinc.learner cimport LinearModel
|
||||||
|
from thinc.features cimport Extractor
|
||||||
|
|
||||||
|
from thinc.typedefs cimport atom_t, weight_t, feat_t
|
||||||
|
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,6 +177,8 @@ cdef int fill_token(atom_t* ctxt, const TokenC* token) except -1:
|
||||||
|
|
||||||
|
|
||||||
cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1:
|
cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1:
|
||||||
|
# NB: we have padding to keep us safe here
|
||||||
|
# See tokens.pyx
|
||||||
fill_token(&ctxt[P2W], token - 2)
|
fill_token(&ctxt[P2W], token - 2)
|
||||||
fill_token(&ctxt[P1W], token - 1)
|
fill_token(&ctxt[P1W], token - 1)
|
||||||
|
|
||||||
|
@ -185,62 +191,79 @@ cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1:
|
||||||
|
|
||||||
cdef class SenseTagger:
|
cdef class SenseTagger:
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
cdef readonly Model model
|
cdef readonly LinearModel model
|
||||||
|
cdef readonly Extractor extractor
|
||||||
|
cdef readonly model_dir
|
||||||
|
|
||||||
def __init__(self, StringStore strings, model_dir):
|
def __init__(self, StringStore strings, model_dir):
|
||||||
self.strings = strings
|
if model_dir is not None and path.isdir(model_dir):
|
||||||
|
model_dir = path.join(model_dir, 'model')
|
||||||
|
|
||||||
templates = unigrams + bigrams + trigrams
|
templates = unigrams + bigrams + trigrams
|
||||||
self.model = Model(N_SENSES, templates, model_dir)
|
self.extractor = Extractor(templates)
|
||||||
|
self.model = LinearModel(N_SENSES, self.extractor.n_templ)
|
||||||
|
self.model_dir = model_dir
|
||||||
|
if self.model_dir and path.exists(self.model_dir):
|
||||||
|
self.model.load(self.model_dir, freq_thresh=0)
|
||||||
|
self.strings = strings
|
||||||
|
|
||||||
def __call__(self, Tokens tokens):
|
def __call__(self, Tokens tokens):
|
||||||
eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats,
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
self.model.n_feats)
|
cdef int i, guess, n_feats
|
||||||
cdef int i
|
cdef const TokenC* token
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
n_valid = self._set_valid(<bint*>eg.c.is_valid, pos_senses(&tokens.data[i]))
|
token = &tokens.data[i]
|
||||||
if n_valid >= 1:
|
if token.pos in (NOUN, VERB):
|
||||||
fill_context(eg.c.atoms, &tokens.data[i])
|
fill_context(context, token)
|
||||||
self.model.predict(eg)
|
feats = self.extractor.get_feats(context, &n_feats)
|
||||||
tokens.data[i].sense = eg.c.guess
|
scores = self.model.get_scores(feats, n_feats)
|
||||||
|
tokens.data[i].sense = self.best_in_set(scores, POS_SENSES[<int>token.pos])
|
||||||
|
|
||||||
def train(self, Tokens tokens, GoldParse gold):
|
def train(self, Tokens tokens, GoldParse gold):
|
||||||
eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats+1,
|
cdef int i, j
|
||||||
self.model.n_feats+1)
|
|
||||||
cdef int i
|
|
||||||
for i, ssenses in enumerate(gold.ssenses):
|
for i, ssenses in enumerate(gold.ssenses):
|
||||||
if ssenses:
|
if ssenses:
|
||||||
gold.c.ssenses[i] = encode_sense_strs(ssenses)
|
gold.c.ssenses[i] = encode_sense_strs(ssenses)
|
||||||
else:
|
else:
|
||||||
gold.c.ssenses[i] = pos_senses(&tokens.data[i])
|
gold.c.ssenses[i] = pos_senses(&tokens.data[i])
|
||||||
|
|
||||||
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
|
cdef int n_feats
|
||||||
|
cdef feat_t f_key
|
||||||
|
cdef int f_i
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
if tokens.data[i].lex.senses == 0 or tokens.data[i].lex.senses == 1:
|
token = &tokens.data[i]
|
||||||
continue
|
if token.pos in (NOUN, VERB) \
|
||||||
self._set_costs(<bint*>eg.c.is_valid, eg.c.costs, gold.c.ssenses[i])
|
and token.lex.senses >= 2 \
|
||||||
fill_context(eg.c.atoms, &tokens.data[i])
|
and gold.c.ssenses[i] >= 2:
|
||||||
|
fill_context(context, token)
|
||||||
self.model.train(eg)
|
feats = self.extractor.get_feats(context, &n_feats)
|
||||||
|
scores = self.model.get_scores(feats, n_feats)
|
||||||
tokens.data[i].sense = eg.c.guess
|
token.sense = self.best_in_set(scores, POS_SENSES[<int>token.pos])
|
||||||
cost += eg.c.cost
|
best = self.best_in_set(scores, gold.c.ssenses[i])
|
||||||
|
guess_counts = {}
|
||||||
|
gold_counts = {}
|
||||||
|
if token.sense != best:
|
||||||
|
for j in range(n_feats):
|
||||||
|
f_key = feats[j].key
|
||||||
|
f_i = feats[j].i
|
||||||
|
feat = (f_i, f_key)
|
||||||
|
gold_counts[feat] = gold_counts.get(feat, 0) + 1.0
|
||||||
|
guess_counts[feat] = guess_counts.get(feat, 0) - 1.0
|
||||||
|
#self.model.update({token.sense: guess_counts, best: gold_counts})
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
cdef int _set_valid(self, bint* is_valid, flags_t senses) except -1:
|
cdef int best_in_set(self, const weight_t* scores, flags_t senses) except -1:
|
||||||
cdef int n_valid
|
cdef weight_t max_ = 0
|
||||||
cdef flags_t bit
|
cdef int argmax = -1
|
||||||
is_valid[0] = False
|
cdef flags_t i
|
||||||
for bit in range(1, N_SENSES):
|
for i in range(N_SENSES):
|
||||||
is_valid[bit] = senses & (1 << bit)
|
if (senses & (1 << i)) and (argmax == -1 or scores[i] > max_):
|
||||||
n_valid += is_valid[bit]
|
max_ = scores[i]
|
||||||
return n_valid
|
argmax = i
|
||||||
|
assert argmax >= 0
|
||||||
cdef int _set_costs(self, bint* is_valid, int* costs, flags_t senses):
|
return argmax
|
||||||
cdef flags_t bit
|
|
||||||
is_valid[0] = False
|
|
||||||
costs[0] = 1
|
|
||||||
for bit in range(1, N_SENSES):
|
|
||||||
is_valid[bit] = True
|
|
||||||
costs[bit] = 0 if (senses & (1 << bit)) else 1
|
|
||||||
|
|
||||||
|
|
||||||
cdef flags_t pos_senses(const TokenC* token) nogil:
|
cdef flags_t pos_senses(const TokenC* token) nogil:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user