diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index c94cb6b58..f831caefe 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -289,7 +289,14 @@ class Tagger(Pipe): err = Errors.E1006.format(name="Tagger") raise ValueError(err) self.set_output(len(self.labels)) - self.model.initialize(X=doc_sample) + if doc_sample: + label_sample = [ + self.model.ops.alloc2f(len(doc), len(self.labels)) + for doc in doc_sample + ] + self.model.initialize(X=doc_sample, Y=label_sample) + else: + self.model.initialize() if sgd is None: sgd = self.create_optimizer() return sgd