mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 04:10:20 +03:00
Add model, logic changes and documentation
This commit is contained in:
parent
27fac7df2e
commit
74dbb65bca
|
@ -1,4 +1,5 @@
|
|||
from .entity_linker import * # noqa
|
||||
from .lemmatizer import * # noqa
|
||||
from .multi_task import * # noqa
|
||||
from .parser import * # noqa
|
||||
from .spancat import * # noqa
|
||||
|
|
44
spacy/ml/models/lemmatizer.py
Normal file
44
spacy/ml/models/lemmatizer.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from typing import Optional, List
|
||||
from thinc.api import glorot_uniform_init, with_array, Softmax_v2, Model
|
||||
from thinc.api import Relu, Dropout, Sigmoid, concatenate, chain
|
||||
from thinc.types import Floats2d, Union
|
||||
|
||||
from ...util import registry
|
||||
from ...tokens import Doc
|
||||
|
||||
|
||||
@registry.architectures("spacy.Lemmatizer.v1")
|
||||
def build_lemmatizer_model(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
nO: Optional[int] = None,
|
||||
normalize=False,
|
||||
lowercasing=True,
|
||||
) -> Model[List[Doc], Union[List[Floats2d]]]:
|
||||
"""Build a model for the edit-tree lemmatizer, using a provided token-to-vector component.
|
||||
A linear layer with softmax activation is added to predict scores
|
||||
given the token vectors (this part corresponds to the *Tagger* architecture).
|
||||
|
||||
Additionally, if *lowercasing==True*, a single sigmoid output is learned for each
|
||||
token input: this is used in the context of the lemmatizer to determine when to convert inputs
|
||||
to lowercase before processing them with the edit trees. A single normalized linear layer
|
||||
is placed between the token-to-vector component and the sigmoid output neuron, as this has
|
||||
been found to produce slightly better results than a direct connection. A possible intuition
|
||||
explaining this is that it reduces the interference between the two learning goals,
|
||||
which are linked but nonetheless distinct.
|
||||
|
||||
tok2vec (Model[List[Doc], List[Floats2d]]): The token-to-vector subnetwork.
|
||||
nO (int or None): The number of tags to output in the main model. Inferred from the data if None.
|
||||
lowercasing: if *True*, the additional sigmoid appendage is created.
|
||||
"""
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
softmax = Softmax_v2(
|
||||
nO, t2v_width, init_W=glorot_uniform_init, normalize_outputs=normalize
|
||||
)
|
||||
model = tok2vec >> with_array(softmax)
|
||||
if lowercasing:
|
||||
lowercasing_output = Sigmoid(1)
|
||||
sigmoid_appendage = Relu(50) >> Dropout(0.2) >> lowercasing_output
|
||||
model |= tok2vec >> with_array(sigmoid_appendage)
|
||||
model.set_ref("lowercasing_output", lowercasing_output)
|
||||
return model
|
|
@ -5,8 +5,9 @@ from itertools import islice
|
|||
import numpy as np
|
||||
|
||||
import srsly
|
||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
||||
from thinc.api import Config, Model
|
||||
from thinc.api import SequenceCategoricalCrossentropy, L2Distance
|
||||
from thinc.types import Floats2d, Ints2d
|
||||
|
||||
from ._edit_tree_internals.edit_trees import EditTrees
|
||||
from ._edit_tree_internals.schemas import validate_edit_tree
|
||||
|
@ -22,7 +23,8 @@ from .. import util
|
|||
|
||||
default_model_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Tagger.v2"
|
||||
@architectures = "spacy.Lemmatizer.v1"
|
||||
lowercasing = true
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v2"
|
||||
|
@ -115,29 +117,67 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
|
||||
self.cfg: Dict[str, Any] = {"labels": []}
|
||||
self.scorer = scorer
|
||||
self.lowercasing = model.has_ref("lowercasing_output")
|
||||
|
||||
def get_loss(
|
||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||
) -> Tuple[float, List[Floats2d]]:
|
||||
validate_examples(examples, "EditTreeLemmatizer.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1)
|
||||
|
||||
truths = []
|
||||
tree_loss_func = SequenceCategoricalCrossentropy(
|
||||
normalize=False, missing_value=-1
|
||||
)
|
||||
lowercasing_loss_func = L2Distance(normalize=False)
|
||||
tree_truths = []
|
||||
lowercasing_truths = []
|
||||
for eg in examples:
|
||||
eg_truths = []
|
||||
eg_tree_truths = []
|
||||
eg_lowercasing_truths = []
|
||||
for (predicted, gold_lemma) in zip(
|
||||
eg.predicted, eg.get_aligned("LEMMA", as_string=True)
|
||||
):
|
||||
if gold_lemma is None:
|
||||
label = -1
|
||||
eg_lowercasing_truths.append([0])
|
||||
else:
|
||||
tree_id = self.trees.add(predicted.text, gold_lemma)
|
||||
if self.lowercasing and _should_lowercased(
|
||||
predicted.text, gold_lemma
|
||||
):
|
||||
eg_lowercasing_truths.append([1])
|
||||
text = predicted.text.lower()
|
||||
else:
|
||||
eg_lowercasing_truths.append([0])
|
||||
text = predicted.text
|
||||
tree_id = self.trees.add(text, gold_lemma)
|
||||
label = self.tree2label.get(tree_id, 0)
|
||||
eg_truths.append(label)
|
||||
|
||||
truths.append(eg_truths)
|
||||
eg_tree_truths.append(label)
|
||||
|
||||
d_scores, loss = loss_func(scores, truths)
|
||||
tree_truths.append(eg_tree_truths)
|
||||
lowercasing_truths.append(eg_lowercasing_truths)
|
||||
|
||||
if self.lowercasing:
|
||||
tree_scores, lowercasing_flags = _split_predictions(scores)
|
||||
d_tree_scores, loss = tree_loss_func(tree_scores, tree_truths)
|
||||
|
||||
d_scores: List[Floats2d] = []
|
||||
for i, doc_d_tree_scores in enumerate(d_tree_scores):
|
||||
eg_lowercasing_flags = lowercasing_flags[i]
|
||||
eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func(
|
||||
eg_lowercasing_flags, self.model.ops.asarray(lowercasing_truths[i])
|
||||
)
|
||||
d_scores.append(
|
||||
self.model.ops.xp.hstack(
|
||||
[
|
||||
doc_d_tree_scores,
|
||||
self.model.ops.asarray(
|
||||
eg_d_lowercasing_flags, dtype="float32"
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
loss += eg_lowercasing_loss
|
||||
else:
|
||||
d_scores, loss = tree_loss_func(scores, tree_truths)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError(Errors.E910.format(name=self.name))
|
||||
|
||||
|
@ -151,15 +191,21 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
guesses: List[Ints2d] = [self.model.ops.alloc2i(0, n_labels) for _ in docs]
|
||||
assert len(guesses) == n_docs
|
||||
return guesses
|
||||
scores = self.model.predict(docs)
|
||||
if self.lowercasing:
|
||||
scores, lowercasing_flags = _split_predictions(self.model.predict(docs))
|
||||
else:
|
||||
scores = self.model.predict(docs)
|
||||
lowercasing_flags = None
|
||||
assert len(scores) == n_docs
|
||||
guesses = self._scores2guesses(docs, scores)
|
||||
guesses = self._scores2guesses(docs, scores, lowercasing_flags)
|
||||
assert len(guesses) == n_docs
|
||||
return guesses
|
||||
|
||||
def _scores2guesses(self, docs, scores):
|
||||
def _scores2guesses(
|
||||
self, docs, scores, lowercasing_flags: Optional[List[Floats2d]]
|
||||
):
|
||||
guesses = []
|
||||
for doc, doc_scores in zip(docs, scores):
|
||||
for i, (doc, doc_scores) in enumerate(zip(docs, scores)):
|
||||
if self.top_k == 1:
|
||||
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
|
||||
else:
|
||||
|
@ -169,15 +215,20 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
doc_guesses = doc_guesses.get()
|
||||
|
||||
doc_compat_guesses = []
|
||||
for token, candidates in zip(doc, doc_guesses):
|
||||
for j, (token, candidates) in enumerate(zip(doc, doc_guesses)):
|
||||
text = token.text
|
||||
to_lowercase = False
|
||||
if lowercasing_flags is not None and lowercasing_flags[i][j] > 0.5:
|
||||
to_lowercase = True
|
||||
text = text.lower()
|
||||
tree_id = -1
|
||||
for candidate in candidates:
|
||||
candidate_tree_id = self.cfg["labels"][candidate]
|
||||
|
||||
if self.trees.apply(candidate_tree_id, token.text) is not None:
|
||||
if self.trees.apply(candidate_tree_id, text) is not None:
|
||||
tree_id = candidate_tree_id
|
||||
break
|
||||
doc_compat_guesses.append(tree_id)
|
||||
doc_compat_guesses.append((tree_id, to_lowercase))
|
||||
|
||||
guesses.append(np.array(doc_compat_guesses))
|
||||
|
||||
|
@ -188,7 +239,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
doc_tree_ids = batch_tree_ids[i]
|
||||
if hasattr(doc_tree_ids, "get"):
|
||||
doc_tree_ids = doc_tree_ids.get()
|
||||
for j, tree_id in enumerate(doc_tree_ids):
|
||||
for j, (tree_id, to_lowercase) in enumerate(doc_tree_ids):
|
||||
if self.overwrite or doc[j].lemma == 0:
|
||||
# If no applicable tree could be found during prediction,
|
||||
# the special identifier -1 is used. Otherwise the tree
|
||||
|
@ -197,7 +248,8 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
if self.backoff is not None:
|
||||
doc[j].lemma = getattr(doc[j], self.backoff)
|
||||
else:
|
||||
lemma = self.trees.apply(tree_id, doc[j].text)
|
||||
text = doc[j].text.lower() if to_lowercase else doc[j].text
|
||||
lemma = self.trees.apply(tree_id, text)
|
||||
doc[j].lemma_ = lemma
|
||||
|
||||
@property
|
||||
|
@ -349,9 +401,15 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
for example in get_examples():
|
||||
for token in example.reference:
|
||||
if token.lemma != 0:
|
||||
tree_id = trees.add(token.text, token.lemma_)
|
||||
if self.lowercasing and _should_lowercased(
|
||||
token.text, token.lemma_
|
||||
):
|
||||
text = token.text.lower()
|
||||
else:
|
||||
text = token.text
|
||||
tree_id = trees.add(text, token.lemma_)
|
||||
tree_freqs[tree_id] += 1
|
||||
repr_pairs[tree_id] = (token.text, token.lemma_)
|
||||
repr_pairs[tree_id] = (text, token.lemma_)
|
||||
|
||||
# Construct trees that make the frequency cut-off using representative
|
||||
# form - token pairs.
|
||||
|
@ -374,3 +432,15 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
self.tree2label[tree_id] = len(self.cfg["labels"])
|
||||
self.cfg["labels"].append(tree_id)
|
||||
return self.tree2label[tree_id]
|
||||
|
||||
|
||||
def _should_lowercased(form: str, lemma: str) -> bool:
|
||||
return form.lower() != form and lemma.lower() == lemma
|
||||
|
||||
|
||||
def _split_predictions(
|
||||
raw_predictions: List[Floats2d],
|
||||
) -> Tuple[List[Floats2d], List[Floats2d]]:
|
||||
tree_scores = [rp[:, :-1] for rp in raw_predictions]
|
||||
lowercasing_flags = [rp[:, -1:] for rp in raw_predictions]
|
||||
return tree_scores, lowercasing_flags
|
||||
|
|
|
@ -648,6 +648,43 @@ The other arguments are shared between all versions.
|
|||
|
||||
</Accordion>
|
||||
|
||||
### spacy.Lemmatizer.v1 {#Lemmatizer}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.Lemmatizer.v1"
|
||||
> nO = null
|
||||
> normalize = false
|
||||
> lowercasing = true
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> # ...
|
||||
> ```
|
||||
|
||||
Build a model for the edit-tree lemmatizer, using a provided token-to-vector
|
||||
component. A linear layer with softmax activation is added to predict scores
|
||||
given the token vectors (this part corresponds to the
|
||||
[Tagger](/api/architectures#tagger) architecture).
|
||||
|
||||
Additionally, if `lowercasing = true`, a single sigmoid output is learned for
|
||||
each token input: this is used in the context of the lemmatizer to determine
|
||||
when to convert inputs to lowercase before processing them with the edit trees.
|
||||
A single normalized linear layer is placed between the token-to-vector component
|
||||
and the sigmoid output neuron, as this has been found to produce slightly better
|
||||
results than a direct connection. A possible intuition explaining this is that
|
||||
it reduces the interference between the two learning goals, which are linked but
|
||||
nonetheless distinct.
|
||||
|
||||
| Name | Description |
|
||||
| ------------- | ------------------------------------------------------------------------------------------ |
|
||||
| `tok2vec` | Subnetwork to map tokens into vector representations. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `nO` | The number of tags to output. Inferred from the data if `None`. ~~Optional[int]~~ |
|
||||
| `normalize` | Normalize probabilities during inference. Defaults to `False`. ~~bool~~ |
|
||||
| `lowercasing` | If `True`, the additional sigmoid appendage is created. Defaults to `True`. ~~bool~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"}
|
||||
|
||||
A text classification architecture needs to take a [`Doc`](/api/doc) as input,
|
||||
|
|
Loading…
Reference in New Issue
Block a user