diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index af24bf336..820bde7d3 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -287,7 +287,13 @@ class Tagger(Pipe): self.add_label(tag) self.set_output(len(self.labels)) if self.labels: - self.model.initialize(X=doc_sample) + label_sample = [ + self.model.ops.alloc2f(len(doc), len(self.labels)) + for doc in docs + ] + for y in label_sample: + y[:, 0] = 1.0 + self.model.initialize(X=doc_sample, Y=label_sample) else: self.model.initialize() if sgd is None: