mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Make pipe base class a bit less presumptuous
This commit is contained in:
parent
16b5144095
commit
06eb428ed1
|
@ -83,8 +83,12 @@ class Pipe(object):
|
||||||
"""
|
"""
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
scores, tensors = self.predict(docs)
|
predictions = self.predict(docs)
|
||||||
|
if isinstance(predictions, tuple) and len(tuple) == 2:
|
||||||
|
scores, tensors = predictions
|
||||||
self.set_annotations(docs, scores, tensor=tensors)
|
self.set_annotations(docs, scores, tensor=tensors)
|
||||||
|
else:
|
||||||
|
self.set_annotations(docs, predictions)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
@ -133,6 +137,7 @@ class Pipe(object):
|
||||||
If no model has been initialized yet, the model is added."""
|
If no model has been initialized yet, the model is added."""
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
|
if hasattr(self, "vocab"):
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
|
@ -153,6 +158,7 @@ class Pipe(object):
|
||||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||||
if self.model not in (True, False, None):
|
if self.model not in (True, False, None):
|
||||||
serialize["model"] = self.model.to_bytes
|
serialize["model"] = self.model.to_bytes
|
||||||
|
if hasattr(self, "vocab"):
|
||||||
serialize["vocab"] = self.vocab.to_bytes
|
serialize["vocab"] = self.vocab.to_bytes
|
||||||
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
@ -173,6 +179,7 @@ class Pipe(object):
|
||||||
|
|
||||||
deserialize = OrderedDict()
|
deserialize = OrderedDict()
|
||||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||||
|
if hasattr(self, "vocab"):
|
||||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user