mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
Fix MultitaskObjective
This commit is contained in:
parent
82135d85b7
commit
61a051f2c0
|
@ -652,10 +652,7 @@ class MultitaskObjective(Tagger):
|
||||||
self.labels[label] = len(self.labels)
|
self.labels[label] = len(self.labels)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
token_vector_width = util.env_opt('token_vector_width')
|
token_vector_width = util.env_opt('token_vector_width')
|
||||||
self.model = chain(
|
self.model = self.Model(len(self.labels), tok2vec=tok2vec)
|
||||||
tok2vec,
|
|
||||||
Softmax(len(self.labels), token_vector_width)
|
|
||||||
)
|
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
|
@ -663,7 +660,20 @@ class MultitaskObjective(Tagger):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, n_tags, tok2vec=None, **cfg):
|
def Model(cls, n_tags, tok2vec=None, **cfg):
|
||||||
return build_tagger_model(n_tags, tok2vec=tok2vec, **cfg)
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
|
softmax = Softmax(n_tags, token_vector_width)
|
||||||
|
model = chain(
|
||||||
|
tok2vec,
|
||||||
|
softmax
|
||||||
|
)
|
||||||
|
model.tok2vec = tok2vec
|
||||||
|
model.softmax = softmax
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict(self, docs):
|
||||||
|
tokvecs = self.model.tok2vec(docs)
|
||||||
|
scores = self.model.softmax(tokvecs)
|
||||||
|
return tokvecs, scores
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user