2021-01-29 03:51:21 +03:00
|
|
|
from typing import Sequence, Iterable, Optional, Dict, Callable, List
|
2020-07-22 14:42:59 +03:00
|
|
|
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
2020-09-08 23:44:25 +03:00
|
|
|
from itertools import islice
|
2020-02-18 17:38:18 +03:00
|
|
|
|
2020-10-08 22:33:49 +03:00
|
|
|
from .trainable_pipe import TrainablePipe
|
|
|
|
from ..training import Example, validate_examples, validate_get_examples
|
2020-01-29 19:06:46 +03:00
|
|
|
from ..tokens import Doc
|
|
|
|
from ..vocab import Vocab
|
2020-07-22 14:42:59 +03:00
|
|
|
from ..language import Language
|
2020-07-25 16:01:15 +03:00
|
|
|
from ..errors import Errors
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
default_model_config = """
|
|
|
|
[model]
|
2021-04-22 11:04:15 +03:00
|
|
|
@architectures = "spacy.HashEmbedCNN.v2"
|
2020-07-22 14:42:59 +03:00
|
|
|
pretrained_vectors = null
|
|
|
|
width = 96
|
|
|
|
depth = 4
|
|
|
|
embed_size = 2000
|
|
|
|
window_size = 1
|
|
|
|
maxout_pieces = 3
|
|
|
|
subword_features = true
|
|
|
|
"""
|
|
|
|
DEFAULT_TOK2VEC_MODEL = Config().from_str(default_model_config)["model"]
|
|
|
|
|
|
|
|
|
|
|
|
@Language.factory(
|
|
|
|
"tok2vec", assigns=["doc.tensor"], default_config={"model": DEFAULT_TOK2VEC_MODEL}
|
|
|
|
)
|
|
|
|
def make_tok2vec(nlp: Language, name: str, model: Model) -> "Tok2Vec":
|
|
|
|
return Tok2Vec(nlp.vocab, model, name)
|
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2020-10-08 22:33:49 +03:00
|
|
|
class Tok2Vec(TrainablePipe):
|
2020-08-09 01:48:03 +03:00
|
|
|
"""Apply a "token-to-vector" model and set its outputs in the doc.tensor
|
|
|
|
attribute. This is mostly useful to share a single subnetwork between multiple
|
2020-08-09 23:36:23 +03:00
|
|
|
components, e.g. to have one embedding and CNN network shared between a
|
2020-08-09 01:48:03 +03:00
|
|
|
parser, tagger and NER.
|
|
|
|
|
|
|
|
In order to use the `Tok2Vec` predictions, subsequent components should use
|
2020-08-09 23:36:23 +03:00
|
|
|
the `Tok2VecListener` layer as the tok2vec subnetwork of their model. This
|
2020-08-09 01:48:03 +03:00
|
|
|
layer will read data from the `doc.tensor` attribute during prediction.
|
|
|
|
During training, the `Tok2Vec` component will save its prediction and backprop
|
|
|
|
callback for each batch, so that the subsequent components can backpropagate
|
|
|
|
to the shared weights. This implementation is used because it allows us to
|
|
|
|
avoid relying on object identity within the models to achieve the parameter
|
|
|
|
sharing.
|
|
|
|
"""
|
2020-08-09 23:36:23 +03:00
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
def __init__(self, vocab: Vocab, model: Model, name: str = "tok2vec") -> None:
|
2020-07-28 14:37:31 +03:00
|
|
|
"""Initialize a tok2vec component.
|
|
|
|
|
|
|
|
vocab (Vocab): The shared vocabulary.
|
2020-08-09 01:48:03 +03:00
|
|
|
model (thinc.api.Model[List[Doc], List[Floats2d]]):
|
|
|
|
The Thinc Model powering the pipeline component. It should take
|
|
|
|
a list of Doc objects as input, and output a list of 2d float arrays.
|
2020-07-28 14:37:31 +03:00
|
|
|
name (str): The component instance name.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/tok2vec#init
|
2020-01-29 19:06:46 +03:00
|
|
|
"""
|
|
|
|
self.vocab = vocab
|
|
|
|
self.model = model
|
2020-07-22 14:42:59 +03:00
|
|
|
self.name = name
|
2021-01-20 03:12:35 +03:00
|
|
|
self.listener_map = {}
|
2020-07-22 14:42:59 +03:00
|
|
|
self.cfg = {}
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2021-01-20 03:12:35 +03:00
|
|
|
@property
|
|
|
|
def listeners(self) -> List["Tok2VecListener"]:
|
|
|
|
"""RETURNS (List[Tok2VecListener]): The listener models listening to this
|
|
|
|
component. Usually internals.
|
|
|
|
"""
|
|
|
|
return [m for c in self.listening_components for m in self.listener_map[c]]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def listening_components(self) -> List[str]:
|
|
|
|
"""RETURNS (List[str]): The downstream components listening to this
|
|
|
|
component. Usually internals.
|
|
|
|
"""
|
|
|
|
return list(self.listener_map.keys())
|
|
|
|
|
|
|
|
def add_listener(self, listener: "Tok2VecListener", component_name: str) -> None:
|
2020-08-09 01:48:03 +03:00
|
|
|
"""Add a listener for a downstream component. Usually internals."""
|
2021-01-20 03:12:35 +03:00
|
|
|
self.listener_map.setdefault(component_name, [])
|
2021-02-02 05:08:40 +03:00
|
|
|
if listener not in self.listener_map[component_name]:
|
|
|
|
self.listener_map[component_name].append(listener)
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2021-01-29 13:41:38 +03:00
|
|
|
def remove_listener(self, listener: "Tok2VecListener", component_name: str) -> bool:
|
2021-01-29 11:36:38 +03:00
|
|
|
"""Remove a listener for a downstream component. Usually internals."""
|
|
|
|
if component_name in self.listener_map:
|
|
|
|
if listener in self.listener_map[component_name]:
|
|
|
|
self.listener_map[component_name].remove(listener)
|
|
|
|
# If no listeners are left, remove entry
|
|
|
|
if not self.listener_map[component_name]:
|
|
|
|
del self.listener_map[component_name]
|
2021-01-29 13:41:38 +03:00
|
|
|
return True
|
|
|
|
return False
|
2021-01-29 11:36:38 +03:00
|
|
|
|
2021-01-20 03:12:35 +03:00
|
|
|
def find_listeners(self, component) -> None:
|
|
|
|
"""Walk over a model of a processing component, looking for layers that
|
|
|
|
are Tok2vecListener subclasses that have an upstream_name that matches
|
|
|
|
this component. Listeners can also set their upstream_name attribute to
|
|
|
|
the wildcard string '*' to match any `Tok2Vec`.
|
2020-08-09 01:48:03 +03:00
|
|
|
|
|
|
|
You're unlikely to ever need multiple `Tok2Vec` components, so it's
|
|
|
|
fine to leave your listeners upstream_name on '*'.
|
|
|
|
"""
|
2021-01-20 03:12:35 +03:00
|
|
|
names = ("*", self.name)
|
|
|
|
if isinstance(getattr(component, "model", None), Model):
|
|
|
|
for node in component.model.walk():
|
|
|
|
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
|
|
|
self.add_listener(node, component.name)
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2020-07-28 14:37:31 +03:00
|
|
|
def predict(self, docs: Iterable[Doc]):
|
|
|
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
|
|
|
Returns a single tensor for a batch of documents.
|
|
|
|
|
|
|
|
docs (Iterable[Doc]): The documents to predict.
|
|
|
|
RETURNS: Vector representations for each token in the documents.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/tok2vec#predict
|
2020-01-29 19:06:46 +03:00
|
|
|
"""
|
|
|
|
tokvecs = self.model.predict(docs)
|
|
|
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
|
|
|
for listener in self.listeners:
|
2021-02-11 03:37:39 +03:00
|
|
|
listener.receive(batch_id, tokvecs, _empty_backprop)
|
2020-01-29 19:06:46 +03:00
|
|
|
return tokvecs
|
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
2020-07-28 14:37:31 +03:00
|
|
|
"""Modify a batch of documents, using pre-computed scores.
|
|
|
|
|
|
|
|
docs (Iterable[Doc]): The documents to modify.
|
|
|
|
tokvecses: The tensors to set, produced by Tok2Vec.predict.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/tok2vec#set_annotations
|
2020-01-29 19:06:46 +03:00
|
|
|
"""
|
|
|
|
for doc, tokvecs in zip(docs, tokvecses):
|
|
|
|
assert tokvecs.shape[0] == len(doc)
|
|
|
|
doc.tensor = tokvecs
|
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
def update(
|
|
|
|
self,
|
|
|
|
examples: Iterable[Example],
|
|
|
|
*,
|
|
|
|
drop: float = 0.0,
|
|
|
|
sgd: Optional[Optimizer] = None,
|
|
|
|
losses: Optional[Dict[str, float]] = None,
|
|
|
|
):
|
2020-07-28 14:37:31 +03:00
|
|
|
"""Learn from a batch of documents and gold-standard information,
|
|
|
|
updating the pipe's model.
|
|
|
|
|
|
|
|
examples (Iterable[Example]): A batch of Example objects.
|
|
|
|
drop (float): The dropout rate.
|
|
|
|
sgd (thinc.api.Optimizer): The optimizer.
|
|
|
|
losses (Dict[str, float]): Optional record of the loss during training.
|
|
|
|
Updated using the component name as the key.
|
|
|
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/tok2vec#update
|
2020-01-29 19:06:46 +03:00
|
|
|
"""
|
|
|
|
if losses is None:
|
|
|
|
losses = {}
|
2020-08-12 00:29:31 +03:00
|
|
|
validate_examples(examples, "Tok2Vec.update")
|
2020-06-26 20:34:12 +03:00
|
|
|
docs = [eg.predicted for eg in examples]
|
2020-01-29 19:06:46 +03:00
|
|
|
set_dropout_rate(self.model, drop)
|
|
|
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
2020-04-21 20:30:41 +03:00
|
|
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
|
|
|
losses.setdefault(self.name, 0.0)
|
|
|
|
|
|
|
|
def accumulate_gradient(one_d_tokvecs):
|
|
|
|
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
|
|
|
to all but the last listener. Only the last one does the backprop.
|
|
|
|
"""
|
|
|
|
nonlocal d_tokvecs
|
|
|
|
for i in range(len(one_d_tokvecs)):
|
|
|
|
d_tokvecs[i] += one_d_tokvecs[i]
|
|
|
|
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
|
|
|
|
|
|
|
def backprop(one_d_tokvecs):
|
|
|
|
"""Callback to actually do the backprop. Passed to last listener."""
|
|
|
|
accumulate_gradient(one_d_tokvecs)
|
|
|
|
d_docs = bp_tokvecs(d_tokvecs)
|
|
|
|
if sgd is not None:
|
2020-10-05 17:23:33 +03:00
|
|
|
self.finish_update(sgd)
|
2020-04-21 20:30:41 +03:00
|
|
|
return d_docs
|
2020-01-29 19:06:46 +03:00
|
|
|
|
|
|
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
2020-04-21 20:30:41 +03:00
|
|
|
for listener in self.listeners[:-1]:
|
|
|
|
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
2020-08-14 15:58:48 +03:00
|
|
|
if self.listeners:
|
|
|
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
2020-07-09 20:43:39 +03:00
|
|
|
return losses
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2020-07-25 16:01:15 +03:00
|
|
|
def get_loss(self, examples, scores) -> None:
|
2020-01-29 19:06:46 +03:00
|
|
|
pass
|
|
|
|
|
2020-09-28 22:35:09 +03:00
|
|
|
def initialize(
|
2020-07-22 14:42:59 +03:00
|
|
|
self,
|
2020-08-12 00:29:31 +03:00
|
|
|
get_examples: Callable[[], Iterable[Example]],
|
2020-07-28 14:37:31 +03:00
|
|
|
*,
|
2020-09-29 13:20:26 +03:00
|
|
|
nlp: Optional[Language] = None,
|
2020-02-18 17:38:18 +03:00
|
|
|
):
|
2020-09-08 23:44:25 +03:00
|
|
|
"""Initialize the pipe for training, using a representative set
|
|
|
|
of data examples.
|
2020-07-28 14:37:31 +03:00
|
|
|
|
2020-09-08 23:44:25 +03:00
|
|
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
|
|
|
returns a representative sample of gold-standard Example objects.
|
2020-09-29 13:20:26 +03:00
|
|
|
nlp (Language): The current nlp object the component is part of.
|
2020-07-28 14:37:31 +03:00
|
|
|
|
2021-01-30 12:09:38 +03:00
|
|
|
DOCS: https://spacy.io/api/tok2vec#initialize
|
2020-01-29 19:06:46 +03:00
|
|
|
"""
|
2020-10-08 22:33:49 +03:00
|
|
|
validate_get_examples(get_examples, "Tok2Vec.initialize")
|
2020-09-08 23:44:25 +03:00
|
|
|
doc_sample = []
|
|
|
|
for example in islice(get_examples(), 10):
|
|
|
|
doc_sample.append(example.x)
|
|
|
|
assert doc_sample, Errors.E923.format(name=self.name)
|
|
|
|
self.model.initialize(X=doc_sample)
|
2020-01-29 19:06:46 +03:00
|
|
|
|
2020-07-31 00:30:54 +03:00
|
|
|
def add_label(self, label):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
|
|
|
|
class Tok2VecListener(Model):
|
|
|
|
"""A layer that gets fed its answers from an upstream connection,
|
|
|
|
for instance from a component earlier in the pipeline.
|
2020-02-18 17:38:18 +03:00
|
|
|
|
2020-08-09 01:48:03 +03:00
|
|
|
The Tok2VecListener layer is used as a sublayer within a component such
|
|
|
|
as a parser, NER or text categorizer. Usually you'll have multiple listeners
|
|
|
|
connecting to a single upstream Tok2Vec component, that's earlier in the
|
|
|
|
pipeline. The Tok2VecListener layers act as proxies, passing the predictions
|
|
|
|
from the Tok2Vec component into downstream components, and communicating
|
|
|
|
gradients back upstream.
|
|
|
|
"""
|
2020-08-09 23:36:23 +03:00
|
|
|
|
2020-01-29 19:06:46 +03:00
|
|
|
name = "tok2vec-listener"
|
|
|
|
|
2020-07-22 14:42:59 +03:00
|
|
|
def __init__(self, upstream_name: str, width: int) -> None:
|
2020-08-09 01:48:03 +03:00
|
|
|
"""
|
|
|
|
upstream_name (str): A string to identify the 'upstream' Tok2Vec component
|
|
|
|
to communicate with. The upstream name should either be the wildcard
|
|
|
|
string '*', or the name of the `Tok2Vec` component. You'll almost
|
|
|
|
never have multiple upstream Tok2Vec components, so the wildcard
|
|
|
|
string will almost always be fine.
|
|
|
|
width (int):
|
|
|
|
The width of the vectors produced by the upstream tok2vec component.
|
|
|
|
"""
|
2020-01-29 19:06:46 +03:00
|
|
|
Model.__init__(self, name=self.name, forward=forward, dims={"nO": width})
|
|
|
|
self.upstream_name = upstream_name
|
|
|
|
self._batch_id = None
|
|
|
|
self._outputs = None
|
|
|
|
self._backprop = None
|
|
|
|
|
|
|
|
@classmethod
|
2020-08-09 01:48:03 +03:00
|
|
|
def get_batch_id(cls, inputs: List[Doc]) -> int:
|
|
|
|
"""Calculate a content-sensitive hash of the batch of documents, to check
|
|
|
|
whether the next batch of documents is unexpected.
|
|
|
|
"""
|
2020-01-29 19:06:46 +03:00
|
|
|
return sum(sum(token.orth for token in doc) for doc in inputs)
|
|
|
|
|
2020-07-25 16:01:15 +03:00
|
|
|
def receive(self, batch_id: int, outputs, backprop) -> None:
|
2020-08-09 01:48:03 +03:00
|
|
|
"""Store a batch of training predictions and a backprop callback. The
|
|
|
|
predictions and callback are produced by the upstream Tok2Vec component,
|
|
|
|
and later will be used when the listener's component's model is called.
|
|
|
|
"""
|
2020-01-29 19:06:46 +03:00
|
|
|
self._batch_id = batch_id
|
|
|
|
self._outputs = outputs
|
|
|
|
self._backprop = backprop
|
|
|
|
|
2020-07-25 16:01:15 +03:00
|
|
|
def verify_inputs(self, inputs) -> bool:
|
2020-08-09 01:48:03 +03:00
|
|
|
"""Check that the batch of Doc objects matches the ones we have a
|
|
|
|
prediction for.
|
|
|
|
"""
|
2020-01-29 19:06:46 +03:00
|
|
|
if self._batch_id is None and self._outputs is None:
|
2020-07-25 16:01:15 +03:00
|
|
|
raise ValueError(Errors.E954)
|
2020-01-29 19:06:46 +03:00
|
|
|
else:
|
|
|
|
batch_id = self.get_batch_id(inputs)
|
|
|
|
if batch_id != self._batch_id:
|
2020-07-25 16:01:15 +03:00
|
|
|
raise ValueError(Errors.E953.format(id1=batch_id, id2=self._batch_id))
|
2020-01-29 19:06:46 +03:00
|
|
|
else:
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
2020-07-25 16:01:15 +03:00
|
|
|
def forward(model: Tok2VecListener, inputs, is_train: bool):
|
2020-08-09 01:48:03 +03:00
|
|
|
"""Supply the outputs from the upstream Tok2Vec component."""
|
2020-01-29 19:06:46 +03:00
|
|
|
if is_train:
|
|
|
|
model.verify_inputs(inputs)
|
|
|
|
return model._outputs, model._backprop
|
|
|
|
else:
|
2020-08-29 04:46:50 +03:00
|
|
|
# This is pretty grim, but it's hard to do better :(.
|
|
|
|
# It's hard to avoid relying on the doc.tensor attribute, because the
|
|
|
|
# pipeline components can batch the data differently during prediction.
|
|
|
|
# That doesn't happen in update, where the nlp object works on batches
|
|
|
|
# of data.
|
|
|
|
# When the components batch differently, we don't receive a matching
|
|
|
|
# prediction from the upstream, so we can't predict.
|
2021-02-12 15:14:30 +03:00
|
|
|
outputs = []
|
|
|
|
width = model.get_dim("nO")
|
|
|
|
for doc in inputs:
|
|
|
|
if doc.tensor.size == 0:
|
|
|
|
# But we do need to do *something* if the tensor hasn't been set.
|
|
|
|
# The compromise is to at least return data of the right shape,
|
|
|
|
# so the output is valid.
|
|
|
|
outputs.append(model.ops.alloc2f(len(doc), width))
|
|
|
|
else:
|
|
|
|
outputs.append(doc.tensor)
|
2020-08-29 04:46:50 +03:00
|
|
|
return outputs, lambda dX: []
|
2021-02-11 03:37:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _empty_backprop(dX): # for pickling
|
|
|
|
return []
|