mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
* Refactor context extraction, and start breaking out gold standards into their own functions
This commit is contained in:
parent
602f993af9
commit
f307eb2e36
|
@ -60,9 +60,7 @@ cdef class Slots:
|
|||
cdef int N_FIELDS
|
||||
|
||||
|
||||
cdef hash_t fill_slots(Slots s, int i, Tokens tokens) except 0
|
||||
|
||||
cdef int fill_flat(atom_t* context, Slots s) except -1
|
||||
cdef int fill_context(atom_t* context, int i, Tokens tokens) except -1
|
||||
|
||||
|
||||
cpdef Slots FIELD_IDS
|
||||
|
|
|
@ -60,111 +60,59 @@ cdef void _number_token(Token t, int* n_fields):
|
|||
n_fields[0] = i
|
||||
|
||||
|
||||
cdef int fill_token(Token t, Lexeme* lex, atom_t pos, atom_t ner):
|
||||
t.sic = lex.sic
|
||||
t.cluster = lex.cluster
|
||||
t.norm = lex.norm if (lex.prob != 0 and lex.prob >= -10) else lex.shape
|
||||
t.shape = lex.shape
|
||||
t.asciied = lex.asciied
|
||||
t.prefix = lex.prefix
|
||||
t.suffix = lex.suffix
|
||||
t.length = lex.length
|
||||
cdef int _fill_token(atom_t* c, Token t, Lexeme* lex, atom_t pos, atom_t ner):
|
||||
c[t.sic] = lex.sic
|
||||
c[t.cluster] = lex.cluster
|
||||
c[t.norm] = lex.norm if (lex.prob != 0 and lex.prob >= -10) else lex.shape
|
||||
c[t.shape] = lex.shape
|
||||
c[t.asciied] = lex.asciied
|
||||
c[t.prefix] = lex.prefix
|
||||
c[t.suffix] = lex.suffix
|
||||
c[t.length] = lex.length
|
||||
|
||||
t.postype = lex.postype
|
||||
t.nertype = 0
|
||||
t.sensetype = 0
|
||||
c[t.postype] = lex.postype
|
||||
c[t.nertype] = 0
|
||||
c[t.sensetype] = 0
|
||||
|
||||
t.is_alpha = lex.flags & (1 << IS_ALPHA)
|
||||
t.is_digit = lex.flags & (1 << IS_DIGIT)
|
||||
t.is_lower = lex.flags & (1 << IS_LOWER)
|
||||
t.is_punct = lex.flags & (1 << IS_PUNCT)
|
||||
t.is_space = lex.flags & (1 << IS_SPACE)
|
||||
t.is_title = lex.flags & (1 << IS_TITLE)
|
||||
t.is_upper = lex.flags & (1 << IS_UPPER)
|
||||
t.like_url = lex.flags & (1 << LIKE_URL)
|
||||
t.like_number = lex.flags & (1 << LIKE_NUMBER)
|
||||
t.oft_lower = lex.flags & (1 << OFT_LOWER)
|
||||
t.oft_title = lex.flags & (1 << OFT_TITLE)
|
||||
t.oft_upper = lex.flags & (1 << OFT_UPPER)
|
||||
c[t.is_alpha] = lex.flags & (1 << IS_ALPHA)
|
||||
c[t.is_digit] = lex.flags & (1 << IS_DIGIT)
|
||||
c[t.is_lower] = lex.flags & (1 << IS_LOWER)
|
||||
c[t.is_punct] = lex.flags & (1 << IS_PUNCT)
|
||||
c[t.is_space] = lex.flags & (1 << IS_SPACE)
|
||||
c[t.is_title] = lex.flags & (1 << IS_TITLE)
|
||||
c[t.is_upper] = lex.flags & (1 << IS_UPPER)
|
||||
c[t.like_url] = lex.flags & (1 << LIKE_URL)
|
||||
c[t.like_number] = lex.flags & (1 << LIKE_NUMBER)
|
||||
c[t.oft_lower] = lex.flags & (1 << OFT_LOWER)
|
||||
c[t.oft_title] = lex.flags & (1 << OFT_TITLE)
|
||||
c[t.oft_upper] = lex.flags & (1 << OFT_UPPER)
|
||||
|
||||
t.in_males = lex.flags & (1 << IN_MALES)
|
||||
t.in_females = lex.flags & (1 << IN_FEMALES)
|
||||
t.in_surnames = lex.flags & (1 << IN_SURNAMES)
|
||||
t.in_places = lex.flags & (1 << IN_PLACES)
|
||||
t.in_games = lex.flags & (1 << IN_GAMES)
|
||||
t.in_celebs = lex.flags & (1 << IN_CELEBS)
|
||||
t.in_names = lex.flags & (1 << IN_NAMES)
|
||||
c[t.in_males] = lex.flags & (1 << IN_MALES)
|
||||
c[t.in_females] = lex.flags & (1 << IN_FEMALES)
|
||||
c[t.in_surnames] = lex.flags & (1 << IN_SURNAMES)
|
||||
c[t.in_places] = lex.flags & (1 << IN_PLACES)
|
||||
c[t.in_games] = lex.flags & (1 << IN_GAMES)
|
||||
c[t.in_celebs] = lex.flags & (1 << IN_CELEBS)
|
||||
c[t.in_names] = lex.flags & (1 << IN_NAMES)
|
||||
|
||||
t.pos = pos
|
||||
t.sense = 0
|
||||
t.ner = ner
|
||||
c[t.pos] = pos
|
||||
c[t.sense] = 0
|
||||
c[t.ner] = ner
|
||||
|
||||
|
||||
cdef int _flatten_token(atom_t* context, Token ids, Token vals) except -1:
|
||||
context[ids.sic] = vals.sic
|
||||
context[ids.cluster] = vals.cluster
|
||||
context[ids.norm] = vals.norm
|
||||
context[ids.shape] = vals.shape
|
||||
context[ids.asciied] = vals.asciied
|
||||
context[ids.prefix] = vals.prefix
|
||||
context[ids.suffix] = vals.suffix
|
||||
context[ids.length] = vals.length
|
||||
|
||||
context[ids.postype] = vals.postype
|
||||
context[ids.nertype] = vals.nertype
|
||||
context[ids.sensetype] = vals.sensetype
|
||||
|
||||
context[ids.is_alpha] = vals.is_alpha
|
||||
context[ids.is_ascii] = vals.is_ascii
|
||||
context[ids.is_digit] = vals.is_digit
|
||||
context[ids.is_lower] = vals.is_lower
|
||||
context[ids.is_punct] = vals.is_punct
|
||||
context[ids.is_title] = vals.is_title
|
||||
context[ids.is_upper] = vals.is_upper
|
||||
context[ids.like_url] = vals.like_url
|
||||
context[ids.like_number] = vals.like_number
|
||||
context[ids.oft_lower] = vals.oft_lower
|
||||
context[ids.oft_title] = vals.oft_title
|
||||
context[ids.oft_upper] = vals.oft_upper
|
||||
|
||||
context[ids.in_males] = vals.in_males
|
||||
context[ids.in_females] = vals.in_females
|
||||
context[ids.in_surnames] = vals.in_surnames
|
||||
context[ids.in_places] = vals.in_places
|
||||
context[ids.in_games] = vals.in_games
|
||||
context[ids.in_celebs] = vals.in_celebs
|
||||
context[ids.in_names] = vals.in_names
|
||||
|
||||
context[ids.pos] = vals.pos
|
||||
context[ids.sense] = vals.sense
|
||||
context[ids.ner] = vals.ner
|
||||
|
||||
|
||||
cdef hash_t fill_slots(Slots s, int i, Tokens tokens) except 0:
|
||||
fill_token(s.P4, tokens.lex[i-4], tokens.pos[i-4], tokens.ner[i-4])
|
||||
fill_token(s.P3, tokens.lex[i-3], tokens.pos[i-3], tokens.ner[i-3])
|
||||
fill_token(s.P2, tokens.lex[i-2], tokens.pos[i-2], tokens.ner[i-2])
|
||||
fill_token(s.P1, tokens.lex[i-1], tokens.pos[i-1], tokens.ner[i-1])
|
||||
fill_token(s.N0, tokens.lex[i], tokens.pos[i], tokens.ner[i])
|
||||
fill_token(s.N1, tokens.lex[i+1], tokens.pos[i+1], tokens.ner[i+1])
|
||||
fill_token(s.N2, tokens.lex[i+2], tokens.pos[i+2], tokens.ner[i+2])
|
||||
fill_token(s.N3, tokens.lex[i+3], tokens.pos[i+3], tokens.ner[i+3])
|
||||
fill_token(s.N4, tokens.lex[i+4], tokens.pos[i+4], tokens.ner[i+4])
|
||||
cdef int fill_context(atom_t* context, int i, Tokens tokens) except -1:
|
||||
_fill_token(context, FIELD_IDS.P4, tokens.lex[i-4], tokens.pos[i-4], tokens.ner[i-4])
|
||||
_fill_token(context, FIELD_IDS.P3, tokens.lex[i-3], tokens.pos[i-3], tokens.ner[i-3])
|
||||
_fill_token(context, FIELD_IDS.P2, tokens.lex[i-2], tokens.pos[i-2], tokens.ner[i-2])
|
||||
_fill_token(context, FIELD_IDS.P1, tokens.lex[i-1], tokens.pos[i-1], tokens.ner[i-1])
|
||||
_fill_token(context, FIELD_IDS.N0, tokens.lex[i], tokens.pos[i], tokens.ner[i])
|
||||
_fill_token(context, FIELD_IDS.N1, tokens.lex[i+1], tokens.pos[i+1], tokens.ner[i+1])
|
||||
_fill_token(context, FIELD_IDS.N2, tokens.lex[i+2], tokens.pos[i+2], tokens.ner[i+2])
|
||||
_fill_token(context, FIELD_IDS.N3, tokens.lex[i+3], tokens.pos[i+3], tokens.ner[i+3])
|
||||
_fill_token(context, FIELD_IDS.N4, tokens.lex[i+4], tokens.pos[i+4], tokens.ner[i+4])
|
||||
return 1
|
||||
|
||||
|
||||
cdef int fill_flat(atom_t* context, Slots s) except -1:
|
||||
_flatten_token(context, FIELD_IDS.P4, s.P4)
|
||||
_flatten_token(context, FIELD_IDS.P3, s.P3)
|
||||
_flatten_token(context, FIELD_IDS.P2, s.P2)
|
||||
_flatten_token(context, FIELD_IDS.P1, s.P1)
|
||||
_flatten_token(context, FIELD_IDS.N0, s.N0)
|
||||
_flatten_token(context, FIELD_IDS.N1, s.N1)
|
||||
_flatten_token(context, FIELD_IDS.N2, s.N2)
|
||||
_flatten_token(context, FIELD_IDS.N3, s.N4)
|
||||
_flatten_token(context, FIELD_IDS.N4, s.N4)
|
||||
|
||||
|
||||
N_FIELDS = 0
|
||||
FIELD_IDS = Slots()
|
||||
_number_token(FIELD_IDS.P4, &N_FIELDS)
|
||||
|
|
|
@ -28,8 +28,7 @@ cdef class Tagger:
|
|||
cpdef readonly list tag_names
|
||||
|
||||
cdef class_t _guess
|
||||
cdef atom_t* _context_flat
|
||||
cdef Slots _context_slots
|
||||
cdef atom_t* _context
|
||||
cdef feat_t* _feats
|
||||
cdef weight_t* _values
|
||||
cdef weight_t* _scores
|
||||
|
|
|
@ -11,8 +11,7 @@ import json
|
|||
import cython
|
||||
|
||||
|
||||
from .context cimport fill_slots
|
||||
from .context cimport fill_flat
|
||||
from .context cimport fill_context
|
||||
from .context cimport N_FIELDS
|
||||
|
||||
from thinc.features cimport ConjFeat
|
||||
|
@ -35,19 +34,24 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
|
|||
|
||||
|
||||
def train(train_sents, model_dir, nr_iter=10):
|
||||
cdef Tokens tokens
|
||||
tagger = Tagger(model_dir)
|
||||
for _ in range(nr_iter):
|
||||
n_corr = 0
|
||||
total = 0
|
||||
for tokens, golds in train_sents:
|
||||
assert len(tokens) == len(golds), [t.string for t in tokens]
|
||||
for i, gold in enumerate(golds):
|
||||
for i in range(tokens.length):
|
||||
if tagger.tag_type == POS:
|
||||
gold = _get_gold_pos(i, golds, tokens.pos)
|
||||
elif tagger.tag_type == ENTITY:
|
||||
gold = _get_gold_ner(i, golds, tokens.ner)
|
||||
guess = tagger.predict(i, tokens)
|
||||
tokens.set_tag(i, tagger.tag_type, guess)
|
||||
if gold != NULL_TAG:
|
||||
tagger.tell_answer([gold])
|
||||
if gold is not None:
|
||||
tagger.tell_answer(gold)
|
||||
total += 1
|
||||
n_corr += guess == gold
|
||||
n_corr += guess in gold
|
||||
#print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
|
||||
print('%.4f' % ((n_corr / total) * 100))
|
||||
random.shuffle(train_sents)
|
||||
|
@ -55,6 +59,20 @@ def train(train_sents, model_dir, nr_iter=10):
|
|||
tagger.model.dump(path.join(model_dir, 'model'))
|
||||
|
||||
|
||||
cdef object _get_gold_pos(i, golds, int* pred):
|
||||
if golds[i] == 0:
|
||||
return None
|
||||
else:
|
||||
return [golds[i]]
|
||||
|
||||
|
||||
cdef object _get_gold_ner(i, golds, int* ner):
|
||||
if golds[i] == 0:
|
||||
return None
|
||||
else:
|
||||
return [golds[i]]
|
||||
|
||||
|
||||
def evaluate(tagger, sents):
|
||||
n_corr = 0
|
||||
total = 0
|
||||
|
@ -83,8 +101,7 @@ cdef class Tagger:
|
|||
if path.exists(path.join(model_dir, 'model')):
|
||||
self.model.load(path.join(model_dir, 'model'))
|
||||
|
||||
self._context_flat = <atom_t*>self.mem.alloc(N_FIELDS, sizeof(atom_t))
|
||||
self._context_slots = Slots()
|
||||
self._context = <atom_t*>self.mem.alloc(N_FIELDS, sizeof(atom_t))
|
||||
self._feats = <feat_t*>self.mem.alloc(self.extractor.n+1, sizeof(feat_t))
|
||||
self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t))
|
||||
self._scores = <weight_t*>self.mem.alloc(self.model.nr_class, sizeof(weight_t))
|
||||
|
@ -110,9 +127,8 @@ cdef class Tagger:
|
|||
>>> tag = EN.pos_tagger.predict(0, tokens)
|
||||
>>> assert tag == EN.pos_tagger.tag_id('DT') == 5
|
||||
"""
|
||||
cdef hash_t hashed = fill_slots(self._context_slots, i, tokens)
|
||||
fill_flat(self._context_flat, self._context_slots)
|
||||
self.extractor.extract(self._feats, self._values, self._context_flat, NULL)
|
||||
fill_context(self._context, i, tokens)
|
||||
self.extractor.extract(self._feats, self._values, self._context, NULL)
|
||||
self._guess = self.model.score(self._scores, self._feats, self._values)
|
||||
return self._guess
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user