mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
* Tmp commit of NER refactoring
This commit is contained in:
parent
49df1b7002
commit
135756ac3d
|
@ -1,4 +1,4 @@
|
||||||
from libc.stdint cimport int8_t, uint8_t, uint16_t, uint32_t
|
from libc.stdint cimport uint8_t, uint32_t, int32_t
|
||||||
|
|
||||||
from .typedefs cimport flags_t, attr_t, id_t, hash_t
|
from .typedefs cimport flags_t, attr_t, id_t, hash_t
|
||||||
from .parts_of_speech cimport univ_pos_t
|
from .parts_of_speech cimport univ_pos_t
|
||||||
|
@ -42,28 +42,9 @@ cdef struct PosTag:
|
||||||
univ_pos_t pos
|
univ_pos_t pos
|
||||||
|
|
||||||
|
|
||||||
# Start and end will be offsets: i + ent.start will always take you to the
|
|
||||||
# "next" entity start. If inside an entity, ent.start will be negative ---
|
|
||||||
# the next entity is the start of the one the token is inside. If i _is_
|
|
||||||
# the start of an entity, then ent.start will be the beginning of the next one.
|
|
||||||
#
|
|
||||||
# The same/inverse is true for end. If ent.end has a negative value, we are either
|
|
||||||
# at the end of an entity, or outside one. If we're inside an entity, ent.end
|
|
||||||
# will have a positive value.
|
|
||||||
#
|
|
||||||
# This allows us to easily find the span of an entity we might be inside, while
|
|
||||||
# naturally sharing an API with iterating through all entities in the sentence
|
|
||||||
cdef struct Entity:
|
|
||||||
int32_t tag
|
|
||||||
uint16_t flags
|
|
||||||
int8_t start
|
|
||||||
int8_t end
|
|
||||||
|
|
||||||
|
|
||||||
cdef struct TokenC:
|
cdef struct TokenC:
|
||||||
const LexemeC* lex
|
const LexemeC* lex
|
||||||
Morphology morph
|
Morphology morph
|
||||||
Entity ent
|
|
||||||
univ_pos_t pos
|
univ_pos_t pos
|
||||||
int tag
|
int tag
|
||||||
int idx
|
int idx
|
||||||
|
|
|
@ -123,7 +123,7 @@ cdef int _break_cost(const State* s, const int* gold) except -1:
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
cdef class ArcEager(TransitionSystem):
|
||||||
def __init__(self, list left_labels, list right_labels):
|
def __init__(self, list left_labels, list right_labels):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
left_labels.sort()
|
left_labels.sort()
|
||||||
|
@ -163,7 +163,7 @@ cdef class TransitionSystem:
|
||||||
moves[i].label = 0
|
moves[i].label = 0
|
||||||
moves[i].clas = i
|
moves[i].clas = i
|
||||||
i += 1
|
i += 1
|
||||||
self._moves = moves
|
self.c = moves
|
||||||
|
|
||||||
cdef int transition(self, State *s, const Transition* t) except -1:
|
cdef int transition(self, State *s, const Transition* t) except -1:
|
||||||
if t.move == SHIFT:
|
if t.move == SHIFT:
|
||||||
|
|
|
@ -83,7 +83,7 @@ cdef class GreedyParser:
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
self.moves.transition(state, &guess)
|
guess.do(&guess, state)
|
||||||
# Messily tell Tokens object the string names of the dependency labels
|
# Messily tell Tokens object the string names of the dependency labels
|
||||||
dep_strings = [None] * len(self.moves.label_ids)
|
dep_strings = [None] * len(self.moves.label_ids)
|
||||||
for label, id_ in self.moves.label_ids.items():
|
for label, id_ in self.moves.label_ids.items():
|
||||||
|
@ -129,9 +129,9 @@ cdef class GreedyParser:
|
||||||
history.append((py_moves[best.move], print_state(state, py_words)))
|
history.append((py_moves[best.move], print_state(state, py_words)))
|
||||||
self.model.update(context, guess.clas, best.clas, guess.cost)
|
self.model.update(context, guess.clas, best.clas, guess.cost)
|
||||||
if force_gold:
|
if force_gold:
|
||||||
self.moves.transition(state, &best)
|
best.do(&best, state)
|
||||||
else:
|
else:
|
||||||
self.moves.transition(state, &guess)
|
guess.do(&guess, state)
|
||||||
cdef int n_corr = 0
|
cdef int n_corr = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
if gold_heads[i] != -1:
|
if gold_heads[i] != -1:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user