Merge branch 'develop' of https://github.com/explosion/spaCy into develop

This commit is contained in:
Matthew Honnibal 2018-05-03 14:02:10 +02:00
commit a8e70a4187
3 changed files with 12 additions and 9 deletions

View File

@ -195,14 +195,15 @@ class PrecomputableAffine(Model):
size=tokvecs.size).reshape(tokvecs.shape)
def predict(ids, tokvecs):
# nS ids. nW tokvecs
hiddens = model(tokvecs) # (nW, f, o, p)
# nS ids. nW tokvecs. Exclude the padding array.
hiddens = model(tokvecs[:-1]) # (nW, f, o, p)
vectors = model.ops.allocate((ids.shape[0], model.nO * model.nP), dtype='f')
# need nS vectors
vectors = model.ops.allocate((ids.shape[0], model.nO, model.nP))
for i, feats in enumerate(ids):
for j, id_ in enumerate(feats):
vectors[i] += hiddens[id_, j]
hiddens = hiddens.reshape((hiddens.shape[0] * model.nF, model.nO * model.nP))
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
vectors = vectors.reshape((vectors.shape[0], model.nO, model.nP))
vectors += model.b
vectors = model.ops.asarray(vectors)
if model.nP >= 2:
return model.ops.maxout(vectors)[0]
else:

View File

@ -314,8 +314,8 @@ cdef cppclass StateC:
this._stack[this._s_i] = this.B(0)
this._s_i += 1
this._b_i += 1
if this.B_(0).sent_start == 1:
this.set_break(this.B(0))
if this.safe_get(this.B_(0).l_edge).sent_start == 1:
this.set_break(this.B_(0).l_edge)
if this._b_i > this._break:
this._break = -1

View File

@ -20,6 +20,7 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold cimport GoldParse, GoldParseC
from ..structs cimport TokenC
from ..errors import Errors
from ..tokens.doc cimport Doc, set_children_from_heads
# Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1
@ -530,8 +531,9 @@ cdef class ArcEager(TransitionSystem):
if st._sent[i].head == 0:
st._sent[i].dep = self.root_label
def finalize_doc(self, doc):
def finalize_doc(self, Doc doc):
doc.is_parsed = True
set_children_from_heads(doc.c, doc.length)
cdef int set_valid(self, int* output, const StateC* st) nogil:
cdef bint[N_MOVES] is_valid