mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
* Fixed greedy NER parsing. With static oracle, replicates accuracy from tagger.
This commit is contained in:
parent
399239760b
commit
0d943ab358
|
@ -13,6 +13,8 @@ cdef class NERParser:
|
|||
cdef Pool mem
|
||||
cdef Extractor extractor
|
||||
cdef LinearModel model
|
||||
cdef readonly list tag_names
|
||||
cdef readonly int n_classes
|
||||
|
||||
cdef Move* _moves
|
||||
cdef atom_t* _context
|
||||
|
@ -21,5 +23,5 @@ cdef class NERParser:
|
|||
cdef weight_t* _scores
|
||||
|
||||
|
||||
cpdef int train(self, Tokens tokens, golds)
|
||||
cpdef int train(self, Tokens tokens, golds) except -1
|
||||
cpdef int set_tags(self, Tokens tokens) except -1
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
cimport cython
|
||||
import random
|
||||
import os
|
||||
|
@ -7,27 +10,58 @@ import json
|
|||
|
||||
from thinc.features cimport ConjFeat
|
||||
|
||||
from ..context cimport fill_context
|
||||
from ..context cimport N_FIELDS
|
||||
from .context cimport fill_context
|
||||
from .context cimport N_FIELDS
|
||||
from .moves cimport Move
|
||||
from .moves cimport fill_moves, transition, best_accepted
|
||||
from .moves cimport set_accept_if_valid, set_accept_if_oracle
|
||||
from ._state cimport entity_is_open
|
||||
from .moves import get_n_moves
|
||||
from ._state cimport State
|
||||
from ._state cimport init_state
|
||||
|
||||
|
||||
def setup_model_dir(tag_names, templates, model_dir):
|
||||
if path.exists(model_dir):
|
||||
shutil.rmtree(model_dir)
|
||||
os.mkdir(model_dir)
|
||||
config = {
|
||||
'templates': templates,
|
||||
'tag_names': tag_names,
|
||||
}
|
||||
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
||||
json.dump(config, file_)
|
||||
|
||||
|
||||
|
||||
def train(train_sents, model_dir, nr_iter=10):
|
||||
cdef Tokens tokens
|
||||
parser = NERParser(model_dir)
|
||||
for _ in range(nr_iter):
|
||||
n_corr = 0
|
||||
total = 0
|
||||
for i, (tokens, golds) in enumerate(train_sents):
|
||||
if any([g == 0 for g in golds]):
|
||||
continue
|
||||
n_corr += parser.train(tokens, golds)
|
||||
total += len([g for g in golds if g != 0])
|
||||
print('%.4f' % ((n_corr / total) * 100))
|
||||
random.shuffle(train_sents)
|
||||
parser.model.end_training()
|
||||
parser.model.dump(path.join(model_dir, 'model'))
|
||||
|
||||
|
||||
cdef class NERParser:
|
||||
def __init__(self, model_dir):
|
||||
self.mem = Pool()
|
||||
cfg = json.load(open(path.join(model_dir, 'config.json')))
|
||||
templates = cfg['templates']
|
||||
self.entity_types = cfg['entity_types']
|
||||
self.tag_names = cfg['tag_names']
|
||||
self.extractor = Extractor(templates, [ConjFeat] * len(templates))
|
||||
self.n_classes = get_n_moves(len(self.entity_types))
|
||||
self._moves = <Move*>self.mem.alloc(self.n_classes, sizeof(Move))
|
||||
fill_moves(self._moves, len(self.entity_types))
|
||||
self.model = LinearModel(len(self.tag_names))
|
||||
self.n_classes = len(self.tag_names)
|
||||
self._moves = <Move*>self.mem.alloc(len(self.tag_names), sizeof(Move))
|
||||
fill_moves(self._moves, self.tag_names)
|
||||
self.model = LinearModel(self.n_classes)
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
self.model.load(path.join(model_dir, 'model'))
|
||||
|
||||
|
@ -36,14 +70,16 @@ cdef class NERParser:
|
|||
self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t))
|
||||
self._scores = <weight_t*>self.mem.alloc(self.model.nr_class, sizeof(weight_t))
|
||||
|
||||
cpdef int train(self, Tokens tokens, gold_classes):
|
||||
cpdef int train(self, Tokens tokens, gold_classes) except -1:
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* s = init_state(mem, tokens.length)
|
||||
cdef Move* golds = <Move*>mem.alloc(len(gold_classes), sizeof(Move))
|
||||
for i, clas in enumerate(gold_classes):
|
||||
golds[i] = self.moves[clas - 1]
|
||||
assert golds[i].id == clas
|
||||
for tok_i, clas in enumerate(gold_classes):
|
||||
golds[tok_i] = self._moves[clas]
|
||||
assert golds[tok_i].clas == clas, '%d vs %d' % (golds[tok_i].clas, clas)
|
||||
cdef Move* guess
|
||||
n_correct = 0
|
||||
cdef int f = 0
|
||||
while s.i < tokens.length:
|
||||
fill_context(self._context, s.i, tokens)
|
||||
self.extractor.extract(self._feats, self._values, self._context, NULL)
|
||||
|
@ -51,21 +87,22 @@ cdef class NERParser:
|
|||
|
||||
set_accept_if_valid(self._moves, self.n_classes, s)
|
||||
guess = best_accepted(self._moves, self._scores, self.n_classes)
|
||||
|
||||
set_accept_if_oracle(self._moves, golds, self.n_classes, s) # TODO
|
||||
assert guess.clas != 0
|
||||
assert gold_classes[s.i] != 0
|
||||
set_accept_if_oracle(self._moves, golds, self.n_classes, s)
|
||||
gold = best_accepted(self._moves, self._scores, self.n_classes)
|
||||
|
||||
if guess.clas == gold.clas:
|
||||
self.model.update({})
|
||||
return 0
|
||||
|
||||
counts = {guess.clas: {}, gold.clas: {}}
|
||||
self.extractor.count(counts[gold.clas], self._feats, 1)
|
||||
self.extractor.count(counts[guess.clas], self._feats, -1)
|
||||
counts = {}
|
||||
n_correct += 1
|
||||
else:
|
||||
counts = {guess.clas: {}, gold.clas: {}}
|
||||
self.extractor.count(counts[gold.clas], self._feats, 1)
|
||||
self.extractor.count(counts[guess.clas], self._feats, -1)
|
||||
self.model.update(counts)
|
||||
|
||||
transition(s, guess)
|
||||
gold_str = self.tag_names[gold.clas]
|
||||
transition(s, gold)
|
||||
tokens.ner[s.i-1] = s.tags[s.i-1]
|
||||
return n_correct
|
||||
|
||||
cpdef int set_tags(self, Tokens tokens) except -1:
|
||||
cdef Pool mem = Pool()
|
||||
|
|
|
@ -6,6 +6,7 @@ from thinc.typedefs cimport weight_t
|
|||
from ._state cimport State
|
||||
|
||||
cpdef enum ActionType:
|
||||
MISSING
|
||||
BEGIN
|
||||
IN
|
||||
LAST
|
||||
|
@ -29,4 +30,4 @@ cdef Move* best_accepted(Move* moves, weight_t* scores, int n) except NULL
|
|||
|
||||
cdef int transition(State *s, Move* m) except -1
|
||||
|
||||
cdef int fill_moves(Move* moves, int n_tags) except -1
|
||||
cdef int fill_moves(Move* moves, list tag_names) except -1
|
||||
|
|
|
@ -7,6 +7,7 @@ from ._state cimport entity_is_sunk
|
|||
|
||||
|
||||
ACTION_NAMES = ['' for _ in range(N_ACTIONS)]
|
||||
ACTION_NAMES[<int>MISSING] = '?'
|
||||
ACTION_NAMES[<int>BEGIN] = 'B'
|
||||
ACTION_NAMES[<int>IN] = 'I'
|
||||
ACTION_NAMES[<int>LAST] = 'L'
|
||||
|
@ -36,6 +37,8 @@ cdef bint can_out(State* s, int label):
|
|||
|
||||
cdef bint is_oracle(ActionType act, int tag, ActionType g_act, int g_tag,
|
||||
ActionType next_act, bint is_sunk):
|
||||
if g_act == MISSING:
|
||||
return True
|
||||
if act == BEGIN:
|
||||
if g_act == BEGIN:
|
||||
# B, Gold B --> Label match
|
||||
|
@ -55,10 +58,10 @@ cdef bint is_oracle(ActionType act, int tag, ActionType g_act, int g_tag,
|
|||
return True
|
||||
elif g_act == LAST:
|
||||
# I, Gold L --> True iff this entity sunk and next tag == O
|
||||
return is_sunk and next_act == OUT
|
||||
return is_sunk and (next_act == OUT or next_act == MISSING)
|
||||
elif g_act == OUT:
|
||||
# I, Gold O --> True iff next tag == O
|
||||
return next_act == OUT
|
||||
return next_act == OUT or next_act == MISSING
|
||||
elif g_act == UNIT:
|
||||
# I, Gold U --> True iff next tag == O
|
||||
return next_act == OUT
|
||||
|
@ -109,7 +112,8 @@ cdef bint is_oracle(ActionType act, int tag, ActionType g_act, int g_tag,
|
|||
cdef int set_accept_if_valid(Move* moves, int n_classes, State* s) except 0:
|
||||
cdef int n_accept = 0
|
||||
cdef Move* m
|
||||
for i in range(n_classes):
|
||||
moves[0].accept = False
|
||||
for i in range(1, n_classes):
|
||||
m = &moves[i]
|
||||
if m.action == BEGIN:
|
||||
m.accept = can_begin(s, m.label)
|
||||
|
@ -134,7 +138,7 @@ cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s)
|
|||
cdef Move* m
|
||||
cdef int n_accept = 0
|
||||
set_accept_if_valid(moves, n_classes, s)
|
||||
for i in range(n_classes):
|
||||
for i in range(1, n_classes):
|
||||
m = &moves[i]
|
||||
if not m.accept:
|
||||
continue
|
||||
|
@ -146,19 +150,20 @@ cdef int set_accept_if_oracle(Move* moves, Move* golds, int n_classes, State* s)
|
|||
|
||||
|
||||
cdef Move* best_accepted(Move* moves, weight_t* scores, int n) except NULL:
|
||||
cdef int first_accept
|
||||
for first_accept in range(n):
|
||||
cdef int first_accept = -1
|
||||
for first_accept in range(1, n):
|
||||
if moves[first_accept].accept:
|
||||
break
|
||||
else:
|
||||
raise StandardError
|
||||
assert first_accept != -1
|
||||
cdef int best = first_accept
|
||||
cdef weight_t score = scores[first_accept]
|
||||
cdef weight_t score = scores[first_accept-1]
|
||||
cdef int i
|
||||
for i in range(first_accept+1, n):
|
||||
if moves[i].accept and scores[i] > score:
|
||||
if moves[i].accept and scores[i-1] > score:
|
||||
best = i
|
||||
score = scores[i]
|
||||
score = scores[i-1]
|
||||
return &moves[best]
|
||||
|
||||
|
||||
|
@ -182,23 +187,21 @@ def get_n_moves(n_tags):
|
|||
return n_tags + n_tags + n_tags + n_tags + 1
|
||||
|
||||
|
||||
cdef int fill_moves(Move* moves, int n_tags) except -1:
|
||||
cdef int i = 0
|
||||
for label in range(n_tags):
|
||||
moves[i].action = BEGIN
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
for label in range(n_tags):
|
||||
moves[i].action = IN
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
for label in range(n_tags):
|
||||
moves[i].action = LAST
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
for label in range(n_tags):
|
||||
moves[i].action = UNIT
|
||||
moves[i].label = label
|
||||
i += 1
|
||||
moves[i].action = OUT
|
||||
moves[i].label = 0
|
||||
cdef int fill_moves(Move* moves, list tag_names) except -1:
|
||||
cdef Move* m
|
||||
label_names = {'-': 0}
|
||||
for i, tag_name in enumerate(tag_names):
|
||||
m = &moves[i]
|
||||
if '-' in tag_name:
|
||||
action_str, label = tag_name.split('-')
|
||||
elif tag_name == 'O':
|
||||
action_str = 'O'
|
||||
label = '-'
|
||||
elif tag_name == 'NULL' or tag_name == 'EOL':
|
||||
action_str = '?'
|
||||
label = '-'
|
||||
else:
|
||||
raise StandardError(tag_name)
|
||||
m.action = ACTION_NAMES.index(action_str)
|
||||
m.label = label_names.setdefault(label, len(label_names))
|
||||
m.clas = i
|
||||
|
|
|
@ -6,7 +6,7 @@ from ._state cimport State
|
|||
|
||||
cdef class PyState:
|
||||
cdef Pool mem
|
||||
cdef readonly list entity_types
|
||||
cdef readonly list tag_names
|
||||
cdef readonly int n_classes
|
||||
cdef readonly dict moves_by_name
|
||||
|
||||
|
|
|
@ -12,26 +12,16 @@ from .moves import ACTION_NAMES
|
|||
cdef class PyState:
|
||||
def __init__(self, tag_names, n_tokens):
|
||||
self.mem = Pool()
|
||||
self.entity_types = tag_names
|
||||
self.n_classes = get_n_moves(len(self.entity_types))
|
||||
self.tag_names = tag_names
|
||||
self.n_classes = len(tag_names)
|
||||
assert self.n_classes != 0
|
||||
self._moves = <Move*>self.mem.alloc(self.n_classes, sizeof(Move))
|
||||
fill_moves(self._moves, len(self.entity_types))
|
||||
fill_moves(self._moves, tag_names)
|
||||
self._s = init_state(self.mem, n_tokens)
|
||||
self.moves_by_name = {}
|
||||
for i in range(self.n_classes):
|
||||
m = &self._moves[i]
|
||||
action_name = ACTION_NAMES[m.action]
|
||||
if action_name == 'O':
|
||||
self.moves_by_name['O'] = i
|
||||
else:
|
||||
tag_name = tag_names[m.label]
|
||||
self.moves_by_name['%s-%s' % (action_name, tag_name)] = i
|
||||
# TODO
|
||||
self._golds = <Move*>self.mem.alloc(n_tokens, sizeof(Move))
|
||||
|
||||
cdef Move* _get_move(self, unicode move_name) except NULL:
|
||||
return &self._moves[self.moves_by_name[move_name]]
|
||||
return &self._moves[self.tag_names.index(move_name)]
|
||||
|
||||
def set_golds(self, list gold_names):
|
||||
cdef Move* m
|
||||
|
@ -49,8 +39,8 @@ cdef class PyState:
|
|||
return m.accept
|
||||
|
||||
def is_gold(self, unicode move_name):
|
||||
set_accept_if_oracle(self._moves, self._golds, self.n_classes, self._s)
|
||||
cdef Move* m = self._get_move(move_name)
|
||||
set_accept_if_oracle(self._moves, self._golds, self.n_classes, self._s)
|
||||
return m.accept
|
||||
|
||||
property ent:
|
||||
|
|
Loading…
Reference in New Issue
Block a user