diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 23e72916e..960f9f2c2 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -133,9 +133,9 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp return gs -cdef class ArcEagerGoldParse: +cdef class ArcEagerGold: cdef GoldParseStateC c - def __init__(self, StateClass stcls, Example example): + def __init__(self, ArcEager moves, StateClass stcls, Example example): self.mem = Pool() self.c = create_gold_state(self.mem, stcls, example) @@ -512,9 +512,9 @@ cdef class ArcEager(TransitionSystem): states = self.init_batch([eg.predicted for eg in examples]) keeps = [i for i, s in enumerate(states) if not s.is_final()] states = [states[i] for i in keeps] - examples = [examples[i] for i in keeps] + golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] n_steps = sum([len(s.buffer_length()) * 4 for s in states]) - return states, examples, n_steps + return states, golds, n_steps cdef Transition lookup_transition(self, object name_or_id) except *: if isinstance(name_or_id, int): diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 31f89ef88..53d9c67de 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -1,4 +1,6 @@ from collections import Counter +from libc.stdint cimport int32_t +from cymem.cymem cimport Pool from ..typedefs cimport weight_t from .stateclass cimport StateClass @@ -7,6 +9,8 @@ from .transition_system cimport Transition from .transition_system cimport do_func_t from ..lexeme cimport Lexeme from ..attrs cimport IS_SPACE +from ..gold.iob_utils import biluo_tags_from_offsets +from ..gold.example cimport Example from ..errors import Errors @@ -34,6 +38,51 @@ MOVE_NAMES[ISNT] = 'x' cdef struct GoldNERStateC: Transition* ner + int32_t length + + +cdef class BiluoGold: + cdef Pool mem + cdef GoldNERStateC c + + def __init__(self, BiluoPushDown moves, StateClass stcls, Example example): + self.mem = Pool() + self.c = create_gold_state(self.mem, moves, stcls, example) + + +cdef GoldNERStateC create_gold_state( + Pool mem, + BiluoPushDown moves, + StateClass stcls, + Example example +) except *: + cdef GoldNERStateC gs + gs.ner = mem.alloc(example.x.length, sizeof(Transition)) + ner_tags = get_aligned_ner(example) + for i, ner_tag in enumerate(ner_tags): + gs.ner[i] = moves.lookup_transition(ner_tag) + return gs + + +def get_aligned_ner(Example example): + cand_to_gold = example.alignment.cand_to_gold + i2j_multi = example.alignment.i2j_multi + y_tags = biluo_tags_from_offsets( + example.y, + [(e.start_char, e.end_char, e.label_) for e in example.y.ents] + ) + x_tags = [None] * example.x.length + for i in range(example.x.length): + if example.x[i].is_space: + pass + elif cand_to_gold[i] is not None: + x_tags[i] = y_tags[cand_to_gold[i]] + elif i in i2j_multi: + # Assign O/- for many-to-one O/- NER tags + if y_tags[i2j_multi[i]] in ("O", "-"): + x_tags[i] = y_tags[i2j_multi[i]] + return y_tags + cdef do_func_t[N_MOVES] do_funcs @@ -90,11 +139,14 @@ cdef class BiluoPushDown(TransitionSystem): else: return MOVE_NAMES[move] + '-' + self.strings[label] - def has_gold(self, gold, start=0, end=None): - raise NotImplementedError - - def preprocess_gold(self, example): - raise NotImplementedError + def init_gold_batch(self, examples): + states = self.init_batch([eg.predicted for eg in examples]) + keeps = [i for i, s in enumerate(states) if not s.is_final()] + states = [states[i] for i in keeps] + examples = [examples[i] for i in keeps] + golds = [BiluoGold(self, states[i], examples[i]) for i in keeps] + n_steps = sum([len(s.buffer_length()) for s in states]) + return states, golds, n_steps cdef Transition lookup_transition(self, object name) except *: cdef attr_t label diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index b14198378..687c234d0 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -97,12 +97,6 @@ cdef class TransitionSystem: def finalize_doc(self, doc): pass - def preprocess_gold(self, example): - raise NotImplementedError - - def is_gold_parse(self, StateClass state, example): - raise NotImplementedError - cdef Transition lookup_transition(self, object name) except *: raise NotImplementedError