* Have oracle functions take a struct instead of a Python object

This commit is contained in:
Matthew Honnibal 2015-06-02 20:01:06 +02:00
parent d1b55310a1
commit a513ec500f
7 changed files with 68 additions and 57 deletions

View File

@ -5,9 +5,20 @@ from .syntax.transition_system cimport Transition
cimport numpy
cdef struct GoldParseC:
int* tags
int* heads
int* labels
int** brackets
Transition* ner
cdef class GoldParse:
cdef Pool mem
cdef GoldParseC c
cdef int length
cdef readonly int loss
cdef readonly list tags
@ -22,8 +33,4 @@ cdef class GoldParse:
cdef readonly list gold_to_cand
cdef readonly list orig_annot
cdef int* c_tags
cdef int* c_heads
cdef int* c_labels
cdef int** c_brackets
cdef Transition* c_ner

View File

@ -169,13 +169,13 @@ cdef class GoldParse:
self.length = len(tokens)
# These are filled by the tagger/parser/entity recogniser
self.c_tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
self.c_brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
self.c.tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
self.c.brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
for i in range(len(tokens)):
self.c_brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.tags = [None] * len(tokens)
self.heads = [None] * len(tokens)

View File

@ -1,3 +1,4 @@
# cython: profile=True
from __future__ import unicode_literals
from ._state cimport State
@ -11,6 +12,7 @@ from ..structs cimport TokenC
from .transition_system cimport do_func_t, get_cost_func_t
from ..gold cimport GoldParse
from ..gold cimport GoldParseC
DEF NON_MONOTONIC = True
@ -65,14 +67,14 @@ cdef class ArcEager(TransitionSystem):
cdef int preprocess_gold(self, GoldParse gold) except -1:
for i in range(gold.length):
if gold.heads[i] is None: # Missing values
gold.c_heads[i] = i
gold.c_labels[i] = -1
gold.c.heads[i] = i
gold.c.labels[i] = -1
else:
gold.c_heads[i] = gold.heads[i]
gold.c_labels[i] = self.strings[gold.labels[i]]
gold.c.heads[i] = gold.heads[i]
gold.c.labels[i] = self.strings[gold.labels[i]]
for end, brackets in gold.brackets.items():
for start, label_strs in brackets.items():
gold.c_brackets[start][end] = 1
gold.c.brackets[start][end] = 1
for label_str in label_strs:
# Add the encoded label to the set
gold.brackets[end][start].add(self.strings[label_str])
@ -214,78 +216,78 @@ cdef int _do_break(const Transition* self, State* state) except -1:
if not at_eol(state):
push_stack(state)
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _shift_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_shift(s):
return 9000
cost = 0
cost += head_in_stack(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
cost += head_in_stack(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
# If we can break, and there's no cost to doing so, we should
if _can_break(s) and _break_cost(self, s, gold) == 0:
cost += 1
return cost
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_right(s):
return 9000
cost = 0
if gold.c_heads[s.i] == s.stack[0]:
cost += self.label != gold.c_labels[s.i]
if gold.heads[s.i] == s.stack[0]:
cost += self.label != gold.labels[s.i]
return cost
# This indicates missing head
if gold.c_labels[s.i] != -1:
cost += head_in_buffer(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
cost += head_in_stack(s, s.i, gold.c_heads)
if gold.labels[s.i] != -1:
cost += head_in_buffer(s, s.i, gold.heads)
cost += children_in_stack(s, s.i, gold.heads)
cost += head_in_stack(s, s.i, gold.heads)
return cost
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_left(s):
return 9000
cost = 0
if gold.c_heads[s.stack[0]] == s.i:
cost += self.label != gold.c_labels[s.stack[0]]
if gold.heads[s.stack[0]] == s.i:
cost += self.label != gold.labels[s.stack[0]]
return cost
# If we're at EOL, then the left arc will add an arc to ROOT.
elif at_eol(s):
# Are we root?
if gold.c_labels[s.stack[0]] != -1:
if gold.labels[s.stack[0]] != -1:
# If we're at EOL, prefer to reduce or break over left-arc
if _can_reduce(s) or _can_break(s):
cost += gold.c_heads[s.stack[0]] != s.stack[0]
cost += gold.heads[s.stack[0]] != s.stack[0]
# Are we labelling correctly?
cost += self.label != gold.c_labels[s.stack[0]]
cost += self.label != gold.labels[s.stack[0]]
return cost
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
cost += head_in_buffer(s, s.stack[0], gold.heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC and s.stack_len >= 2:
cost += gold.c_heads[s.stack[0]] == s.stack[-1]
if gold.c_labels[s.stack[0]] != -1:
cost += gold.c_heads[s.stack[0]] == s.stack[0]
cost += gold.heads[s.stack[0]] == s.stack[-1]
if gold.labels[s.stack[0]] != -1:
cost += gold.heads[s.stack[0]] == s.stack[0]
return cost
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _reduce_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_reduce(s):
return 9000
cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
cost += children_in_buffer(s, s.stack[0], gold.heads)
if NON_MONOTONIC:
cost += head_in_buffer(s, s.stack[0], gold.c_heads)
cost += head_in_buffer(s, s.stack[0], gold.heads)
return cost
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _break_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_break(s):
return 9000
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn
for i in range(s.i, s.sent_len):
cost += children_in_stack(s, i, gold.c_heads)
cost += head_in_stack(s, i, gold.c_heads)
cost += children_in_stack(s, i, gold.heads)
cost += head_in_stack(s, i, gold.heads)
return cost
@ -360,7 +362,7 @@ cdef inline bint _can_adjust(const State* s) nogil:
#elif b0 >= b1:
# return False
cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _constituent_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_constituent(s):
return 9000
raise Exception("Constituent move should be disabled currently")
@ -388,7 +390,7 @@ cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gol
# loss = 1 # If we see the start position, set loss to 1
#return loss
cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _adjust_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _can_adjust(s):
return 9000
raise Exception("Adjust move should be disabled currently")

View File

@ -8,6 +8,7 @@ from .transition_system cimport do_func_t
from ..structs cimport TokenC, Entity
from thinc.typedefs cimport weight_t
from ..gold cimport GoldParseC
from ..gold cimport GoldParse
@ -94,7 +95,7 @@ cdef class BiluoPushDown(TransitionSystem):
cdef int preprocess_gold(self, GoldParse gold) except -1:
for i in range(gold.length):
gold.c_ner[i] = self.lookup_transition(gold.ner[i])
gold.c.ner[i] = self.lookup_transition(gold.ner[i])
cdef Transition lookup_transition(self, object name) except *:
if name == '-':
@ -147,13 +148,13 @@ cdef class BiluoPushDown(TransitionSystem):
output[i] = _is_valid(m.move, m.label, s)
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
cdef int _get_cost(const Transition* self, const State* s, GoldParseC* gold) except -1:
if not _is_valid(self.move, self.label, s):
return 9000
cdef bint is_sunk = _entity_is_sunk(s, gold.c_ner)
cdef int next_act = gold.c_ner[s.i+1].move if s.i < s.sent_len else OUT
cdef bint is_gold = _is_gold(self.move, self.label, gold.c_ner[s.i].move,
gold.c_ner[s.i].label, next_act, is_sunk)
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
cdef bint is_gold = _is_gold(self.move, self.label, gold.ner[s.i].move,
gold.ner[s.i].label, next_act, is_sunk)
return not is_gold

View File

@ -136,7 +136,7 @@ cdef class Parser:
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
best = self.moves.best_gold(scores, state, gold)
cost = guess.get_cost(&guess, state, gold)
cost = guess.get_cost(&guess, state, &gold.c)
self.model.update(context, guess.clas, best.clas, cost)
guess.do(&guess, state)
loss += cost
@ -177,14 +177,14 @@ cdef class Parser:
state = <State*>beam.at(i)
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
beam.costs[i][j] = move.get_cost(move, state, gold)
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
beam.is_valid[i][j] = beam.costs[i][j] == 0
elif gold is not None:
for i in range(beam.size):
state = <State*>beam.at(i)
for j in range(self.moves.n_moves):
move = &self.moves.c[j]
beam.costs[i][j] = move.get_cost(move, state, gold)
beam.costs[i][j] = move.get_cost(move, state, &gold.c)
beam.advance(_transition_state, <void*>self.moves.c)
state = <State*>beam.at(0)
if state.sent[state.i].sent_end:

View File

@ -4,6 +4,7 @@ from thinc.typedefs cimport weight_t
from ..structs cimport TokenC
from ._state cimport State
from ..gold cimport GoldParse
from ..gold cimport GoldParseC
from ..strings cimport StringStore
@ -14,12 +15,12 @@ cdef struct Transition:
weight_t score
int (*get_cost)(const Transition* self, const State* state, GoldParse gold) except -1
int (*get_cost)(const Transition* self, const State* state, GoldParseC* gold) except -1
int (*do)(const Transition* self, State* state) except -1
ctypedef int (*get_cost_func_t)(const Transition* self, const State* state,
GoldParse gold) except -1
GoldParseC* gold) except -1
ctypedef int (*do_func_t)(const Transition* self, State* state) except -1

View File

@ -54,7 +54,7 @@ cdef class TransitionSystem:
cdef weight_t score = MIN_SCORE
cdef int i
for i in range(self.n_moves):
cost = self.c[i].get_cost(&self.c[i], s, gold)
cost = self.c[i].get_cost(&self.c[i], s, &gold.c)
if scores[i] > score and cost == 0:
best = self.c[i]
score = scores[i]