Fix begin_training

This commit is contained in:
Matthw Honnibal 2020-05-21 20:46:21 +02:00
parent d507ac28d8
commit bc94fdabd0
2 changed files with 21 additions and 9 deletions

View File

@ -532,10 +532,14 @@ class Tagger(Pipe):
exc=vocab.morphology.exc) exc=vocab.morphology.exc)
self.set_output(len(self.labels)) self.set_output(len(self.labels))
doc_sample = [Doc(self.vocab, words=["hello", "world"])] doc_sample = [Doc(self.vocab, words=["hello", "world"])]
for name, component in pipeline: if pipeline is not None:
if component is self: for name, component in pipeline:
break if component is self:
doc_sample = list(component.pipe(doc_sample)) break
if hasattr(component, "pipe"):
doc_sample = list(component.pipe(doc_sample))
else:
doc_sample = [component(doc) for doc in doc_sample]
self.model.initialize(X=doc_sample) self.model.initialize(X=doc_sample)
# Get batch of example docs, example outputs to call begin_training(). # Get batch of example docs, example outputs to call begin_training().
# This lets the model infer shapes. # This lets the model infer shapes.

View File

@ -630,11 +630,19 @@ cdef class Parser:
if len(doc): if len(doc):
doc_sample.append(doc) doc_sample.append(doc)
gold_sample.append(gold) gold_sample.append(gold)
for name, component in pipeline:
if component is self: if pipeline is not None:
break for name, component in pipeline:
doc_sample = list(component.pipe(doc_sample)) if component is self:
self.model.initialize(doc_sample, gold_sample) break
if hasattr(component, "pipe"):
doc_sample = list(component.pipe(doc_sample))
else:
doc_sample = [component(doc) for doc in doc_sample]
if doc_sample:
self.model.initialize(doc_sample)
else:
self.model.initialize()
if pipeline is not None: if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg) self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)