mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Tmp commit
This commit is contained in:
parent
9568ebed08
commit
03a6626545
|
@ -53,6 +53,7 @@ cdef struct Constituent:
|
|||
int start
|
||||
int end
|
||||
int label
|
||||
bint on_stack
|
||||
|
||||
|
||||
cdef struct TokenC:
|
||||
|
|
|
@ -14,6 +14,7 @@ cdef struct State:
|
|||
int sent_len
|
||||
int stack_len
|
||||
int ents_len
|
||||
int ctnt_len
|
||||
|
||||
|
||||
cdef int add_dep(const State *s, const int head, const int child, const int label) except -1
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from ._state cimport State
|
||||
from ._state cimport has_head, get_idx, get_s0, get_n0
|
||||
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
|
||||
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
|
||||
from ._state cimport head_in_buffer, children_in_buffer
|
||||
from ._state cimport head_in_stack, children_in_stack
|
||||
from ._state cimport count_left_kids
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
|
@ -24,15 +25,23 @@ cdef enum:
|
|||
REDUCE
|
||||
LEFT
|
||||
RIGHT
|
||||
|
||||
BREAK
|
||||
|
||||
CONSTITUENT
|
||||
ADJUST
|
||||
|
||||
N_MOVES
|
||||
|
||||
|
||||
MOVE_NAMES = [None] * N_MOVES
|
||||
MOVE_NAMES[SHIFT] = 'S'
|
||||
MOVE_NAMES[REDUCE] = 'D'
|
||||
MOVE_NAMES[LEFT] = 'L'
|
||||
MOVE_NAMES[RIGHT] = 'R'
|
||||
MOVE_NAMES[BREAK] = 'B'
|
||||
MOVE_NAMES[CONSTITUENT] = 'C'
|
||||
MOVE_NAMES[ADJUST] = 'A'
|
||||
|
||||
|
||||
cdef do_func_t[N_MOVES] do_funcs
|
||||
|
@ -43,20 +52,29 @@ cdef class ArcEager(TransitionSystem):
|
|||
@classmethod
|
||||
def get_labels(cls, gold_parses):
|
||||
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True}}
|
||||
for raw_text, segmented, (ids, words, tags, heads, labels, iob) in gold_parses:
|
||||
LEFT: {'ROOT': True}, BREAK: {'ROOT': True},
|
||||
CONSTITUENT: {}, ADJUST: {'': True}}
|
||||
for raw_text, segmented, (ids, words, tags, heads, labels, iob), ctnts in gold_parses:
|
||||
for child, head, label in zip(ids, heads, labels):
|
||||
if label != 'ROOT':
|
||||
if head < child:
|
||||
move_labels[RIGHT][label] = True
|
||||
elif head > child:
|
||||
move_labels[LEFT][label] = True
|
||||
for start, end, label in ctnts:
|
||||
move_labels[CONSTITUENT][label] = True
|
||||
return move_labels
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
for i in range(gold.length):
|
||||
gold.c_heads[i] = gold.heads[i]
|
||||
gold.c_labels[i] = self.strings[gold.labels[i]]
|
||||
for end, brackets in gold.brackets.items():
|
||||
for start, label_strs in brackets.items():
|
||||
gold.c_brackets[start][end] = 1
|
||||
for label_str in label_strs:
|
||||
# Add the encoded label to the set
|
||||
gold.brackets[end][start].add(self.strings[label_str])
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
if '-' in name:
|
||||
|
@ -104,6 +122,8 @@ cdef class ArcEager(TransitionSystem):
|
|||
is_valid[LEFT] = _can_left(s)
|
||||
is_valid[RIGHT] = _can_right(s)
|
||||
is_valid[BREAK] = _can_break(s)
|
||||
is_valid[CONSTITUENT] = _can_constituent(s)
|
||||
is_valid[ADJUST] = _can_adjust(s)
|
||||
cdef Transition best
|
||||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
|
@ -162,11 +182,42 @@ cdef int _do_break(const Transition* self, State* state) except -1:
|
|||
push_stack(state)
|
||||
|
||||
|
||||
cdef int _do_constituent(const Transition* self, State* state) except -1:
|
||||
cdef const TokenC* s0 = get_s0(state)
|
||||
if state.ctnt.head == get_idx(state, s0):
|
||||
start = state.ctnt.start
|
||||
else:
|
||||
start = get_idx(state, s0)
|
||||
state.ctnt += 1
|
||||
state.ctnt.start = start
|
||||
state.ctnt.end = s0.r_edge
|
||||
state.ctnt.head = get_idx(state, s0)
|
||||
state.ctnt.label = self.label
|
||||
|
||||
|
||||
cdef int _do_adjust(const Transition* self, State* state) except -1:
|
||||
cdef const TokenC* child
|
||||
cdef const TokenC* s0 = get_s0(state)
|
||||
cdef int n_left = count_left_kids(s0)
|
||||
for i in range(1, n_left):
|
||||
child = get_left(state, s0, i)
|
||||
assert child is not NULL
|
||||
if child.l_edge < state.ctnt.start:
|
||||
state.ctnt.start = child.l_edge
|
||||
break
|
||||
else:
|
||||
msg = ("Error moving bracket --- Move should be invalid if "
|
||||
"no left edge to move to.")
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
do_funcs[SHIFT] = _do_shift
|
||||
do_funcs[REDUCE] = _do_reduce
|
||||
do_funcs[LEFT] = _do_left
|
||||
do_funcs[RIGHT] = _do_right
|
||||
do_funcs[BREAK] = _do_break
|
||||
do_funcs[CONSTITUENT] = _do_constituent
|
||||
do_funcs[ADJUST] = _do_adjust
|
||||
|
||||
|
||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
|
@ -243,11 +294,72 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc
|
|||
return cost
|
||||
|
||||
|
||||
cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
if not _can_constituent(s):
|
||||
return 9000
|
||||
# The gold standard is indexed by end, then by start, then a set of labels
|
||||
brackets = gold.brackets(get_s0(s).r_edge, {})
|
||||
if not brackets:
|
||||
return 2 # 2 loss for bad bracket, only 1 for good bracket bad label
|
||||
# Index the current brackets in the state
|
||||
existing = set()
|
||||
for i in range(s.ctnt_len):
|
||||
if ctnt.end == s.r_edge and ctnt.label == self.label:
|
||||
existing.add(ctnt.start)
|
||||
cdef int loss = 2
|
||||
cdef const TokenC* child
|
||||
cdef const TokenC* s0 = get_s0(s)
|
||||
cdef int n_left = count_left_kids(s0)
|
||||
# Iterate over the possible start positions, and check whether we have a
|
||||
# (start, end, label) match to the gold tree
|
||||
for i in range(1, n_left):
|
||||
child = get_left(s, s0, i)
|
||||
if child.l_edge in brackets and child.l_edge not in existing:
|
||||
if self.label in brackets[child.l_edge]
|
||||
return 0
|
||||
else:
|
||||
loss = 1 # If we see the start position, set loss to 1
|
||||
return loss
|
||||
|
||||
|
||||
cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
if not _can_adjust(s):
|
||||
return 9000
|
||||
# The gold standard is indexed by end, then by start, then a set of labels
|
||||
gold_starts = gold.brackets(get_s0(s).r_edge, {})
|
||||
# Case 1: There are 0 brackets ending at this word.
|
||||
# --> Cost is sunk, but must allow brackets to begin
|
||||
if not gold_starts:
|
||||
return 0
|
||||
# Is the top bracket correct?
|
||||
gold_labels = gold_starts.get(s.ctnt.start, set())
|
||||
# TODO: Case where we have a unary rule
|
||||
# TODO: Case where two brackets end on this word, with top bracket starting
|
||||
# before
|
||||
|
||||
cdef const TokenC* child
|
||||
cdef const TokenC* s0 = get_s0(s)
|
||||
cdef int n_left = count_left_kids(s0)
|
||||
cdef int i
|
||||
# Iterate over the possible start positions, and check whether we have a
|
||||
# (start, end, label) match to the gold tree
|
||||
for i in range(1, n_left):
|
||||
child = get_left(s, s0, i)
|
||||
if child.l_edge in brackets:
|
||||
if self.label in brackets[child.l_edge]:
|
||||
return 0
|
||||
else:
|
||||
loss = 1 # If we see the start position, set loss to 1
|
||||
return loss
|
||||
|
||||
|
||||
get_cost_funcs[SHIFT] = _shift_cost
|
||||
get_cost_funcs[REDUCE] = _reduce_cost
|
||||
get_cost_funcs[LEFT] = _left_cost
|
||||
get_cost_funcs[RIGHT] = _right_cost
|
||||
get_cost_funcs[BREAK] = _break_cost
|
||||
get_cost_funcs[CONSTITUENT] = _constituent_cost
|
||||
get_cost_funcs[ADJUST] = _adjust_cost
|
||||
|
||||
|
||||
cdef inline bint _can_shift(const State* s) nogil:
|
||||
|
@ -288,3 +400,21 @@ cdef inline bint _can_break(const State* s) nogil:
|
|||
else:
|
||||
seen_headless = True
|
||||
return True
|
||||
|
||||
|
||||
cdef inline bint _can_constituent(const State* s) nogil:
|
||||
return s.stack_len >= 1
|
||||
|
||||
|
||||
cdef inline bint _can_adjust(const State* s) nogil:
|
||||
# Need a left child to move the bracket to
|
||||
cdef const TokenC* child
|
||||
cdef const TokenC* s0 = get_s0(s)
|
||||
cdef int n_left = count_left_kids(s0)
|
||||
cdef int i
|
||||
for i in range(1, n_left):
|
||||
child = get_left(s, s0, i)
|
||||
if child.l_edge < s.ctnt.start:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
|
@ -16,10 +16,12 @@ cdef class GoldParse:
|
|||
cdef readonly dict orths
|
||||
cdef readonly list ner
|
||||
cdef readonly list ents
|
||||
cdef readonly dict brackets
|
||||
|
||||
cdef int* c_tags
|
||||
cdef int* c_heads
|
||||
cdef int* c_labels
|
||||
cdef int** c_brackets
|
||||
cdef Transition* c_ner
|
||||
|
||||
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1
|
||||
|
|
|
@ -30,7 +30,7 @@ def read_json_file(loc):
|
|||
paragraphs.append((paragraph['raw'],
|
||||
tokenized,
|
||||
(ids, words, tags, heads, labels, _iob_to_biluo(iob_ents)),
|
||||
brackets))
|
||||
paragraph.get('brackets', [])))
|
||||
return paragraphs
|
||||
|
||||
|
||||
|
@ -145,7 +145,7 @@ def _parse_line(line):
|
|||
|
||||
|
||||
cdef class GoldParse:
|
||||
def __init__(self, tokens, annot_tuples):
|
||||
def __init__(self, tokens, annot_tuples, brackets=(,)):
|
||||
self.mem = Pool()
|
||||
self.loss = 0
|
||||
self.length = len(tokens)
|
||||
|
@ -155,6 +155,9 @@ cdef class GoldParse:
|
|||
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c_ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
|
||||
self.c_brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
|
||||
for i in range(len(tokens)):
|
||||
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
|
||||
self.tags = [None] * len(tokens)
|
||||
self.heads = [-1] * len(tokens)
|
||||
|
@ -199,6 +202,14 @@ cdef class GoldParse:
|
|||
self.ner[i] = 'I-%s' % label
|
||||
self.ner[end-1] = 'L-%s' % label
|
||||
|
||||
self.brackets = {}
|
||||
for (start_idx, end_idx, label_str) in brackets:
|
||||
if start_idx in idx_map and end_idx in idx_map:
|
||||
start = idx_map[start_idx]
|
||||
end = idx_map[end_idx]
|
||||
self.brackets.setdefault(end, {}).setdefault(start, set())
|
||||
self.brackets[end][start].add(label)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ cdef class GreedyParser:
|
|||
return 0
|
||||
|
||||
def train(self, Tokens tokens, GoldParse gold):
|
||||
py_words = [w.orth_ for w in tokens]
|
||||
self.moves.preprocess_gold(gold)
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
|
|
Loading…
Reference in New Issue
Block a user