mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
* Have oracle functions take a struct instead of a Python object
This commit is contained in:
parent
d1b55310a1
commit
a513ec500f
|
@ -5,9 +5,20 @@ from .syntax.transition_system cimport Transition
|
||||||
|
|
||||||
cimport numpy
|
cimport numpy
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct GoldParseC:
|
||||||
|
int* tags
|
||||||
|
int* heads
|
||||||
|
int* labels
|
||||||
|
int** brackets
|
||||||
|
Transition* ner
|
||||||
|
|
||||||
|
|
||||||
cdef class GoldParse:
|
cdef class GoldParse:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
|
|
||||||
|
cdef GoldParseC c
|
||||||
|
|
||||||
cdef int length
|
cdef int length
|
||||||
cdef readonly int loss
|
cdef readonly int loss
|
||||||
cdef readonly list tags
|
cdef readonly list tags
|
||||||
|
@ -22,8 +33,4 @@ cdef class GoldParse:
|
||||||
cdef readonly list gold_to_cand
|
cdef readonly list gold_to_cand
|
||||||
cdef readonly list orig_annot
|
cdef readonly list orig_annot
|
||||||
|
|
||||||
cdef int* c_tags
|
|
||||||
cdef int* c_heads
|
|
||||||
cdef int* c_labels
|
|
||||||
cdef int** c_brackets
|
|
||||||
cdef Transition* c_ner
|
|
||||||
|
|
|
@ -169,13 +169,13 @@ cdef class GoldParse:
|
||||||
self.length = len(tokens)
|
self.length = len(tokens)
|
||||||
|
|
||||||
# These are filled by the tagger/parser/entity recogniser
|
# These are filled by the tagger/parser/entity recogniser
|
||||||
self.c_tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
self.c.tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||||
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||||
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||||
self.c_ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
|
self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
|
||||||
self.c_brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
|
self.c.brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
|
||||||
for i in range(len(tokens)):
|
for i in range(len(tokens)):
|
||||||
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
self.c.brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
|
||||||
|
|
||||||
self.tags = [None] * len(tokens)
|
self.tags = [None] * len(tokens)
|
||||||
self.heads = [None] * len(tokens)
|
self.heads = [None] * len(tokens)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ._state cimport State
|
from ._state cimport State
|
||||||
|
@ -11,6 +12,7 @@ from ..structs cimport TokenC
|
||||||
|
|
||||||
from .transition_system cimport do_func_t, get_cost_func_t
|
from .transition_system cimport do_func_t, get_cost_func_t
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
from ..gold cimport GoldParseC
|
||||||
|
|
||||||
|
|
||||||
DEF NON_MONOTONIC = True
|
DEF NON_MONOTONIC = True
|
||||||
|
@ -65,14 +67,14 @@ cdef class ArcEager(TransitionSystem):
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||||
for i in range(gold.length):
|
for i in range(gold.length):
|
||||||
if gold.heads[i] is None: # Missing values
|
if gold.heads[i] is None: # Missing values
|
||||||
gold.c_heads[i] = i
|
gold.c.heads[i] = i
|
||||||
gold.c_labels[i] = -1
|
gold.c.labels[i] = -1
|
||||||
else:
|
else:
|
||||||
gold.c_heads[i] = gold.heads[i]
|
gold.c.heads[i] = gold.heads[i]
|
||||||
gold.c_labels[i] = self.strings[gold.labels[i]]
|
gold.c.labels[i] = self.strings[gold.labels[i]]
|
||||||
for end, brackets in gold.brackets.items():
|
for end, brackets in gold.brackets.items():
|
||||||
for start, label_strs in brackets.items():
|
for start, label_strs in brackets.items():
|
||||||
gold.c_brackets[start][end] = 1
|
gold.c.brackets[start][end] = 1
|
||||||
for label_str in label_strs:
|
for label_str in label_strs:
|
||||||
# Add the encoded label to the set
|
# Add the encoded label to the set
|
||||||
gold.brackets[end][start].add(self.strings[label_str])
|
gold.brackets[end][start].add(self.strings[label_str])
|
||||||
|
@ -214,78 +216,78 @@ cdef int _do_break(const Transition* self, State* state) except -1:
|
||||||
if not at_eol(state):
|
if not at_eol(state):
|
||||||
push_stack(state)
|
push_stack(state)
|
||||||
|
|
||||||
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _shift_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_shift(s):
|
if not _can_shift(s):
|
||||||
return 9000
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
cost += head_in_stack(s, s.i, gold.c_heads)
|
cost += head_in_stack(s, s.i, gold.heads)
|
||||||
cost += children_in_stack(s, s.i, gold.c_heads)
|
cost += children_in_stack(s, s.i, gold.heads)
|
||||||
# If we can break, and there's no cost to doing so, we should
|
# If we can break, and there's no cost to doing so, we should
|
||||||
if _can_break(s) and _break_cost(self, s, gold) == 0:
|
if _can_break(s) and _break_cost(self, s, gold) == 0:
|
||||||
cost += 1
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_right(s):
|
if not _can_right(s):
|
||||||
return 9000
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold.c_heads[s.i] == s.stack[0]:
|
if gold.heads[s.i] == s.stack[0]:
|
||||||
cost += self.label != gold.c_labels[s.i]
|
cost += self.label != gold.labels[s.i]
|
||||||
return cost
|
return cost
|
||||||
# This indicates missing head
|
# This indicates missing head
|
||||||
if gold.c_labels[s.i] != -1:
|
if gold.labels[s.i] != -1:
|
||||||
cost += head_in_buffer(s, s.i, gold.c_heads)
|
cost += head_in_buffer(s, s.i, gold.heads)
|
||||||
cost += children_in_stack(s, s.i, gold.c_heads)
|
cost += children_in_stack(s, s.i, gold.heads)
|
||||||
cost += head_in_stack(s, s.i, gold.c_heads)
|
cost += head_in_stack(s, s.i, gold.heads)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_left(s):
|
if not _can_left(s):
|
||||||
return 9000
|
return 9000
|
||||||
cost = 0
|
cost = 0
|
||||||
if gold.c_heads[s.stack[0]] == s.i:
|
if gold.heads[s.stack[0]] == s.i:
|
||||||
cost += self.label != gold.c_labels[s.stack[0]]
|
cost += self.label != gold.labels[s.stack[0]]
|
||||||
return cost
|
return cost
|
||||||
# If we're at EOL, then the left arc will add an arc to ROOT.
|
# If we're at EOL, then the left arc will add an arc to ROOT.
|
||||||
elif at_eol(s):
|
elif at_eol(s):
|
||||||
# Are we root?
|
# Are we root?
|
||||||
if gold.c_labels[s.stack[0]] != -1:
|
if gold.labels[s.stack[0]] != -1:
|
||||||
# If we're at EOL, prefer to reduce or break over left-arc
|
# If we're at EOL, prefer to reduce or break over left-arc
|
||||||
if _can_reduce(s) or _can_break(s):
|
if _can_reduce(s) or _can_break(s):
|
||||||
cost += gold.c_heads[s.stack[0]] != s.stack[0]
|
cost += gold.heads[s.stack[0]] != s.stack[0]
|
||||||
# Are we labelling correctly?
|
# Are we labelling correctly?
|
||||||
cost += self.label != gold.c_labels[s.stack[0]]
|
cost += self.label != gold.labels[s.stack[0]]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
|
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||||
if NON_MONOTONIC and s.stack_len >= 2:
|
if NON_MONOTONIC and s.stack_len >= 2:
|
||||||
cost += gold.c_heads[s.stack[0]] == s.stack[-1]
|
cost += gold.heads[s.stack[0]] == s.stack[-1]
|
||||||
if gold.c_labels[s.stack[0]] != -1:
|
if gold.labels[s.stack[0]] != -1:
|
||||||
cost += gold.c_heads[s.stack[0]] == s.stack[0]
|
cost += gold.heads[s.stack[0]] == s.stack[0]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _reduce_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_reduce(s):
|
if not _can_reduce(s):
|
||||||
return 9000
|
return 9000
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
|
cost += children_in_buffer(s, s.stack[0], gold.heads)
|
||||||
if NON_MONOTONIC:
|
if NON_MONOTONIC:
|
||||||
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
|
cost += head_in_buffer(s, s.stack[0], gold.heads)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _break_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_break(s):
|
if not _can_break(s):
|
||||||
return 9000
|
return 9000
|
||||||
# When we break, we Reduce all of the words on the stack.
|
# When we break, we Reduce all of the words on the stack.
|
||||||
cdef int cost = 0
|
cdef int cost = 0
|
||||||
# Number of deps between S0...Sn and N0...Nn
|
# Number of deps between S0...Sn and N0...Nn
|
||||||
for i in range(s.i, s.sent_len):
|
for i in range(s.i, s.sent_len):
|
||||||
cost += children_in_stack(s, i, gold.c_heads)
|
cost += children_in_stack(s, i, gold.heads)
|
||||||
cost += head_in_stack(s, i, gold.c_heads)
|
cost += head_in_stack(s, i, gold.heads)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -360,7 +362,7 @@ cdef inline bint _can_adjust(const State* s) nogil:
|
||||||
#elif b0 >= b1:
|
#elif b0 >= b1:
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _constituent_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_constituent(s):
|
if not _can_constituent(s):
|
||||||
return 9000
|
return 9000
|
||||||
raise Exception("Constituent move should be disabled currently")
|
raise Exception("Constituent move should be disabled currently")
|
||||||
|
@ -388,7 +390,7 @@ cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gol
|
||||||
# loss = 1 # If we see the start position, set loss to 1
|
# loss = 1 # If we see the start position, set loss to 1
|
||||||
#return loss
|
#return loss
|
||||||
|
|
||||||
cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _adjust_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _can_adjust(s):
|
if not _can_adjust(s):
|
||||||
return 9000
|
return 9000
|
||||||
raise Exception("Adjust move should be disabled currently")
|
raise Exception("Adjust move should be disabled currently")
|
||||||
|
|
|
@ -8,6 +8,7 @@ from .transition_system cimport do_func_t
|
||||||
from ..structs cimport TokenC, Entity
|
from ..structs cimport TokenC, Entity
|
||||||
|
|
||||||
from thinc.typedefs cimport weight_t
|
from thinc.typedefs cimport weight_t
|
||||||
|
from ..gold cimport GoldParseC
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +95,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
|
|
||||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||||
for i in range(gold.length):
|
for i in range(gold.length):
|
||||||
gold.c_ner[i] = self.lookup_transition(gold.ner[i])
|
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
if name == '-':
|
if name == '-':
|
||||||
|
@ -147,13 +148,13 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
output[i] = _is_valid(m.move, m.label, s)
|
output[i] = _is_valid(m.move, m.label, s)
|
||||||
|
|
||||||
|
|
||||||
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
|
cdef int _get_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
|
||||||
if not _is_valid(self.move, self.label, s):
|
if not _is_valid(self.move, self.label, s):
|
||||||
return 9000
|
return 9000
|
||||||
cdef bint is_sunk = _entity_is_sunk(s, gold.c_ner)
|
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||||
cdef int next_act = gold.c_ner[s.i+1].move if s.i < s.sent_len else OUT
|
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||||
cdef bint is_gold = _is_gold(self.move, self.label, gold.c_ner[s.i].move,
|
cdef bint is_gold = _is_gold(self.move, self.label, gold.ner[s.i].move,
|
||||||
gold.c_ner[s.i].label, next_act, is_sunk)
|
gold.ner[s.i].label, next_act, is_sunk)
|
||||||
return not is_gold
|
return not is_gold
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ cdef class Parser:
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
best = self.moves.best_gold(scores, state, gold)
|
best = self.moves.best_gold(scores, state, gold)
|
||||||
cost = guess.get_cost(&guess, state, gold)
|
cost = guess.get_cost(&guess, state, &gold.c)
|
||||||
self.model.update(context, guess.clas, best.clas, cost)
|
self.model.update(context, guess.clas, best.clas, cost)
|
||||||
guess.do(&guess, state)
|
guess.do(&guess, state)
|
||||||
loss += cost
|
loss += cost
|
||||||
|
@ -177,14 +177,14 @@ cdef class Parser:
|
||||||
state = <State*>beam.at(i)
|
state = <State*>beam.at(i)
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
move = &self.moves.c[j]
|
||||||
beam.costs[i][j] = move.get_cost(move, state, gold)
|
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
||||||
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
beam.is_valid[i][j] = beam.costs[i][j] == 0
|
||||||
elif gold is not None:
|
elif gold is not None:
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
state = <State*>beam.at(i)
|
state = <State*>beam.at(i)
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
move = &self.moves.c[j]
|
move = &self.moves.c[j]
|
||||||
beam.costs[i][j] = move.get_cost(move, state, gold)
|
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
|
||||||
beam.advance(_transition_state, <void*>self.moves.c)
|
beam.advance(_transition_state, <void*>self.moves.c)
|
||||||
state = <State*>beam.at(0)
|
state = <State*>beam.at(0)
|
||||||
if state.sent[state.i].sent_end:
|
if state.sent[state.i].sent_end:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
from ._state cimport State
|
from ._state cimport State
|
||||||
from ..gold cimport GoldParse
|
from ..gold cimport GoldParse
|
||||||
|
from ..gold cimport GoldParseC
|
||||||
from ..strings cimport StringStore
|
from ..strings cimport StringStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,12 +15,12 @@ cdef struct Transition:
|
||||||
|
|
||||||
weight_t score
|
weight_t score
|
||||||
|
|
||||||
int (*get_cost)(const Transition* self, const State* state, GoldParse gold) except -1
|
int (*get_cost)(const Transition* self, const State* state, GoldParseC* gold) except -1
|
||||||
int (*do)(const Transition* self, State* state) except -1
|
int (*do)(const Transition* self, State* state) except -1
|
||||||
|
|
||||||
|
|
||||||
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
|
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
|
||||||
GoldParse gold) except -1
|
GoldParseC* gold) except -1
|
||||||
|
|
||||||
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ cdef class TransitionSystem:
|
||||||
cdef weight_t score = MIN_SCORE
|
cdef weight_t score = MIN_SCORE
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
cost = self.c[i].get_cost(&self.c[i], s, gold)
|
cost = self.c[i].get_cost(&self.c[i], s, &gold.c)
|
||||||
if scores[i] > score and cost == 0:
|
if scores[i] > score and cost == 0:
|
||||||
best = self.c[i]
|
best = self.c[i]
|
||||||
score = scores[i]
|
score = scores[i]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user