mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix tagger training
This commit is contained in:
parent
a2357cce3f
commit
386c1a5bd8
|
@ -343,6 +343,7 @@ class NeuralTagger(BaseThincComponent):
|
|||
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||
bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
@ -386,15 +387,13 @@ class NeuralTagger(BaseThincComponent):
|
|||
vocab.morphology = Morphology(vocab.strings, new_tag_map,
|
||||
vocab.morphology.lemmatizer,
|
||||
exc=vocab.morphology.exc)
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
if self.model is True:
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width,
|
||||
self.model = self.Model(self.vocab.morphology.n_tags,
|
||||
pretrained_dims=self.vocab.vectors_length)
|
||||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, token_vector_width, pretrained_dims=0, **cfg):
|
||||
return build_tagger_model(n_tags, token_vector_width,
|
||||
pretrained_dims, **cfg)
|
||||
def Model(cls, n_tags, **cfg):
|
||||
return build_tagger_model(n_tags, **cfg)
|
||||
|
||||
def use_params(self, params):
|
||||
with self.model.use_params(params):
|
||||
|
|
Loading…
Reference in New Issue
Block a user