From 74dbb65bcacfaec82499fa430063e6114bc265c3 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Wed, 7 Dec 2022 16:59:22 +0100 Subject: [PATCH] Add model, logic changes and documentation --- spacy/ml/models/__init__.py | 1 + spacy/ml/models/lemmatizer.py | 44 ++++++++++ spacy/pipeline/edit_tree_lemmatizer.py | 114 ++++++++++++++++++++----- website/docs/api/architectures.md | 37 ++++++++ 4 files changed, 174 insertions(+), 22 deletions(-) create mode 100644 spacy/ml/models/lemmatizer.py diff --git a/spacy/ml/models/__init__.py b/spacy/ml/models/__init__.py index 9b7628f0e..cc0114aaa 100644 --- a/spacy/ml/models/__init__.py +++ b/spacy/ml/models/__init__.py @@ -1,4 +1,5 @@ from .entity_linker import * # noqa +from .lemmatizer import * # noqa from .multi_task import * # noqa from .parser import * # noqa from .spancat import * # noqa diff --git a/spacy/ml/models/lemmatizer.py b/spacy/ml/models/lemmatizer.py new file mode 100644 index 000000000..56abcbbca --- /dev/null +++ b/spacy/ml/models/lemmatizer.py @@ -0,0 +1,44 @@ +from typing import Optional, List +from thinc.api import glorot_uniform_init, with_array, Softmax_v2, Model +from thinc.api import Relu, Dropout, Sigmoid, concatenate, chain +from thinc.types import Floats2d, Union + +from ...util import registry +from ...tokens import Doc + + +@registry.architectures("spacy.Lemmatizer.v1") +def build_lemmatizer_model( + tok2vec: Model[List[Doc], List[Floats2d]], + nO: Optional[int] = None, + normalize=False, + lowercasing=True, +) -> Model[List[Doc], Union[List[Floats2d]]]: + """Build a model for the edit-tree lemmatizer, using a provided token-to-vector component. + A linear layer with softmax activation is added to predict scores + given the token vectors (this part corresponds to the *Tagger* architecture). + + Additionally, if *lowercasing==True*, a single sigmoid output is learned for each + token input: this is used in the context of the lemmatizer to determine when to convert inputs + to lowercase before processing them with the edit trees. A single normalized linear layer + is placed between the token-to-vector component and the sigmoid output neuron, as this has + been found to produce slightly better results than a direct connection. A possible intuition + explaining this is that it reduces the interference between the two learning goals, + which are linked but nonetheless distinct. + + tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork. + nO (int or None): The number of tags to output in the main model. Inferred from the data if None. + lowercasing: if *True*, the additional sigmoid appendage is created. + """ + with Model.define_operators({">>": chain, "|": concatenate}): + t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None + softmax = Softmax_v2( + nO, t2v_width, init_W=glorot_uniform_init, normalize_outputs=normalize + ) + model = tok2vec >> with_array(softmax) + if lowercasing: + lowercasing_output = Sigmoid(1) + sigmoid_appendage = Relu(50) >> Dropout(0.2) >> lowercasing_output + model |= tok2vec >> with_array(sigmoid_appendage) + model.set_ref("lowercasing_output", lowercasing_output) + return model diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index a56c9975e..6feaeb465 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -5,8 +5,9 @@ from itertools import islice import numpy as np import srsly -from thinc.api import Config, Model, SequenceCategoricalCrossentropy -from thinc.types import Floats2d, Ints1d, Ints2d +from thinc.api import Config, Model +from thinc.api import SequenceCategoricalCrossentropy, L2Distance +from thinc.types import Floats2d, Ints2d from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.schemas import validate_edit_tree @@ -22,7 +23,8 @@ from .. import util default_model_config = """ [model] -@architectures = "spacy.Tagger.v2" +@architectures = "spacy.Lemmatizer.v1" +lowercasing = true [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v2" @@ -115,29 +117,67 @@ class EditTreeLemmatizer(TrainablePipe): self.cfg: Dict[str, Any] = {"labels": []} self.scorer = scorer + self.lowercasing = model.has_ref("lowercasing_output") def get_loss( self, examples: Iterable[Example], scores: List[Floats2d] ) -> Tuple[float, List[Floats2d]]: validate_examples(examples, "EditTreeLemmatizer.get_loss") - loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) - - truths = [] + tree_loss_func = SequenceCategoricalCrossentropy( + normalize=False, missing_value=-1 + ) + lowercasing_loss_func = L2Distance(normalize=False) + tree_truths = [] + lowercasing_truths = [] for eg in examples: - eg_truths = [] + eg_tree_truths = [] + eg_lowercasing_truths = [] for (predicted, gold_lemma) in zip( eg.predicted, eg.get_aligned("LEMMA", as_string=True) ): if gold_lemma is None: label = -1 + eg_lowercasing_truths.append([0]) else: - tree_id = self.trees.add(predicted.text, gold_lemma) + if self.lowercasing and _should_lowercased( + predicted.text, gold_lemma + ): + eg_lowercasing_truths.append([1]) + text = predicted.text.lower() + else: + eg_lowercasing_truths.append([0]) + text = predicted.text + tree_id = self.trees.add(text, gold_lemma) label = self.tree2label.get(tree_id, 0) - eg_truths.append(label) - truths.append(eg_truths) + eg_tree_truths.append(label) - d_scores, loss = loss_func(scores, truths) + tree_truths.append(eg_tree_truths) + lowercasing_truths.append(eg_lowercasing_truths) + + if self.lowercasing: + tree_scores, lowercasing_flags = _split_predictions(scores) + d_tree_scores, loss = tree_loss_func(tree_scores, tree_truths) + + d_scores: List[Floats2d] = [] + for i, doc_d_tree_scores in enumerate(d_tree_scores): + eg_lowercasing_flags = lowercasing_flags[i] + eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func( + eg_lowercasing_flags, self.model.ops.asarray(lowercasing_truths[i]) + ) + d_scores.append( + self.model.ops.xp.hstack( + [ + doc_d_tree_scores, + self.model.ops.asarray( + eg_d_lowercasing_flags, dtype="float32" + ), + ] + ) + ) + loss += eg_lowercasing_loss + else: + d_scores, loss = tree_loss_func(scores, tree_truths) if self.model.ops.xp.isnan(loss): raise ValueError(Errors.E910.format(name=self.name)) @@ -151,15 +191,21 @@ class EditTreeLemmatizer(TrainablePipe): guesses: List[Ints2d] = [self.model.ops.alloc2i(0, n_labels) for _ in docs] assert len(guesses) == n_docs return guesses - scores = self.model.predict(docs) + if self.lowercasing: + scores, lowercasing_flags = _split_predictions(self.model.predict(docs)) + else: + scores = self.model.predict(docs) + lowercasing_flags = None assert len(scores) == n_docs - guesses = self._scores2guesses(docs, scores) + guesses = self._scores2guesses(docs, scores, lowercasing_flags) assert len(guesses) == n_docs return guesses - def _scores2guesses(self, docs, scores): + def _scores2guesses( + self, docs, scores, lowercasing_flags: Optional[List[Floats2d]] + ): guesses = [] - for doc, doc_scores in zip(docs, scores): + for i, (doc, doc_scores) in enumerate(zip(docs, scores)): if self.top_k == 1: doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1) else: @@ -169,15 +215,20 @@ class EditTreeLemmatizer(TrainablePipe): doc_guesses = doc_guesses.get() doc_compat_guesses = [] - for token, candidates in zip(doc, doc_guesses): + for j, (token, candidates) in enumerate(zip(doc, doc_guesses)): + text = token.text + to_lowercase = False + if lowercasing_flags is not None and lowercasing_flags[i][j] > 0.5: + to_lowercase = True + text = text.lower() tree_id = -1 for candidate in candidates: candidate_tree_id = self.cfg["labels"][candidate] - if self.trees.apply(candidate_tree_id, token.text) is not None: + if self.trees.apply(candidate_tree_id, text) is not None: tree_id = candidate_tree_id break - doc_compat_guesses.append(tree_id) + doc_compat_guesses.append((tree_id, to_lowercase)) guesses.append(np.array(doc_compat_guesses)) @@ -188,7 +239,7 @@ class EditTreeLemmatizer(TrainablePipe): doc_tree_ids = batch_tree_ids[i] if hasattr(doc_tree_ids, "get"): doc_tree_ids = doc_tree_ids.get() - for j, tree_id in enumerate(doc_tree_ids): + for j, (tree_id, to_lowercase) in enumerate(doc_tree_ids): if self.overwrite or doc[j].lemma == 0: # If no applicable tree could be found during prediction, # the special identifier -1 is used. Otherwise the tree @@ -197,7 +248,8 @@ class EditTreeLemmatizer(TrainablePipe): if self.backoff is not None: doc[j].lemma = getattr(doc[j], self.backoff) else: - lemma = self.trees.apply(tree_id, doc[j].text) + text = doc[j].text.lower() if to_lowercase else doc[j].text + lemma = self.trees.apply(tree_id, text) doc[j].lemma_ = lemma @property @@ -349,9 +401,15 @@ class EditTreeLemmatizer(TrainablePipe): for example in get_examples(): for token in example.reference: if token.lemma != 0: - tree_id = trees.add(token.text, token.lemma_) + if self.lowercasing and _should_lowercased( + token.text, token.lemma_ + ): + text = token.text.lower() + else: + text = token.text + tree_id = trees.add(text, token.lemma_) tree_freqs[tree_id] += 1 - repr_pairs[tree_id] = (token.text, token.lemma_) + repr_pairs[tree_id] = (text, token.lemma_) # Construct trees that make the frequency cut-off using representative # form - token pairs. @@ -374,3 +432,15 @@ class EditTreeLemmatizer(TrainablePipe): self.tree2label[tree_id] = len(self.cfg["labels"]) self.cfg["labels"].append(tree_id) return self.tree2label[tree_id] + + +def _should_lowercased(form: str, lemma: str) -> bool: + return form.lower() != form and lemma.lower() == lemma + + +def _split_predictions( + raw_predictions: List[Floats2d], +) -> Tuple[List[Floats2d], List[Floats2d]]: + tree_scores = [rp[:, :-1] for rp in raw_predictions] + lowercasing_flags = [rp[:, -1:] for rp in raw_predictions] + return tree_scores, lowercasing_flags diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md index 4c5447f75..85aa3f0a5 100644 --- a/website/docs/api/architectures.md +++ b/website/docs/api/architectures.md @@ -648,6 +648,43 @@ The other arguments are shared between all versions. +### spacy.Lemmatizer.v1 {#Lemmatizer} + +> #### Example Config +> +> ```ini +> [model] +> @architectures = "spacy.Lemmatizer.v1" +> nO = null +> normalize = false +> lowercasing = true +> +> [model.tok2vec] +> # ... +> ``` + +Build a model for the edit-tree lemmatizer, using a provided token-to-vector +component. A linear layer with softmax activation is added to predict scores +given the token vectors (this part corresponds to the +[Tagger](/api/architectures#tagger) architecture). + +Additionally, if `lowercasing = true`, a single sigmoid output is learned for +each token input: this is used in the context of the lemmatizer to determine +when to convert inputs to lowercase before processing them with the edit trees. +A single normalized linear layer is placed between the token-to-vector component +and the sigmoid output neuron, as this has been found to produce slightly better +results than a direct connection. A possible intuition explaining this is that +it reduces the interference between the two learning goals, which are linked but +nonetheless distinct. + +| Name | Description | +| ------------- | ------------------------------------------------------------------------------------------ | +| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ | +| `nO` | The number of tags to output. Inferred from the data if `None`. ~~Optional[int]~~ | +| `normalize` | Normalize probabilities during inference. Defaults to `False`. ~~bool~~ | +| `lowercasing` | If `True`, the additional sigmoid appendage is created. Defaults to `True`. ~~bool~~ | +| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ | + ## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"} A text classification architecture needs to take a [`Doc`](/api/doc) as input,