mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix label initialization of textcat component (#6190)
This commit is contained in:
parent
989a96308f
commit
dd542ec6a4
|
@ -497,8 +497,9 @@ class Errors:
|
|||
E919 = ("A textcat 'positive_label' '{pos_label}' was provided for training "
|
||||
"data that does not appear to be a binary classification problem "
|
||||
"with two labels. Labels found: {labels}")
|
||||
E920 = ("The textcat's 'positive_label' config setting '{pos_label}' "
|
||||
"does not match any label in the training data. Labels found: {labels}")
|
||||
E920 = ("The textcat's 'positive_label' setting '{pos_label}' "
|
||||
"does not match any label in the training data or provided during "
|
||||
"initialization. Available labels: {labels}")
|
||||
E921 = ("The method 'set_output' can only be called on components that have "
|
||||
"a Model with a 'resize_output' attribute. Otherwise, the output "
|
||||
"layer can not be dynamically changed.")
|
||||
|
|
|
@ -71,10 +71,6 @@ class SentenceRecognizer(Tagger):
|
|||
# are 0
|
||||
return tuple(["I", "S"])
|
||||
|
||||
@property
|
||||
def label_data(self):
|
||||
return self.labels
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
|
|
|
@ -56,12 +56,7 @@ subword_features = true
|
|||
@Language.factory(
|
||||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"labels": [],
|
||||
"threshold": 0.5,
|
||||
"positive_label": None,
|
||||
"model": DEFAULT_TEXTCAT_MODEL,
|
||||
},
|
||||
default_config={"threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
"cats_score_desc": None,
|
||||
|
@ -75,12 +70,7 @@ subword_features = true
|
|||
},
|
||||
)
|
||||
def make_textcat(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
labels: List[str],
|
||||
threshold: float,
|
||||
positive_label: Optional[str],
|
||||
nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float,
|
||||
) -> "TextCategorizer":
|
||||
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels can
|
||||
|
@ -90,19 +80,9 @@ def make_textcat(
|
|||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||
scores for each category.
|
||||
labels (list): A list of categories to learn. If empty, the model infers the
|
||||
categories from the data.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
|
||||
"""
|
||||
return TextCategorizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
labels=labels,
|
||||
threshold=threshold,
|
||||
positive_label=positive_label,
|
||||
)
|
||||
return TextCategorizer(nlp.vocab, model, name, threshold=threshold)
|
||||
|
||||
|
||||
class TextCategorizer(Pipe):
|
||||
|
@ -112,14 +92,7 @@ class TextCategorizer(Pipe):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Vocab,
|
||||
model: Model,
|
||||
name: str = "textcat",
|
||||
*,
|
||||
labels: List[str],
|
||||
threshold: float,
|
||||
positive_label: Optional[str],
|
||||
self, vocab: Vocab, model: Model, name: str = "textcat", *, threshold: float
|
||||
) -> None:
|
||||
"""Initialize a text categorizer.
|
||||
|
||||
|
@ -127,9 +100,7 @@ class TextCategorizer(Pipe):
|
|||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels (List[str]): The labels to use.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
||||
"""
|
||||
|
@ -137,11 +108,7 @@ class TextCategorizer(Pipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {
|
||||
"labels": labels,
|
||||
"threshold": threshold,
|
||||
"positive_label": positive_label,
|
||||
}
|
||||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||
self.cfg = dict(cfg)
|
||||
|
||||
@property
|
||||
|
@ -348,6 +315,7 @@ class TextCategorizer(Pipe):
|
|||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
labels: Optional[Dict] = None,
|
||||
positive_label: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
@ -369,6 +337,14 @@ class TextCategorizer(Pipe):
|
|||
else:
|
||||
for label in labels:
|
||||
self.add_label(label)
|
||||
if positive_label is not None:
|
||||
if positive_label not in self.labels:
|
||||
err = Errors.E920.format(pos_label=positive_label, labels=self.labels)
|
||||
raise ValueError(err)
|
||||
if len(self.labels) != 2:
|
||||
err = Errors.E919.format(pos_label=positive_label, labels=self.labels)
|
||||
raise ValueError(err)
|
||||
self.cfg["positive_label"] = positive_label
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
label_sample, _ = self._examples_to_truth(subbatch)
|
||||
|
|
|
@ -10,7 +10,6 @@ from spacy.tokens import Doc
|
|||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.training import Example
|
||||
from spacy.training.initialize import verify_textcat_config
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -21,6 +20,17 @@ TRAIN_DATA = [
|
|||
]
|
||||
|
||||
|
||||
def make_get_examples(nlp):
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
def get_examples():
|
||||
return train_examples
|
||||
|
||||
return get_examples
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Test is flakey when run with others")
|
||||
def test_simple_train():
|
||||
nlp = Language()
|
||||
|
@ -92,10 +102,7 @@ def test_no_label():
|
|||
def test_implicit_label():
|
||||
nlp = Language()
|
||||
nlp.add_pipe("textcat")
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
nlp.initialize(get_examples=make_get_examples(nlp))
|
||||
|
||||
|
||||
def test_no_resize():
|
||||
|
@ -113,29 +120,26 @@ def test_no_resize():
|
|||
def test_initialize_examples():
|
||||
nlp = Language()
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
for label, value in annotations.get("cats").items():
|
||||
textcat.add_label(label)
|
||||
# you shouldn't really call this more than once, but for testing it should be fine
|
||||
nlp.initialize()
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
get_examples = make_get_examples(nlp)
|
||||
nlp.initialize(get_examples=get_examples)
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize(get_examples=lambda: None)
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize(get_examples=train_examples)
|
||||
nlp.initialize(get_examples=get_examples())
|
||||
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
nlp.config["initialize"]["components"]["textcat"] = {"positive_label": "POSITIVE"}
|
||||
# Set exclusive labels
|
||||
textcat = nlp.add_pipe(
|
||||
"textcat",
|
||||
config={"model": {"exclusive_classes": True}, "positive_label": "POSITIVE"},
|
||||
)
|
||||
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}},)
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
|
@ -203,28 +207,26 @@ def test_textcat_configs(textcat_config):
|
|||
|
||||
def test_positive_class():
|
||||
nlp = English()
|
||||
pipe_config = {"positive_label": "POS", "labels": ["POS", "NEG"]}
|
||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
get_examples = make_get_examples(nlp)
|
||||
textcat.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS")
|
||||
assert textcat.labels == ("POS", "NEG")
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
|
||||
|
||||
def test_positive_class_not_present():
|
||||
nlp = English()
|
||||
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING"]}
|
||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||
assert textcat.labels == ("SOME", "THING")
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
get_examples = make_get_examples(nlp)
|
||||
with pytest.raises(ValueError):
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
textcat.initialize(get_examples, labels=["SOME", "THING"], positive_label="POS")
|
||||
|
||||
|
||||
def test_positive_class_not_binary():
|
||||
nlp = English()
|
||||
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING", "POS"]}
|
||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||
assert textcat.labels == ("SOME", "THING", "POS")
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
get_examples = make_get_examples(nlp)
|
||||
with pytest.raises(ValueError):
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS")
|
||||
|
||||
|
||||
def test_textcat_evaluation():
|
||||
|
|
|
@ -136,13 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
|
|||
# See issue #1105
|
||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
textcat = TextCategorizer(
|
||||
en_vocab,
|
||||
model,
|
||||
labels=["ENTITY", "ACTION", "MODIFIER"],
|
||||
threshold=0.5,
|
||||
positive_label=None,
|
||||
)
|
||||
textcat = TextCategorizer(en_vocab, model, threshold=0.5)
|
||||
textcat.to_bytes(exclude=["vocab"])
|
||||
|
||||
|
||||
|
|
|
@ -50,9 +50,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
logger.info("Initialized pipeline components")
|
||||
# Verify the config after calling 'initialize' to ensure labels
|
||||
# are properly initialized
|
||||
verify_config(nlp)
|
||||
return nlp
|
||||
|
||||
|
||||
|
@ -152,33 +149,6 @@ def init_tok2vec(
|
|||
return False
|
||||
|
||||
|
||||
def verify_config(nlp: "Language") -> None:
|
||||
"""Perform additional checks based on the config, loaded nlp object and training data."""
|
||||
# TODO: maybe we should validate based on the actual components, the list
|
||||
# in config["nlp"]["pipeline"] instead?
|
||||
for pipe_config in nlp.config["components"].values():
|
||||
# We can't assume that the component name == the factory
|
||||
factory = pipe_config["factory"]
|
||||
if factory == "textcat":
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
|
||||
|
||||
def verify_textcat_config(nlp: "Language", pipe_config: Dict[str, Any]) -> None:
|
||||
# if 'positive_label' is provided: double check whether it's in the data and
|
||||
# the task is binary
|
||||
if pipe_config.get("positive_label"):
|
||||
textcat_labels = nlp.get_pipe("textcat").labels
|
||||
pos_label = pipe_config.get("positive_label")
|
||||
if pos_label not in textcat_labels:
|
||||
raise ValueError(
|
||||
Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
|
||||
)
|
||||
if len(list(textcat_labels)) != 2:
|
||||
raise ValueError(
|
||||
Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
|
||||
)
|
||||
|
||||
|
||||
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
|
||||
"""RETURNS (List[str]): All sourced components in the original config,
|
||||
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
||||
|
|
|
@ -35,11 +35,10 @@ architectures and their arguments and hyperparameters.
|
|||
> nlp.add_pipe("textcat", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise and by default. ~~Optional[str]~~ |
|
||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| Setting | Description |
|
||||
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
```python
|
||||
%%GITHUB_SPACY/spacy/pipeline/textcat.py
|
||||
|
@ -59,21 +58,20 @@ architectures and their arguments and hyperparameters.
|
|||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import TextCategorizer
|
||||
> textcat = TextCategorizer(nlp.vocab, model, threshold=0.5, positive_label="POS")
|
||||
> textcat = TextCategorizer(nlp.vocab, model, threshold=0.5)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise. ~~Optional[str]~~ |
|
||||
| Name | Description |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
|
||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
@ -152,18 +150,20 @@ This method was previously called `begin_training`.
|
|||
> ```ini
|
||||
> ### config.cfg
|
||||
> [initialize.components.textcat]
|
||||
> positive_label = "POS"
|
||||
>
|
||||
> [initialize.components.textcat.labels]
|
||||
> @readers = "spacy.read_labels.v1"
|
||||
> path = "corpus/labels/textcat.json
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
|
||||
| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ |
|
||||
| Name | Description |
|
||||
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
|
||||
| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise and by default. ~~Optional[str]~~ |
|
||||
|
||||
## TextCategorizer.predict {#predict tag="method"}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user