diff --git a/website/docs/usage/layers-architectures.md b/website/docs/usage/layers-architectures.md index 641db02f5..eb6f8b288 100644 --- a/website/docs/usage/layers-architectures.md +++ b/website/docs/usage/layers-architectures.md @@ -502,7 +502,7 @@ with Model.define_operators({">>": chain}): ## Create new trainable components {#components} -In addition to [swapping out](#swap-architectures) default models in built-in +In addition to [swapping out](#swap-architectures) layers in existing components, you can also implement an entirely new, [trainable](/usage/processing-pipelines#trainable-components) pipeline component from scratch. This can be done by creating a new class inheriting from @@ -523,20 +523,28 @@ overview of the `TrainablePipe` methods used by This section outlines an example use-case of implementing a **novel relation extraction component** from scratch. We'll implement a binary relation extraction method that determines whether or not **two entities** in a document -are related, and if so, what type of relation. We'll allow multiple types of -relations between two such entities (multi-label setting). There are two major -steps required: +are related, and if so, what type of relation connects them. We allow multiple +types of relations between two such entities (a multi-label setting). There are +two major steps required: 1. Implement a [machine learning model](#component-rel-model) specific to this - task. It will have to extract candidates from a [`Doc`](/api/doc) and predict - a relation for the available candidate pairs. -2. Implement a custom [pipeline component](#component-rel-pipe) powered by the - machine learning model that sets annotations on the [`Doc`](/api/doc) passing - through the pipeline. + task. It will have to extract candidate relation instances from a + [`Doc`](/api/doc) and predict the corresponding scores for each relation + label. +2. Implement a custom [pipeline component](#component-rel-pipe) - powered by the + machine learning model from step 1 - that translates the predicted scores + into annotations that are stored on the [`Doc`](/api/doc) objects as they + pass through the `nlp` pipeline. - + +Run this example use-case by using our project template. It includes all the +code to create the ML model and the pipeline component from scratch. +It also contains two config files to train the model: +one to run on CPU with a Tok2Vec layer, and one for the GPU using a transformer. +The project applies the relation extraction component to identify biomolecular +interactions in a sample dataset, but you can easily swap in your own dataset +for your experiments in any other domain. + #### Step 1: Implementing the Model {#component-rel-model} @@ -552,41 +560,17 @@ matrix** (~~Floats2d~~) of predictions: > for details. ```python -### Register the model architecture -@registry.architectures.register("rel_model.v1") +### The model architecture +@spacy.registry.architectures.register("rel_model.v1") def create_relation_model(...) -> Model[List[Doc], Floats2d]: model = ... # 👈 model will go here return model ``` -The first layer in this model will typically be an -[embedding layer](/usage/embeddings-transformers) such as a -[`Tok2Vec`](/api/tok2vec) component or a [`Transformer`](/api/transformer). This -layer is assumed to be of type ~~Model[List[Doc], List[Floats2d]]~~ as it -transforms each **document into a list of tokens**, with each token being -represented by its embedding in the vector space. - -Next, we need a method that **generates pairs of entities** that we want to -classify as being related or not. As these candidate pairs are typically formed -within one document, this function takes a [`Doc`](/api/doc) as input and -outputs a `List` of `Span` tuples. For instance, a very straightforward -implementation would be to just take any two entities from the same document: - -```python -### Simple candiate generation -def get_candidates(doc: Doc) -> List[Tuple[Span, Span]]: - candidates = [] - for ent1 in doc.ents: - for ent2 in doc.ents: - candidates.append((ent1, ent2)) - return candidates -``` - -But we could also refine this further by **excluding relations** of an entity -with itself, and posing a **maximum distance** (in number of tokens) between two -entities. We register this function in the -[`@misc` registry](/api/top-level#registry) so we can refer to it from the -config, and easily swap it out for any other candidate generation function. +We adapt a **modular approach** to the definition of this relation model, and +define it as chaining two layers together: the first layer that generates an +instance tensor from a given set of documents, and the second layer that +transforms the instance tensor into a final tensor holding the predictions: > #### config.cfg (excerpt) > @@ -594,18 +578,159 @@ config, and easily swap it out for any other candidate generation function. > [model] > @architectures = "rel_model.v1" > -> [model.tok2vec] +> [model.create_instance_tensor] > # ... > -> [model.get_candidates] -> @misc = "rel_cand_generator.v1" -> max_length = 20 +> [model.classification_layer] +> # ... > ``` ```python -### Extended candidate generation {highlight="1,2,7,8"} -@registry.misc.register("rel_cand_generator.v1") -def create_candidate_indices(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]: +### The model architecture {highlight="6"} +@spacy.registry.architectures.register("rel_model.v1") +def create_relation_model( + create_instance_tensor: Model[List[Doc], Floats2d], + classification_layer: Model[Floats2d, Floats2d], +) -> Model[List[Doc], Floats2d]: + model = chain(create_instance_tensor, classification_layer) + return model +``` + +The `classification_layer` could be something like a +[Linear](https://thinc.ai/docs/api-layers#linear) layer followed by a +[logistic](https://thinc.ai/docs/api-layers#logistic) activation function: + +> #### config.cfg (excerpt) +> +> ```ini +> [model.classification_layer] +> @architectures = "rel_classification_layer.v1" +> nI = null +> nO = null +> ``` + +```python +### The classification layer +@spacy.registry.architectures.register("rel_classification_layer.v1") +def create_classification_layer( + nO: int = None, nI: int = None +) -> Model[Floats2d, Floats2d]: + return chain(Linear(nO=nO, nI=nI), Logistic()) +``` + +The first layer that **creates the instance tensor** can be defined by +implementing a +[custom forward function](https://thinc.ai/docs/usage-models#weights-layers-forward) +with an appropriate backpropagation callback. We also define an +[initialization method](https://thinc.ai/docs/usage-models#weights-layers-init) +that ensures that the layer is properly set up for training. + +We omit some of the implementation details here, and refer to the +[spaCy project](https://github.com/explosion/projects/tree/v3/tutorials/rel_component) +that has the full implementation. + +> #### config.cfg (excerpt) +> +> ```ini +> [model.create_instance_tensor] +> @architectures = "rel_instance_tensor.v1" +> +> [model.create_instance_tensor.tok2vec] +> @architectures = "spacy.HashEmbedCNN.v1" +> # ... +> +> [model.create_instance_tensor.pooling] +> @layers = "reduce_mean.v1" +> +> [model.create_instance_tensor.get_instances] +> # ... +> ``` + +```python +### The layer that creates the instance tensor +@spacy.registry.architectures.register("rel_instance_tensor.v1") +def create_tensors( + tok2vec: Model[List[Doc], List[Floats2d]], + pooling: Model[Ragged, Floats2d], + get_instances: Callable[[Doc], List[Tuple[Span, Span]]], +) -> Model[List[Doc], Floats2d]: + + return Model( + "instance_tensors", + instance_forward, + init=instance_init, + layers=[tok2vec, pooling], + refs={"tok2vec": tok2vec, "pooling": pooling}, + attrs={"get_instances": get_instances}, + ) + + +# The custom forward function +def instance_forward( + model: Model[List[Doc], Floats2d], + docs: List[Doc], + is_train: bool, +) -> Tuple[Floats2d, Callable]: + tok2vec = model.get_ref("tok2vec") + tokvecs, bp_tokvecs = tok2vec(docs, is_train) + get_instances = model.attrs["get_instances"] + all_instances = [get_instances(doc) for doc in docs] + pooling = model.get_ref("pooling") + relations = ... + + def backprop(d_relations: Floats2d) -> List[Doc]: + d_tokvecs = ... + return bp_tokvecs(d_tokvecs) + + return relations, backprop + + +# The custom initialization method +def instance_init( + model: Model, + X: List[Doc] = None, + Y: Floats2d = None, +) -> Model: + tok2vec = model.get_ref("tok2vec") + tok2vec.initialize(X) + return model + +``` + +This custom layer uses an [embedding layer](/usage/embeddings-transformers) such +as a [`Tok2Vec`](/api/tok2vec) component or a [`Transformer`](/api/transformer). +This layer is assumed to be of type ~~Model[List[Doc], List[Floats2d]]~~ as it +transforms each **document into a list of tokens**, with each token being +represented by its embedding in the vector space. + +The `pooling` layer will be applied to summarize the token vectors into **entity +vectors**, as named entities (represented by ~~Span~~ objects) can consist of +one or multiple tokens. For instance, the pooling layer could resort to +calculating the average of all token vectors in an entity. Thinc provides +several +[built-in pooling operators](https://thinc.ai/docs/api-layers#reduction-ops) for +this purpose. + +Finally, we need a `get_instances` method that **generates pairs of entities** +that we want to classify as being related or not. As these candidate pairs are +typically formed within one document, this function takes a [`Doc`](/api/doc) as +input and outputs a `List` of `Span` tuples. For instance, the following +implementation takes any two entities from the same document, as long as they +are within a **maximum distance** (in number of tokens) of eachother: + +> #### config.cfg (excerpt) +> +> ```ini +> +> [model.create_instance_tensor.get_instances] +> @misc = "rel_instance_generator.v1" +> max_length = 100 +> ``` + +```python +### Candidate generation +@spacy.registry.misc.register("rel_instance_generator.v1") +def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]: def get_candidates(doc: "Doc") -> List[Tuple[Span, Span]]: candidates = [] for ent1 in doc.ents: @@ -617,45 +742,39 @@ def create_candidate_indices(max_length: int) -> Callable[[Doc], List[Tuple[Span return get_candidates ``` -Finally, we require a method that transforms the candidate entity pairs into a -2D tensor using the specified [`Tok2Vec`](/api/tok2vec) or -[`Transformer`](/api/transformer). The resulting ~~Floats2~~ object will then be -processed by a final `output_layer` of the network. Putting all this together, -we can define our relation model in a config file as such: +This function in added to the [`@misc` registry](/api/top-level#registry) so we +can refer to it from the config, and easily swap it out for any other candidate +generation function. -```ini -### config.cfg -[model] -@architectures = "rel_model.v1" -# ... +#### Intermezzo: define how to store the relations data {#component-rel-attribute} -[model.tok2vec] -# ... +> #### Example output +> +> ```python +> doc = nlp("Amsterdam is the capital of the Netherlands.") +> print("spans", [(e.start, e.text, e.label_) for e in doc.ents]) +> for value, rel_dict in doc._.rel.items(): +> print(f"{value}: {rel_dict}") +> +> # spans [(0, 'Amsterdam', 'LOC'), (6, 'Netherlands', 'LOC')] +> # (0, 6): {'CAPITAL_OF': 0.89, 'LOCATED_IN': 0.75, 'UNRELATED': 0.002} +> # (6, 0): {'CAPITAL_OF': 0.01, 'LOCATED_IN': 0.13, 'UNRELATED': 0.017} +> ``` -[model.get_candidates] -@misc = "rel_cand_generator.v1" -max_length = 20 - -[model.create_candidate_tensor] -@misc = "rel_cand_tensor.v1" - -[model.output_layer] -@architectures = "rel_output_layer.v1" -# ... -``` - - - - -When creating this model, we store the custom functions as -[attributes](https://thinc.ai/docs/api-model#properties) and the sublayers as -references, so we can access them easily: +For our new relation extraction component, we will use a custom +[extension attribute](/usage/processing-pipelines#custom-components-attributes) +`doc._.rel` in which we store relation data. The attribute refers to a +dictionary, keyed by the **start offsets of each entity** involved in the +candidate relation. The values in the dictionary refer to another dictionary +where relation labels are mapped to values between 0 and 1. We assume anything +above 0.5 to be a `True` relation. The ~~Example~~ instances that we'll use as +training data, will include their gold-standard relation annotations in +`example.reference._.rel`. ```python -tok2vec_layer = model.get_ref("tok2vec") -output_layer = model.get_ref("output_layer") -create_candidate_tensor = model.attrs["create_candidate_tensor"] -get_candidates = model.attrs["get_candidates"] +### Registering the extension attribute +from spacy.tokens import Doc +Doc.set_extension("rel", default={}) ``` #### Step 2: Implementing the pipeline component {#component-rel-pipe} @@ -698,19 +817,44 @@ class RelationExtractor(TrainablePipe): ... ``` -Before the model can be used, it needs to be -[initialized](/usage/training#initialization). This function receives a callback -to access the full **training data set**, or a representative sample. This data -set can be used to deduce all **relevant labels**. Alternatively, a list of -labels can be provided to `initialize`, or you can call -`RelationExtractor.add_label` directly. The number of labels defines the output -dimensionality of the network, and will be used to do +Typically, the **constructor** defines the vocab, the Machine Learning model, +and the name of this component. Additionally, this component, just like the +`textcat` and the `tagger`, stores an **internal list of labels**. The ML model +will predict scores for each label. We add convenience methods to easily +retrieve and add to them. + +```python +### The constructor (continued) + def __init__(self, vocab, model, name="rel"): + """Create a component instance.""" + # ... + self.cfg = {"labels": []} + + @property + def labels(self) -> Tuple[str]: + """Returns the labels currently added to the component.""" + return tuple(self.cfg["labels"]) + + def add_label(self, label: str): + """Add a new label to the pipe.""" + self.cfg["labels"] = list(self.labels) + [label] +``` + +After creation, the component needs to be +[initialized](/usage/training#initialization). This method can define the +relevant labels in two ways: explicitely by setting the `labels` argument in the +[`initialize` block](/api/data-formats#config-initialize) of the config, or +implicately by deducing them from the `get_examples` callback that generates the +full **training data set**, or a representative sample. + +The final number of labels defines the output dimensionality of the network, and +will be used to do [shape inference](https://thinc.ai/docs/usage-models#validation) throughout the layers of the neural network. This is triggered by calling [`Model.initialize`](https://thinc.ai/api/model#initialize). ```python -### The initialize method {highlight="12,18,22"} +### The initialize method {highlight="12,15,18,22"} from itertools import islice def initialize( @@ -741,7 +885,7 @@ Typically, this happens when the pipeline is set up before training in [`spacy train`](/api/cli#training). After initialization, the pipeline component and its internal model can be trained and used to make predictions. -During training, the function [`update`](/api/pipe#update) is invoked which +During training, the method [`update`](/api/pipe#update) is invoked which delegates to [`Model.begin_update`](https://thinc.ai/docs/api-model#begin_update) and a [`get_loss`](/api/pipe#get_loss) function that **calculates the loss** for a @@ -761,18 +905,18 @@ def update( sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, ) -> Dict[str, float]: - ... - docs = [ex.predicted for ex in examples] + # ... + docs = [eg.predicted for eg in examples] predictions, backprop = self.model.begin_update(docs) loss, gradient = self.get_loss(examples, predictions) backprop(gradient) losses[self.name] += loss - ... + # ... return losses ``` -When the internal model is trained, the component can be used to make novel -**predictions**. The [`predict`](/api/pipe#predict) function needs to be +After training the model, the component can be used to make novel +**predictions**. The [`predict`](/api/pipe#predict) method needs to be implemented for each subclass of `TrainablePipe`. In our case, we can simply delegate to the internal model's [predict](https://thinc.ai/docs/api-model#predict) function that takes a batch @@ -788,42 +932,21 @@ def predict(self, docs: Iterable[Doc]) -> Floats2d: The final method that needs to be implemented, is [`set_annotations`](/api/pipe#set_annotations). This function takes the predictions, and modifies the given `Doc` object in place to store them. For our -relation extraction component, we store the data as a dictionary in a custom -[extension attribute](/usage/processing-pipelines#custom-components-attributes) -`doc._.rel`. As keys, we represent the candidate pair by the **start offsets of -each entity**, as this defines an entity pair uniquely within one document. +relation extraction component, we store the data in the +[custom attribute](#component-rel-attribute)`doc._.rel`. To interpret the scores predicted by the relation extraction model correctly, we -need to refer to the model's `get_candidates` function that defined which pairs +need to refer to the model's `get_instances` function that defined which pairs of entities were relevant candidates, so that the predictions can be linked to those exact entities: -> #### Example output -> -> ```python -> doc = nlp("Amsterdam is the capital of the Netherlands.") -> print("spans", [(e.start, e.text, e.label_) for e in doc.ents]) -> for value, rel_dict in doc._.rel.items(): -> print(f"{value}: {rel_dict}") -> -> # spans [(0, 'Amsterdam', 'LOC'), (6, 'Netherlands', 'LOC')] -> # (0, 6): {'CAPITAL_OF': 0.89, 'LOCATED_IN': 0.75, 'UNRELATED': 0.002} -> # (6, 0): {'CAPITAL_OF': 0.01, 'LOCATED_IN': 0.13, 'UNRELATED': 0.017} -> ``` - -```python -### Registering the extension attribute -from spacy.tokens import Doc -Doc.set_extension("rel", default={}) -``` - ```python ### The set_annotations method {highlight="5-6,10"} def set_annotations(self, docs: Iterable[Doc], predictions: Floats2d): c = 0 - get_candidates = self.model.attrs["get_candidates"] + get_instances = self.model.attrs["get_instances"] for doc in docs: - for (e1, e2) in get_candidates(doc): + for (e1, e2) in get_instances(doc): offset = (e1.start, e2.start) if offset not in doc._.rel: doc._.rel[offset] = {} @@ -837,15 +960,15 @@ Under the hood, when the pipe is applied to a document, it delegates to the ```python ### The __call__ method -def __call__(self, Doc doc): +def __call__(self, doc: Doc): predictions = self.predict([doc]) self.set_annotations([doc], predictions) return doc ``` -There is one more optional method to implement: [`score`](/api/pipe#score) -calculates the performance of your component on a set of examples, and -returns the results as a dictionary: +There is one more optional method to implement: [`score`](/api/pipe#score) +calculates the performance of your component on a set of examples, and returns +the results as a dictionary: ```python ### The score method @@ -861,8 +984,8 @@ def score(self, examples: Iterable[Example]) -> Dict[str, Any]: } ``` -This is particularly useful to see the scores on the development corpus -when training the component with [`spacy train`](/api/cli#training). +This is particularly useful for calculating relevant scores on the development +corpus when training the component with [`spacy train`](/api/cli#training). Once our `TrainablePipe` subclass is fully implemented, we can [register](/usage/processing-pipelines#custom-components-factories) the @@ -879,14 +1002,8 @@ assigns it a name and lets you create the component with > > [components.relation_extractor.model] > @architectures = "rel_model.v1" -> -> [components.relation_extractor.model.tok2vec] > # ... > -> [components.relation_extractor.model.get_candidates] -> @misc = "rel_cand_generator.v1" -> max_length = 20 -> > [training.score_weights] > rel_micro_p = 0.0 > rel_micro_r = 0.0 @@ -902,8 +1019,8 @@ def make_relation_extractor(nlp, name, model): return RelationExtractor(nlp.vocab, model, name) ``` -You can extend the decorator to include information such as the type of -annotations that are required for this component to run, the type of annotations +You can extend the decorator to include information such as the type of +annotations that are required for this component to run, the type of annotations it produces, and the scores that can be calculated: ```python @@ -924,6 +1041,12 @@ def make_relation_extractor(nlp, name, model): return RelationExtractor(nlp.vocab, model, name) ``` - + +Run this example use-case by using our project template. It includes all the +code to create the ML model and the pipeline component from scratch. +It contains two config files to train the model: +one to run on CPU with a Tok2Vec layer, and one for the GPU using a transformer. +The project applies the relation extraction component to identify biomolecular +interactions, but you can easily swap in your own dataset for your experiments +in any other domain. +