mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix MultitaskObjective
This commit is contained in:
parent
82135d85b7
commit
61a051f2c0
|
@ -652,10 +652,7 @@ class MultitaskObjective(Tagger):
|
|||
self.labels[label] = len(self.labels)
|
||||
if self.model is True:
|
||||
token_vector_width = util.env_opt('token_vector_width')
|
||||
self.model = chain(
|
||||
tok2vec,
|
||||
Softmax(len(self.labels), token_vector_width)
|
||||
)
|
||||
self.model = self.Model(len(self.labels), tok2vec=tok2vec)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
|
@ -663,7 +660,20 @@ class MultitaskObjective(Tagger):
|
|||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, tok2vec=None, **cfg):
|
||||
return build_tagger_model(n_tags, tok2vec=tok2vec, **cfg)
|
||||
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||
softmax = Softmax(n_tags, token_vector_width)
|
||||
model = chain(
|
||||
tok2vec,
|
||||
softmax
|
||||
)
|
||||
model.tok2vec = tok2vec
|
||||
model.softmax = softmax
|
||||
return model
|
||||
|
||||
def predict(self, docs):
|
||||
tokvecs = self.model.tok2vec(docs)
|
||||
scores = self.model.softmax(tokvecs)
|
||||
return tokvecs, scores
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
cdef int idx = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user