Merge branch 'develop' of https://github.com/explosion/spaCy into develop

This commit is contained in:
Matthew Honnibal 2018-05-03 14:02:10 +02:00
commit a8e70a4187
3 changed files with 12 additions and 9 deletions

View File

@ -195,14 +195,15 @@ class PrecomputableAffine(Model):
size=tokvecs.size).reshape(tokvecs.shape) size=tokvecs.size).reshape(tokvecs.shape)
def predict(ids, tokvecs): def predict(ids, tokvecs):
# nS ids. nW tokvecs # nS ids. nW tokvecs. Exclude the padding array.
hiddens = model(tokvecs) # (nW, f, o, p) hiddens = model(tokvecs[:-1]) # (nW, f, o, p)
vectors = model.ops.allocate((ids.shape[0], model.nO * model.nP), dtype='f')
# need nS vectors # need nS vectors
vectors = model.ops.allocate((ids.shape[0], model.nO, model.nP)) hiddens = hiddens.reshape((hiddens.shape[0] * model.nF, model.nO * model.nP))
for i, feats in enumerate(ids): model.ops.scatter_add(vectors, ids.flatten(), hiddens)
for j, id_ in enumerate(feats): vectors = vectors.reshape((vectors.shape[0], model.nO, model.nP))
vectors[i] += hiddens[id_, j]
vectors += model.b vectors += model.b
vectors = model.ops.asarray(vectors)
if model.nP >= 2: if model.nP >= 2:
return model.ops.maxout(vectors)[0] return model.ops.maxout(vectors)[0]
else: else:

View File

@ -314,8 +314,8 @@ cdef cppclass StateC:
this._stack[this._s_i] = this.B(0) this._stack[this._s_i] = this.B(0)
this._s_i += 1 this._s_i += 1
this._b_i += 1 this._b_i += 1
if this.B_(0).sent_start == 1: if this.safe_get(this.B_(0).l_edge).sent_start == 1:
this.set_break(this.B(0)) this.set_break(this.B_(0).l_edge)
if this._b_i > this._break: if this._b_i > this._break:
this._break = -1 this._break = -1

View File

@ -20,6 +20,7 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold cimport GoldParse, GoldParseC from ..gold cimport GoldParse, GoldParseC
from ..structs cimport TokenC from ..structs cimport TokenC
from ..errors import Errors from ..errors import Errors
from ..tokens.doc cimport Doc, set_children_from_heads
# Calculate cost as gold/not gold. We don't use scalar value anyway. # Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1 cdef int BINARY_COSTS = 1
@ -530,8 +531,9 @@ cdef class ArcEager(TransitionSystem):
if st._sent[i].head == 0: if st._sent[i].head == 0:
st._sent[i].dep = self.root_label st._sent[i].dep = self.root_label
def finalize_doc(self, doc): def finalize_doc(self, Doc doc):
doc.is_parsed = True doc.is_parsed = True
set_children_from_heads(doc.c, doc.length)
cdef int set_valid(self, int* output, const StateC* st) nogil: cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid cdef bint[N_MOVES] is_valid