mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Have oracle functions take a struct instead of a Python object
This commit is contained in:
parent
d1b55310a1
commit
a513ec500f
|
@ -5,9 +5,20 @@ from .syntax.transition_system cimport Transition
|
|||
|
||||
cimport numpy
|
||||
|
||||
|
||||
cdef struct GoldParseC:
|
||||
int* tags
|
||||
int* heads
|
||||
int* labels
|
||||
int** brackets
|
||||
Transition* ner
|
||||
|
||||
|
||||
cdef class GoldParse:
|
||||
cdef Pool mem
|
||||
|
||||
cdef GoldParseC c
|
||||
|
||||
cdef int length
|
||||
cdef readonly int loss
|
||||
cdef readonly list tags
|
||||
|
@ -22,8 +33,4 @@ cdef class GoldParse:
|
|||
cdef readonly list gold_to_cand
|
||||
cdef readonly list orig_annot
|
||||
|
||||
cdef int* c_tags
|
||||
cdef int* c_heads
|
||||
cdef int* c_labels
|
||||
cdef int** c_brackets
|
||||
cdef Transition* c_ner
|
||||
|
||||
|
|
|
@ -169,13 +169,13 @@ cdef class GoldParse:
|
|||
self.length = len(tokens)
|
||||
|
||||
# These are filled by the tagger/parser/entity recogniser
|
||||
self.c_tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c_ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
|
||||
self.c_brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
|
||||
self.c.tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
|
||||
self.c.brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
|
||||
for i in range(len(tokens)):
|
||||
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
self.c.brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||
|
||||
self.tags = [None] * len(tokens)
|
||||
self.heads = [None] * len(tokens)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ._state cimport State
|
||||
|
@ -11,6 +12,7 @@ from ..structs cimport TokenC
|
|||
|
||||
from .transition_system cimport do_func_t, get_cost_func_t
|
||||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
|
@ -65,14 +67,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
for i in range(gold.length):
|
||||
if gold.heads[i] is None: # Missing values
|
||||
gold.c_heads[i] = i
|
||||
gold.c_labels[i] = -1
|
||||
gold.c.heads[i] = i
|
||||
gold.c.labels[i] = -1
|
||||
else:
|
||||
gold.c_heads[i] = gold.heads[i]
|
||||
gold.c_labels[i] = self.strings[gold.labels[i]]
|
||||
gold.c.heads[i] = gold.heads[i]
|
||||
gold.c.labels[i] = self.strings[gold.labels[i]]
|
||||
for end, brackets in gold.brackets.items():
|
||||
for start, label_strs in brackets.items():
|
||||
gold.c_brackets[start][end] = 1
|
||||
gold.c.brackets[start][end] = 1
|
||||
for label_str in label_strs:
|
||||
# Add the encoded label to the set
|
||||
gold.brackets[end][start].add(self.strings[label_str])
|
||||
|
@ -214,78 +216,78 @@ cdef int _do_break(const Transition* self, State* state) except -1:
|
|||
if not at_eol(state):
|
||||
push_stack(state)
|
||||
|
||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_shift(s):
|
||||
return 9000
|
||||
cost = 0
|
||||
cost += head_in_stack(s, s.i, gold.c_heads)
|
||||
cost += children_in_stack(s, s.i, gold.c_heads)
|
||||
cost += head_in_stack(s, s.i, gold.heads)
|
||||
cost += children_in_stack(s, s.i, gold.heads)
|
||||
# If we can break, and there's no cost to doing so, we should
|
||||
if _can_break(s) and _break_cost(self, s, gold) == 0:
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_right(s):
|
||||
return 9000
|
||||
cost = 0
|
||||
if gold.c_heads[s.i] == s.stack[0]:
|
||||
cost += self.label != gold.c_labels[s.i]
|
||||
if gold.heads[s.i] == s.stack[0]:
|
||||
cost += self.label != gold.labels[s.i]
|
||||
return cost
|
||||
# This indicates missing head
|
||||
if gold.c_labels[s.i] != -1:
|
||||
cost += head_in_buffer(s, s.i, gold.c_heads)
|
||||
cost += children_in_stack(s, s.i, gold.c_heads)
|
||||
cost += head_in_stack(s, s.i, gold.c_heads)
|
||||
if gold.labels[s.i] != -1:
|
||||
cost += head_in_buffer(s, s.i, gold.heads)
|
||||
cost += children_in_stack(s, s.i, gold.heads)
|
||||
cost += head_in_stack(s, s.i, gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_left(s):
|
||||
return 9000
|
||||
cost = 0
|
||||
if gold.c_heads[s.stack[0]] == s.i:
|
||||
cost += self.label != gold.c_labels[s.stack[0]]
|
||||
if gold.heads[s.stack[0]] == s.i:
|
||||
cost += self.label != gold.labels[s.stack[0]]
|
||||
return cost
|
||||
# If we're at EOL, then the left arc will add an arc to ROOT.
|
||||
elif at_eol(s):
|
||||
# Are we root?
|
||||
if gold.c_labels[s.stack[0]] != -1:
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
# If we're at EOL, prefer to reduce or break over left-arc
|
||||
if _can_reduce(s) or _can_break(s):
|
||||
cost += gold.c_heads[s.stack[0]] != s.stack[0]
|
||||
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||
# Are we labelling correctly?
|
||||
cost += self.label != gold.c_labels[s.stack[0]]
|
||||
cost += self.label != gold.labels[s.stack[0]]
|
||||
return cost
|
||||
|
||||
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
if NON_MONOTONIC and s.stack_len >= 2:
|
||||
cost += gold.c_heads[s.stack[0]] == s.stack[-1]
|
||||
if gold.c_labels[s.stack[0]] != -1:
|
||||
cost += gold.c_heads[s.stack[0]] == s.stack[0]
|
||||
cost += gold.heads[s.stack[0]] == s.stack[-1]
|
||||
if gold.labels[s.stack[0]] != -1:
|
||||
cost += gold.heads[s.stack[0]] == s.stack[0]
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_reduce(s):
|
||||
return 9000
|
||||
cdef int cost = 0
|
||||
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
||||
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||
if NON_MONOTONIC:
|
||||
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
|
||||
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _break_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_break(s):
|
||||
return 9000
|
||||
# When we break, we Reduce all of the words on the stack.
|
||||
cdef int cost = 0
|
||||
# Number of deps between S0...Sn and N0...Nn
|
||||
for i in range(s.i, s.sent_len):
|
||||
cost += children_in_stack(s, i, gold.c_heads)
|
||||
cost += head_in_stack(s, i, gold.c_heads)
|
||||
cost += children_in_stack(s, i, gold.heads)
|
||||
cost += head_in_stack(s, i, gold.heads)
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -360,7 +362,7 @@ cdef inline bint _can_adjust(const State* s) nogil:
|
|||
#elif b0 >= b1:
|
||||
# return False
|
||||
|
||||
cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _constituent_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_constituent(s):
|
||||
return 9000
|
||||
raise Exception("Constituent move should be disabled currently")
|
||||
|
@ -388,7 +390,7 @@ cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gol
|
|||
# loss = 1 # If we see the start position, set loss to 1
|
||||
#return loss
|
||||
|
||||
cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _adjust_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _can_adjust(s):
|
||||
return 9000
|
||||
raise Exception("Adjust move should be disabled currently")
|
||||
|
|
|
@ -8,6 +8,7 @@ from .transition_system cimport do_func_t
|
|||
from ..structs cimport TokenC, Entity
|
||||
|
||||
from thinc.typedefs cimport weight_t
|
||||
from ..gold cimport GoldParseC
|
||||
from ..gold cimport GoldParse
|
||||
|
||||
|
||||
|
@ -94,7 +95,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
for i in range(gold.length):
|
||||
gold.c_ner[i] = self.lookup_transition(gold.ner[i])
|
||||
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
|
||||
|
||||
cdef Transition lookup_transition(self, object name) except *:
|
||||
if name == '-':
|
||||
|
@ -147,13 +148,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
output[i] = _is_valid(m.move, m.label, s)
|
||||
|
||||
|
||||
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
||||
cdef int _get_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||
if not _is_valid(self.move, self.label, s):
|
||||
return 9000
|
||||
cdef bint is_sunk = _entity_is_sunk(s, gold.c_ner)
|
||||
cdef int next_act = gold.c_ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||
cdef bint is_gold = _is_gold(self.move, self.label, gold.c_ner[s.i].move,
|
||||
gold.c_ner[s.i].label, next_act, is_sunk)
|
||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||
cdef bint is_gold = _is_gold(self.move, self.label, gold.ner[s.i].move,
|
||||
gold.ner[s.i].label, next_act, is_sunk)
|
||||
return not is_gold
|
||||
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ cdef class Parser:
|
|||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
best = self.moves.best_gold(scores, state, gold)
|
||||
cost = guess.get_cost(&guess, state, gold)
|
||||
cost = guess.get_cost(&guess, state, &gold.c)
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
guess.do(&guess, state)
|
||||
loss += cost
|
||||
|
@ -177,14 +177,14 @@ cdef class Parser:
|
|||
state = <State*>beam.at(i)
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
beam.costs[i][j] = move.get_cost(move, state, gold)
|
||||
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
||||
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||
elif gold is not None:
|
||||
for i in range(beam.size):
|
||||
state = <State*>beam.at(i)
|
||||
for j in range(self.moves.n_moves):
|
||||
move = &self.moves.c[j]
|
||||
beam.costs[i][j] = move.get_cost(move, state, gold)
|
||||
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
||||
beam.advance(_transition_state, <void*>self.moves.c)
|
||||
state = <State*>beam.at(0)
|
||||
if state.sent[state.i].sent_end:
|
||||
|
|
|
@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t
|
|||
from ..structs cimport TokenC
|
||||
from ._state cimport State
|
||||
from ..gold cimport GoldParse
|
||||
from ..gold cimport GoldParseC
|
||||
from ..strings cimport StringStore
|
||||
|
||||
|
||||
|
@ -14,12 +15,12 @@ cdef struct Transition:
|
|||
|
||||
weight_t score
|
||||
|
||||
int (*get_cost)(const Transition* self, const State* state, GoldParse gold) except -1
|
||||
int (*get_cost)(const Transition* self, const State* state, GoldParseC* gold) except -1
|
||||
int (*do)(const Transition* self, State* state) except -1
|
||||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
|
||||
GoldParse gold) except -1
|
||||
GoldParseC* gold) except -1
|
||||
|
||||
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ cdef class TransitionSystem:
|
|||
cdef weight_t score = MIN_SCORE
|
||||
cdef int i
|
||||
for i in range(self.n_moves):
|
||||
cost = self.c[i].get_cost(&self.c[i], s, gold)
|
||||
cost = self.c[i].get_cost(&self.c[i], s, &gold.c)
|
||||
if scores[i] > score and cost == 0:
|
||||
best = self.c[i]
|
||||
score = scores[i]
|
||||
|
|
Loading…
Reference in New Issue
Block a user