From 06eb428ed163ec570cc3bddfda45188d2ea501ab Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 28 Jul 2019 17:56:11 +0200 Subject: [PATCH] Make pipe base class a bit less presumptuous --- spacy/pipeline/pipes.pyx | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index df923fb70..3b5e3d41c 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -83,8 +83,12 @@ class Pipe(object): """ for docs in util.minibatch(stream, size=batch_size): docs = list(docs) - scores, tensors = self.predict(docs) - self.set_annotations(docs, scores, tensor=tensors) + predictions = self.predict(docs) + if isinstance(predictions, tuple) and len(tuple) == 2: + scores, tensors = predictions + self.set_annotations(docs, scores, tensor=tensors) + else: + self.set_annotations(docs, predictions) yield from docs def predict(self, docs): @@ -133,7 +137,8 @@ class Pipe(object): If no model has been initialized yet, the model is added.""" if self.model is True: self.model = self.Model(**self.cfg) - link_vectors_to_models(self.vocab) + if hasattr(self, "vocab"): + link_vectors_to_models(self.vocab) if sgd is None: sgd = self.create_optimizer() return sgd @@ -153,7 +158,8 @@ class Pipe(object): serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) if self.model not in (True, False, None): serialize["model"] = self.model.to_bytes - serialize["vocab"] = self.vocab.to_bytes + if hasattr(self, "vocab"): + serialize["vocab"] = self.vocab.to_bytes exclude = util.get_serialization_exclude(serialize, exclude, kwargs) return util.to_bytes(serialize, exclude) @@ -173,7 +179,8 @@ class Pipe(object): deserialize = OrderedDict() deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) - deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) + if hasattr(self, "vocab"): + deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) deserialize["model"] = load_model exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) util.from_bytes(bytes_data, deserialize, exclude)