* Tmp commit of NER refactoring

This commit is contained in:
Matthew Honnibal 2015-02-18 04:41:06 -05:00
parent 49df1b7002
commit 135756ac3d
3 changed files with 6 additions and 25 deletions

View File

@ -1,4 +1,4 @@
from libc.stdint cimport int8_t, uint8_t, uint16_t, uint32_t from libc.stdint cimport uint8_t, uint32_t, int32_t
from .typedefs cimport flags_t, attr_t, id_t, hash_t from .typedefs cimport flags_t, attr_t, id_t, hash_t
from .parts_of_speech cimport univ_pos_t from .parts_of_speech cimport univ_pos_t
@ -42,28 +42,9 @@ cdef struct PosTag:
univ_pos_t pos univ_pos_t pos
# Start and end will be offsets: i + ent.start will always take you to the
# "next" entity start. If inside an entity, ent.start will be negative ---
# the next entity is the start of the one the token is inside. If i _is_
# the start of an entity, then ent.start will be the beginning of the next one.
#
# The same/inverse is true for end. If ent.end has a negative value, we are either
# at the end of an entity, or outside one. If we're inside an entity, ent.end
# will have a positive value.
#
# This allows us to easily find the span of an entity we might be inside, while
# naturally sharing an API with iterating through all entities in the sentence
cdef struct Entity:
int32_t tag
uint16_t flags
int8_t start
int8_t end
cdef struct TokenC: cdef struct TokenC:
const LexemeC* lex const LexemeC* lex
Morphology morph Morphology morph
Entity ent
univ_pos_t pos univ_pos_t pos
int tag int tag
int idx int idx

View File

@ -123,7 +123,7 @@ cdef int _break_cost(const State* s, const int* gold) except -1:
return cost return cost
cdef class TransitionSystem: cdef class ArcEager(TransitionSystem):
def __init__(self, list left_labels, list right_labels): def __init__(self, list left_labels, list right_labels):
self.mem = Pool() self.mem = Pool()
left_labels.sort() left_labels.sort()
@ -163,7 +163,7 @@ cdef class TransitionSystem:
moves[i].label = 0 moves[i].label = 0
moves[i].clas = i moves[i].clas = i
i += 1 i += 1
self._moves = moves self.c = moves
cdef int transition(self, State *s, const Transition* t) except -1: cdef int transition(self, State *s, const Transition* t) except -1:
if t.move == SHIFT: if t.move == SHIFT:

View File

@ -83,7 +83,7 @@ cdef class GreedyParser:
fill_context(context, state) fill_context(context, state)
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
self.moves.transition(state, &guess) guess.do(&guess, state)
# Messily tell Tokens object the string names of the dependency labels # Messily tell Tokens object the string names of the dependency labels
dep_strings = [None] * len(self.moves.label_ids) dep_strings = [None] * len(self.moves.label_ids)
for label, id_ in self.moves.label_ids.items(): for label, id_ in self.moves.label_ids.items():
@ -129,9 +129,9 @@ cdef class GreedyParser:
history.append((py_moves[best.move], print_state(state, py_words))) history.append((py_moves[best.move], print_state(state, py_words)))
self.model.update(context, guess.clas, best.clas, guess.cost) self.model.update(context, guess.clas, best.clas, guess.cost)
if force_gold: if force_gold:
self.moves.transition(state, &best) best.do(&best, state)
else: else:
self.moves.transition(state, &guess) guess.do(&guess, state)
cdef int n_corr = 0 cdef int n_corr = 0
for i in range(tokens.length): for i in range(tokens.length):
if gold_heads[i] != -1: if gold_heads[i] != -1: