Update parser and NER gold stuff

This commit is contained in:
Matthew Honnibal 2020-06-19 02:29:16 +02:00
parent 5ae9e3480d
commit bd29b7b14f
3 changed files with 61 additions and 15 deletions

View File

@ -133,9 +133,9 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp
return gs return gs
cdef class ArcEagerGoldParse: cdef class ArcEagerGold:
cdef GoldParseStateC c cdef GoldParseStateC c
def __init__(self, StateClass stcls, Example example): def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool() self.mem = Pool()
self.c = create_gold_state(self.mem, stcls, example) self.c = create_gold_state(self.mem, stcls, example)
@ -512,9 +512,9 @@ cdef class ArcEager(TransitionSystem):
states = self.init_batch([eg.predicted for eg in examples]) states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, s in enumerate(states) if not s.is_final()] keeps = [i for i, s in enumerate(states) if not s.is_final()]
states = [states[i] for i in keeps] states = [states[i] for i in keeps]
examples = [examples[i] for i in keeps] golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
n_steps = sum([len(s.buffer_length()) * 4 for s in states]) n_steps = sum([len(s.buffer_length()) * 4 for s in states])
return states, examples, n_steps return states, golds, n_steps
cdef Transition lookup_transition(self, object name_or_id) except *: cdef Transition lookup_transition(self, object name_or_id) except *:
if isinstance(name_or_id, int): if isinstance(name_or_id, int):

View File

@ -1,4 +1,6 @@
from collections import Counter from collections import Counter
from libc.stdint cimport int32_t
from cymem.cymem cimport Pool
from ..typedefs cimport weight_t from ..typedefs cimport weight_t
from .stateclass cimport StateClass from .stateclass cimport StateClass
@ -7,6 +9,8 @@ from .transition_system cimport Transition
from .transition_system cimport do_func_t from .transition_system cimport do_func_t
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme
from ..attrs cimport IS_SPACE from ..attrs cimport IS_SPACE
from ..gold.iob_utils import biluo_tags_from_offsets
from ..gold.example cimport Example
from ..errors import Errors from ..errors import Errors
@ -34,6 +38,51 @@ MOVE_NAMES[ISNT] = 'x'
cdef struct GoldNERStateC: cdef struct GoldNERStateC:
Transition* ner Transition* ner
int32_t length
cdef class BiluoGold:
cdef Pool mem
cdef GoldNERStateC c
def __init__(self, BiluoPushDown moves, StateClass stcls, Example example):
self.mem = Pool()
self.c = create_gold_state(self.mem, moves, stcls, example)
cdef GoldNERStateC create_gold_state(
Pool mem,
BiluoPushDown moves,
StateClass stcls,
Example example
) except *:
cdef GoldNERStateC gs
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
ner_tags = get_aligned_ner(example)
for i, ner_tag in enumerate(ner_tags):
gs.ner[i] = moves.lookup_transition(ner_tag)
return gs
def get_aligned_ner(Example example):
cand_to_gold = example.alignment.cand_to_gold
i2j_multi = example.alignment.i2j_multi
y_tags = biluo_tags_from_offsets(
example.y,
[(e.start_char, e.end_char, e.label_) for e in example.y.ents]
)
x_tags = [None] * example.x.length
for i in range(example.x.length):
if example.x[i].is_space:
pass
elif cand_to_gold[i] is not None:
x_tags[i] = y_tags[cand_to_gold[i]]
elif i in i2j_multi:
# Assign O/- for many-to-one O/- NER tags
if y_tags[i2j_multi[i]] in ("O", "-"):
x_tags[i] = y_tags[i2j_multi[i]]
return y_tags
cdef do_func_t[N_MOVES] do_funcs cdef do_func_t[N_MOVES] do_funcs
@ -90,11 +139,14 @@ cdef class BiluoPushDown(TransitionSystem):
else: else:
return MOVE_NAMES[move] + '-' + self.strings[label] return MOVE_NAMES[move] + '-' + self.strings[label]
def has_gold(self, gold, start=0, end=None): def init_gold_batch(self, examples):
raise NotImplementedError states = self.init_batch([eg.predicted for eg in examples])
keeps = [i for i, s in enumerate(states) if not s.is_final()]
def preprocess_gold(self, example): states = [states[i] for i in keeps]
raise NotImplementedError examples = [examples[i] for i in keeps]
golds = [BiluoGold(self, states[i], examples[i]) for i in keeps]
n_steps = sum([len(s.buffer_length()) for s in states])
return states, golds, n_steps
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
cdef attr_t label cdef attr_t label

View File

@ -97,12 +97,6 @@ cdef class TransitionSystem:
def finalize_doc(self, doc): def finalize_doc(self, doc):
pass pass
def preprocess_gold(self, example):
raise NotImplementedError
def is_gold_parse(self, StateClass state, example):
raise NotImplementedError
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
raise NotImplementedError raise NotImplementedError