Make pipe base class a bit less presumptuous

This commit is contained in:
Matthew Honnibal 2019-07-28 17:56:11 +02:00
parent 16b5144095
commit 06eb428ed1

View File

@ -83,8 +83,12 @@ class Pipe(object):
""" """
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
docs = list(docs) docs = list(docs)
scores, tensors = self.predict(docs) predictions = self.predict(docs)
if isinstance(predictions, tuple) and len(tuple) == 2:
scores, tensors = predictions
self.set_annotations(docs, scores, tensor=tensors) self.set_annotations(docs, scores, tensor=tensors)
else:
self.set_annotations(docs, predictions)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
@ -133,6 +137,7 @@ class Pipe(object):
If no model has been initialized yet, the model is added.""" If no model has been initialized yet, the model is added."""
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
if hasattr(self, "vocab"):
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
@ -153,6 +158,7 @@ class Pipe(object):
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
if self.model not in (True, False, None): if self.model not in (True, False, None):
serialize["model"] = self.model.to_bytes serialize["model"] = self.model.to_bytes
if hasattr(self, "vocab"):
serialize["vocab"] = self.vocab.to_bytes serialize["vocab"] = self.vocab.to_bytes
exclude = util.get_serialization_exclude(serialize, exclude, kwargs) exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
@ -173,6 +179,7 @@ class Pipe(object):
deserialize = OrderedDict() deserialize = OrderedDict()
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
if hasattr(self, "vocab"):
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["model"] = load_model deserialize["model"] = load_model
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)