Add model, logic changes and documentation

This commit is contained in:
richardpaulhudson 2022-12-07 16:59:22 +01:00
parent 27fac7df2e
commit 74dbb65bca
4 changed files with 174 additions and 22 deletions

View File

@ -1,4 +1,5 @@
from .entity_linker import * # noqa from .entity_linker import * # noqa
from .lemmatizer import * # noqa
from .multi_task import * # noqa from .multi_task import * # noqa
from .parser import * # noqa from .parser import * # noqa
from .spancat import * # noqa from .spancat import * # noqa

View File

@ -0,0 +1,44 @@
from typing import Optional, List
from thinc.api import glorot_uniform_init, with_array, Softmax_v2, Model
from thinc.api import Relu, Dropout, Sigmoid, concatenate, chain
from thinc.types import Floats2d, Union
from ...util import registry
from ...tokens import Doc
@registry.architectures("spacy.Lemmatizer.v1")
def build_lemmatizer_model(
tok2vec: Model[List[Doc], List[Floats2d]],
nO: Optional[int] = None,
normalize=False,
lowercasing=True,
) -> Model[List[Doc], Union[List[Floats2d]]]:
"""Build a model for the edit-tree lemmatizer, using a provided token-to-vector component.
A linear layer with softmax activation is added to predict scores
given the token vectors (this part corresponds to the *Tagger* architecture).
Additionally, if *lowercasing==True*, a single sigmoid output is learned for each
token input: this is used in the context of the lemmatizer to determine when to convert inputs
to lowercase before processing them with the edit trees. A single normalized linear layer
is placed between the token-to-vector component and the sigmoid output neuron, as this has
been found to produce slightly better results than a direct connection. A possible intuition
explaining this is that it reduces the interference between the two learning goals,
which are linked but nonetheless distinct.
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
nO (int or None): The number of tags to output in the main model. Inferred from the data if None.
lowercasing: if *True*, the additional sigmoid appendage is created.
"""
with Model.define_operators({">>": chain, "|": concatenate}):
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
softmax = Softmax_v2(
nO, t2v_width, init_W=glorot_uniform_init, normalize_outputs=normalize
)
model = tok2vec >> with_array(softmax)
if lowercasing:
lowercasing_output = Sigmoid(1)
sigmoid_appendage = Relu(50) >> Dropout(0.2) >> lowercasing_output
model |= tok2vec >> with_array(sigmoid_appendage)
model.set_ref("lowercasing_output", lowercasing_output)
return model

View File

@ -5,8 +5,9 @@ from itertools import islice
import numpy as np import numpy as np
import srsly import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy from thinc.api import Config, Model
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.api import SequenceCategoricalCrossentropy, L2Distance
from thinc.types import Floats2d, Ints2d
from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree from ._edit_tree_internals.schemas import validate_edit_tree
@ -22,7 +23,8 @@ from .. import util
default_model_config = """ default_model_config = """
[model] [model]
@architectures = "spacy.Tagger.v2" @architectures = "spacy.Lemmatizer.v1"
lowercasing = true
[model.tok2vec] [model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v2" @architectures = "spacy.HashEmbedCNN.v2"
@ -115,29 +117,67 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []} self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer self.scorer = scorer
self.lowercasing = model.has_ref("lowercasing_output")
def get_loss( def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d] self, examples: Iterable[Example], scores: List[Floats2d]
) -> Tuple[float, List[Floats2d]]: ) -> Tuple[float, List[Floats2d]]:
validate_examples(examples, "EditTreeLemmatizer.get_loss") validate_examples(examples, "EditTreeLemmatizer.get_loss")
loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) tree_loss_func = SequenceCategoricalCrossentropy(
normalize=False, missing_value=-1
truths = [] )
lowercasing_loss_func = L2Distance(normalize=False)
tree_truths = []
lowercasing_truths = []
for eg in examples: for eg in examples:
eg_truths = [] eg_tree_truths = []
eg_lowercasing_truths = []
for (predicted, gold_lemma) in zip( for (predicted, gold_lemma) in zip(
eg.predicted, eg.get_aligned("LEMMA", as_string=True) eg.predicted, eg.get_aligned("LEMMA", as_string=True)
): ):
if gold_lemma is None: if gold_lemma is None:
label = -1 label = -1
eg_lowercasing_truths.append([0])
else: else:
tree_id = self.trees.add(predicted.text, gold_lemma) if self.lowercasing and _should_lowercased(
predicted.text, gold_lemma
):
eg_lowercasing_truths.append([1])
text = predicted.text.lower()
else:
eg_lowercasing_truths.append([0])
text = predicted.text
tree_id = self.trees.add(text, gold_lemma)
label = self.tree2label.get(tree_id, 0) label = self.tree2label.get(tree_id, 0)
eg_truths.append(label)
truths.append(eg_truths) eg_tree_truths.append(label)
d_scores, loss = loss_func(scores, truths) tree_truths.append(eg_tree_truths)
lowercasing_truths.append(eg_lowercasing_truths)
if self.lowercasing:
tree_scores, lowercasing_flags = _split_predictions(scores)
d_tree_scores, loss = tree_loss_func(tree_scores, tree_truths)
d_scores: List[Floats2d] = []
for i, doc_d_tree_scores in enumerate(d_tree_scores):
eg_lowercasing_flags = lowercasing_flags[i]
eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func(
eg_lowercasing_flags, self.model.ops.asarray(lowercasing_truths[i])
)
d_scores.append(
self.model.ops.xp.hstack(
[
doc_d_tree_scores,
self.model.ops.asarray(
eg_d_lowercasing_flags, dtype="float32"
),
]
)
)
loss += eg_lowercasing_loss
else:
d_scores, loss = tree_loss_func(scores, tree_truths)
if self.model.ops.xp.isnan(loss): if self.model.ops.xp.isnan(loss):
raise ValueError(Errors.E910.format(name=self.name)) raise ValueError(Errors.E910.format(name=self.name))
@ -151,15 +191,21 @@ class EditTreeLemmatizer(TrainablePipe):
guesses: List[Ints2d] = [self.model.ops.alloc2i(0, n_labels) for _ in docs] guesses: List[Ints2d] = [self.model.ops.alloc2i(0, n_labels) for _ in docs]
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return guesses
scores = self.model.predict(docs) if self.lowercasing:
scores, lowercasing_flags = _split_predictions(self.model.predict(docs))
else:
scores = self.model.predict(docs)
lowercasing_flags = None
assert len(scores) == n_docs assert len(scores) == n_docs
guesses = self._scores2guesses(docs, scores) guesses = self._scores2guesses(docs, scores, lowercasing_flags)
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return guesses
def _scores2guesses(self, docs, scores): def _scores2guesses(
self, docs, scores, lowercasing_flags: Optional[List[Floats2d]]
):
guesses = [] guesses = []
for doc, doc_scores in zip(docs, scores): for i, (doc, doc_scores) in enumerate(zip(docs, scores)):
if self.top_k == 1: if self.top_k == 1:
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1) doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
else: else:
@ -169,15 +215,20 @@ class EditTreeLemmatizer(TrainablePipe):
doc_guesses = doc_guesses.get() doc_guesses = doc_guesses.get()
doc_compat_guesses = [] doc_compat_guesses = []
for token, candidates in zip(doc, doc_guesses): for j, (token, candidates) in enumerate(zip(doc, doc_guesses)):
text = token.text
to_lowercase = False
if lowercasing_flags is not None and lowercasing_flags[i][j] > 0.5:
to_lowercase = True
text = text.lower()
tree_id = -1 tree_id = -1
for candidate in candidates: for candidate in candidates:
candidate_tree_id = self.cfg["labels"][candidate] candidate_tree_id = self.cfg["labels"][candidate]
if self.trees.apply(candidate_tree_id, token.text) is not None: if self.trees.apply(candidate_tree_id, text) is not None:
tree_id = candidate_tree_id tree_id = candidate_tree_id
break break
doc_compat_guesses.append(tree_id) doc_compat_guesses.append((tree_id, to_lowercase))
guesses.append(np.array(doc_compat_guesses)) guesses.append(np.array(doc_compat_guesses))
@ -188,7 +239,7 @@ class EditTreeLemmatizer(TrainablePipe):
doc_tree_ids = batch_tree_ids[i] doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"): if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get() doc_tree_ids = doc_tree_ids.get()
for j, tree_id in enumerate(doc_tree_ids): for j, (tree_id, to_lowercase) in enumerate(doc_tree_ids):
if self.overwrite or doc[j].lemma == 0: if self.overwrite or doc[j].lemma == 0:
# If no applicable tree could be found during prediction, # If no applicable tree could be found during prediction,
# the special identifier -1 is used. Otherwise the tree # the special identifier -1 is used. Otherwise the tree
@ -197,7 +248,8 @@ class EditTreeLemmatizer(TrainablePipe):
if self.backoff is not None: if self.backoff is not None:
doc[j].lemma = getattr(doc[j], self.backoff) doc[j].lemma = getattr(doc[j], self.backoff)
else: else:
lemma = self.trees.apply(tree_id, doc[j].text) text = doc[j].text.lower() if to_lowercase else doc[j].text
lemma = self.trees.apply(tree_id, text)
doc[j].lemma_ = lemma doc[j].lemma_ = lemma
@property @property
@ -349,9 +401,15 @@ class EditTreeLemmatizer(TrainablePipe):
for example in get_examples(): for example in get_examples():
for token in example.reference: for token in example.reference:
if token.lemma != 0: if token.lemma != 0:
tree_id = trees.add(token.text, token.lemma_) if self.lowercasing and _should_lowercased(
token.text, token.lemma_
):
text = token.text.lower()
else:
text = token.text
tree_id = trees.add(text, token.lemma_)
tree_freqs[tree_id] += 1 tree_freqs[tree_id] += 1
repr_pairs[tree_id] = (token.text, token.lemma_) repr_pairs[tree_id] = (text, token.lemma_)
# Construct trees that make the frequency cut-off using representative # Construct trees that make the frequency cut-off using representative
# form - token pairs. # form - token pairs.
@ -374,3 +432,15 @@ class EditTreeLemmatizer(TrainablePipe):
self.tree2label[tree_id] = len(self.cfg["labels"]) self.tree2label[tree_id] = len(self.cfg["labels"])
self.cfg["labels"].append(tree_id) self.cfg["labels"].append(tree_id)
return self.tree2label[tree_id] return self.tree2label[tree_id]
def _should_lowercased(form: str, lemma: str) -> bool:
return form.lower() != form and lemma.lower() == lemma
def _split_predictions(
raw_predictions: List[Floats2d],
) -> Tuple[List[Floats2d], List[Floats2d]]:
tree_scores = [rp[:, :-1] for rp in raw_predictions]
lowercasing_flags = [rp[:, -1:] for rp in raw_predictions]
return tree_scores, lowercasing_flags

View File

@ -648,6 +648,43 @@ The other arguments are shared between all versions.
</Accordion> </Accordion>
### spacy.Lemmatizer.v1 {#Lemmatizer}
> #### Example Config
>
> ```ini
> [model]
> @architectures = "spacy.Lemmatizer.v1"
> nO = null
> normalize = false
> lowercasing = true
>
> [model.tok2vec]
> # ...
> ```
Build a model for the edit-tree lemmatizer, using a provided token-to-vector
component. A linear layer with softmax activation is added to predict scores
given the token vectors (this part corresponds to the
[Tagger](/api/architectures#tagger) architecture).
Additionally, if `lowercasing = true`, a single sigmoid output is learned for
each token input: this is used in the context of the lemmatizer to determine
when to convert inputs to lowercase before processing them with the edit trees.
A single normalized linear layer is placed between the token-to-vector component
and the sigmoid output neuron, as this has been found to produce slightly better
results than a direct connection. A possible intuition explaining this is that
it reduces the interference between the two learning goals, which are linked but
nonetheless distinct.
| Name | Description |
| ------------- | ------------------------------------------------------------------------------------------ |
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
| `nO` | The number of tags to output. Inferred from the data if `None`. ~~Optional[int]~~ |
| `normalize` | Normalize probabilities during inference. Defaults to `False`. ~~bool~~ |
| `lowercasing` | If `True`, the additional sigmoid appendage is created. Defaults to `True`. ~~bool~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"} ## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"}
A text classification architecture needs to take a [`Doc`](/api/doc) as input, A text classification architecture needs to take a [`Doc`](/api/doc) as input,