mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
* Tmp commit of ner code
This commit is contained in:
parent
33c421bcf8
commit
5c3016bac8
|
@ -6,6 +6,7 @@ cdef int begin_entity(State* s, label) except -1
|
|||
cdef int end_entity(State* s) except -1
|
||||
|
||||
cdef State* init_state(Pool mem, int sent_length) except NULL
|
||||
cdef int copy_state(Pool mem, State* dest, State* source) except -1
|
||||
|
||||
cdef bint entity_is_open(State *s) except -1
|
||||
|
||||
|
|
|
@ -1,44 +1,50 @@
|
|||
from .bilou_moves cimport BEGIN, UNIT
|
||||
|
||||
|
||||
cdef int begin_entity(State* s, label) except -1:
|
||||
s.curr.start = s.i
|
||||
s.curr.label = label
|
||||
|
||||
|
||||
cdef int end_entity(State* s) except -1:
|
||||
s.curr.end = s.i
|
||||
s.ents[s.j] = s.curr
|
||||
cdef void begin_entity(State* s, label):
|
||||
s.j += 1
|
||||
s.curr.start = 0
|
||||
s.curr.label = -1
|
||||
s.curr.end = 0
|
||||
s.ents[s.j].start = s.i
|
||||
s.ents[s.j].tag = label
|
||||
s.ents[s.j].end = s.i + 1
|
||||
|
||||
|
||||
cdef void end_entity(State* s):
|
||||
s.ents[s.j].end = s.i + 1
|
||||
|
||||
|
||||
cdef State* init_state(Pool mem, int sent_length) except NULL:
|
||||
s = <State*>mem.alloc(1, sizeof(State))
|
||||
s.j = 0
|
||||
s.ents = <Entity*>mem.alloc(sent_length, sizeof(Entity))
|
||||
for i in range(sent_length):
|
||||
s.ents[i].label = -1
|
||||
s.curr.label = -1
|
||||
s.tags = <int*>mem.alloc(sent_length, sizeof(int))
|
||||
s.length = sent_length
|
||||
return s
|
||||
|
||||
|
||||
cdef bint entity_is_open(State *s) except -1:
|
||||
return s.curr.label != -1
|
||||
cdef bint entity_is_open(State *s):
|
||||
return s.ents[s.j].start != 0
|
||||
|
||||
|
||||
cdef bint entity_is_sunk(State *s, Move* golds) except -1:
|
||||
cdef bint entity_is_sunk(State *s, Move* golds):
|
||||
if not entity_is_open(s):
|
||||
return False
|
||||
|
||||
cdef Move* gold = &golds[s.curr.start]
|
||||
cdef Entity* ent = &s.ents[s.j]
|
||||
cdef Move* gold = &golds[ent.start]
|
||||
if gold.action != BEGIN and gold.action != UNIT:
|
||||
return True
|
||||
elif gold.label != s.curr.label:
|
||||
elif gold.label != ent.label:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
cdef int copy_state(Pool mem, State* dest, State* source) except -1:
|
||||
'''Copy state source into state dest.'''
|
||||
if source.length > dest.length:
|
||||
dest.ents = <Entity*>mem.realloc(dest.ents, source.length * sizeof(Entity))
|
||||
dest.tags = <int*>mem.realloc(dest.tags, source.length * sizeof(int))
|
||||
memcpy(dest.ents, source.ents, source.length * sizeof(Entity))
|
||||
memcpy(dest.tags, source.tags, source.length * sizeof(int))
|
||||
dest.length = source.length
|
||||
dest.i = source.i
|
||||
dest.j = source.j
|
||||
dest.curr = source.curr
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
cimport cython
|
||||
import random
|
||||
import os
|
||||
|
@ -10,56 +7,14 @@ import json
|
|||
|
||||
from thinc.features cimport ConjFeat
|
||||
|
||||
from .context cimport fill_context
|
||||
from .context cimport N_FIELDS
|
||||
from .structs cimport Move, State
|
||||
from .io_moves cimport fill_moves, transition, best_accepted
|
||||
from .io_moves cimport set_accept_if_valid, set_accept_if_oracle
|
||||
from .io_moves import get_n_moves
|
||||
from ..context cimport fill_context
|
||||
from ..context cimport N_FIELDS
|
||||
from .moves cimport Move
|
||||
from .moves cimport fill_moves, transition, best_accepted
|
||||
from .moves cimport set_accept_if_valid, set_accept_if_oracle
|
||||
from .moves import get_n_moves
|
||||
from ._state cimport State
|
||||
from ._state cimport init_state
|
||||
from ._state cimport entity_is_open
|
||||
from ._state cimport end_entity
|
||||
from .annot cimport NERAnnotation
|
||||
|
||||
|
||||
def setup_model_dir(entity_types, templates, model_dir):
|
||||
if path.exists(model_dir):
|
||||
shutil.rmtree(model_dir)
|
||||
os.mkdir(model_dir)
|
||||
config = {
|
||||
'templates': templates,
|
||||
'entity_types': entity_types,
|
||||
}
|
||||
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
||||
json.dump(config, file_)
|
||||
|
||||
|
||||
def train(train_sents, model_dir, nr_iter=10):
|
||||
cdef Tokens tokens
|
||||
cdef NERAnnotation gold_ner
|
||||
parser = NERParser(model_dir)
|
||||
for _ in range(nr_iter):
|
||||
tp = 0
|
||||
fp = 0
|
||||
fn = 0
|
||||
for i, (tokens, gold_ner) in enumerate(train_sents):
|
||||
#print [tokens[i].string for i in range(tokens.length)]
|
||||
test_ents = set(parser.train(tokens, gold_ner))
|
||||
#print 'Test', test_ents
|
||||
gold_ents = set(gold_ner.entities)
|
||||
#print 'Gold', set(gold_ner.entities)
|
||||
tp += len(gold_ents.intersection(test_ents))
|
||||
fp += len(test_ents - gold_ents)
|
||||
fn += len(gold_ents - test_ents)
|
||||
p = tp / (tp + fp)
|
||||
r = tp / (tp + fn)
|
||||
f = 2 * ((p * r) / (p + r))
|
||||
print 'P: %.3f' % p,
|
||||
print 'R: %.3f' % r,
|
||||
print 'F: %.3f' % f
|
||||
random.shuffle(train_sents)
|
||||
parser.model.end_training()
|
||||
parser.model.dump(path.join(model_dir, 'model'))
|
||||
|
||||
|
||||
cdef class NERParser:
|
||||
|
@ -67,12 +22,12 @@ cdef class NERParser:
|
|||
self.mem = Pool()
|
||||
cfg = json.load(open(path.join(model_dir, 'config.json')))
|
||||
templates = cfg['templates']
|
||||
self.extractor = Extractor(templates, [ConjFeat] * len(templates))
|
||||
self.entity_types = cfg['entity_types']
|
||||
self.extractor = Extractor(templates, [ConjFeat] * len(templates))
|
||||
self.n_classes = get_n_moves(len(self.entity_types))
|
||||
self._moves = <Move*>self.mem.alloc(self.n_classes, sizeof(Move))
|
||||
fill_moves(self._moves, self.n_classes, self.entity_types)
|
||||
self.model = LinearModel(self.n_classes)
|
||||
fill_moves(self._moves, len(self.entity_types))
|
||||
self.model = LinearModel(len(self.tag_names))
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
self.model.load(path.join(model_dir, 'model'))
|
||||
|
||||
|
@ -81,59 +36,46 @@ cdef class NERParser:
|
|||
self._values = <weight_t*>self.mem.alloc(self.extractor.n+1, sizeof(weight_t))
|
||||
self._scores = <weight_t*>self.mem.alloc(self.model.nr_class, sizeof(weight_t))
|
||||
|
||||
cpdef list train(self, Tokens tokens, NERAnnotation annot):
|
||||
cpdef int train(self, Tokens tokens, gold_classes):
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* s = init_state(mem, tokens.length)
|
||||
cdef Move* golds = <Move*>mem.alloc(len(gold_classes), sizeof(Move))
|
||||
for i, clas in enumerate(gold_classes):
|
||||
golds[i] = self.moves[clas - 1]
|
||||
assert golds[i].id == clas
|
||||
cdef Move* guess
|
||||
cdef Move* oracle_move
|
||||
n_correct = 0
|
||||
cdef int f = 0
|
||||
while s.i < tokens.length:
|
||||
fill_context(self._context, s, tokens)
|
||||
fill_context(self._context, s.i, tokens)
|
||||
self.extractor.extract(self._feats, self._values, self._context, NULL)
|
||||
self.model.score(self._scores, self._feats, self._values)
|
||||
|
||||
set_accept_if_valid(self._moves, self.n_classes, s)
|
||||
guess = best_accepted(self._moves, self._scores, self.n_classes)
|
||||
assert guess.clas != 0
|
||||
set_accept_if_oracle(self._moves, self.n_classes, s,
|
||||
annot.starts, annot.ends, annot.labels)
|
||||
oracle_move = best_accepted(self._moves, self._scores, self.n_classes)
|
||||
assert oracle_move.clas != 0
|
||||
if guess.clas == oracle_move.clas:
|
||||
counts = {}
|
||||
n_correct += 1
|
||||
else:
|
||||
counts = {guess.clas: {}, oracle_move.clas: {}}
|
||||
self.extractor.count(counts[oracle_move.clas], self._feats, 1)
|
||||
|
||||
set_accept_if_oracle(self._moves, golds, self.n_classes, s) # TODO
|
||||
gold = best_accepted(self._moves, self._scores, self.n_classes)
|
||||
|
||||
if guess.clas == gold.clas:
|
||||
self.model.update({})
|
||||
return 0
|
||||
|
||||
counts = {guess.clas: {}, gold.clas: {}}
|
||||
self.extractor.count(counts[gold.clas], self._feats, 1)
|
||||
self.extractor.count(counts[guess.clas], self._feats, -1)
|
||||
self.model.update(counts)
|
||||
|
||||
transition(s, guess)
|
||||
tokens.ner[s.i-1] = s.tags[s.i-1]
|
||||
if entity_is_open(s):
|
||||
s.curr.label = annot.labels[s.curr.start]
|
||||
end_entity(s)
|
||||
entities = []
|
||||
for i in range(s.j):
|
||||
entities.append((s.ents[i].start, s.ents[i].end, s.ents[i].label))
|
||||
return entities
|
||||
|
||||
cpdef list set_tags(self, Tokens tokens):
|
||||
cpdef int set_tags(self, Tokens tokens) except -1:
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* s = init_state(mem, tokens.length)
|
||||
cdef Move* move
|
||||
while s.i < tokens.length:
|
||||
fill_context(self._context, s, tokens)
|
||||
fill_context(self._context, s.i, tokens)
|
||||
self.extractor.extract(self._feats, self._values, self._context, NULL)
|
||||
self.model.score(self._scores, self._feats, self._values)
|
||||
set_accept_if_valid(self._moves, self.n_classes, s)
|
||||
move = best_accepted(self._moves, self._scores, self.n_classes)
|
||||
transition(s, move)
|
||||
tokens.ner[s.i-1] = s.tags[s.i-1]
|
||||
if entity_is_open(s):
|
||||
s.curr.label = move.label
|
||||
end_entity(s)
|
||||
entities = []
|
||||
for i in range(s.j):
|
||||
entities.append((s.ents[i].start, s.ents[i].end, s.ents[i].label))
|
||||
return entities
|
||||
|
|
|
@ -77,11 +77,16 @@ cdef int set_accept_if_valid(Move* moves, int n, State* s) except 0:
|
|||
moves[0].accept = False
|
||||
for i in range(1, n):
|
||||
if moves[i].action == SHIFT:
|
||||
moves[i].accept = moves[i].label == s.curr.label or not entity_is_open(s)
|
||||
if s.i >= s.length:
|
||||
moves[i].accept = False
|
||||
elif open_ent and moves[i].label != s.curr.label:
|
||||
moves[i].accept = False
|
||||
else:
|
||||
moves[i].accept = True
|
||||
elif moves[i].action == REDUCE:
|
||||
moves[i].accept = open_ent
|
||||
elif moves[i].action == OUT:
|
||||
moves[i].accept = not open_ent
|
||||
moves[i].accept = s.i < s.length and not open_ent
|
||||
n_accept += moves[i].accept
|
||||
return n_accept
|
||||
|
||||
|
@ -150,3 +155,7 @@ cdef int fill_moves(Move* moves, int n, list entity_types) except -1:
|
|||
moves[i].clas = i
|
||||
moves[i].label = 0
|
||||
i += 1
|
||||
|
||||
|
||||
cdef bint is_final(State* s):
|
||||
return s.i == s.length and not entity_is_open(s)
|
||||
|
|
Loading…
Reference in New Issue
Block a user