mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
Make the TrainablePipe.store_activations
property a bool
This means that we can also bring back `store_activations` setter.
This commit is contained in:
parent
1cfbb934ed
commit
aea53378dc
|
@ -65,7 +65,7 @@ def make_edit_tree_lemmatizer(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
):
|
):
|
||||||
"""Construct an EditTreeLemmatizer component."""
|
"""Construct an EditTreeLemmatizer component."""
|
||||||
return EditTreeLemmatizer(
|
return EditTreeLemmatizer(
|
||||||
|
@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
top_k: int = 1,
|
top_k: int = 1,
|
||||||
scorer: Optional[Callable] = lemmatizer_score,
|
scorer: Optional[Callable] = lemmatizer_score,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Construct an edit tree lemmatizer.
|
Construct an edit tree lemmatizer.
|
||||||
|
@ -109,8 +109,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
frequency in the training data.
|
frequency in the training data.
|
||||||
overwrite (bool): overwrite existing lemma annotations.
|
overwrite (bool): overwrite existing lemma annotations.
|
||||||
top_k (int): try to apply at most the k most probable edit trees.
|
top_k (int): try to apply at most the k most probable edit trees.
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations are: "probs" and "guesses".
|
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -125,7 +124,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
self.cfg: Dict[str, Any] = {"labels": []}
|
self.cfg: Dict[str, Any] = {"labels": []}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||||
|
@ -202,9 +201,10 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
|
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
|
||||||
batch_tree_ids = activations["guesses"]
|
batch_tree_ids = activations["guesses"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
for activation in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
for act_name, acts in activations.items():
|
||||||
|
doc.activations[self.name][act_name] = acts[i]
|
||||||
doc_tree_ids = batch_tree_ids[i]
|
doc_tree_ids = batch_tree_ids[i]
|
||||||
if hasattr(doc_tree_ids, "get"):
|
if hasattr(doc_tree_ids, "get"):
|
||||||
doc_tree_ids = doc_tree_ids.get()
|
doc_tree_ids = doc_tree_ids.get()
|
||||||
|
|
|
@ -85,7 +85,7 @@ def make_entity_linker(
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
):
|
):
|
||||||
"""Construct an EntityLinker component.
|
"""Construct an EntityLinker component.
|
||||||
|
|
||||||
|
@ -104,8 +104,7 @@ def make_entity_linker(
|
||||||
component must provide entity annotations.
|
component must provide entity annotations.
|
||||||
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
|
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
|
||||||
prediction is discarded. If None, predictions are not filtered by any threshold.
|
prediction is discarded. If None, predictions are not filtered by any threshold.
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations are: "ents" and "scores".
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not model.attrs.get("include_span_maker", False):
|
if not model.attrs.get("include_span_maker", False):
|
||||||
|
@ -174,7 +173,7 @@ class EntityLinker(TrainablePipe):
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize an entity linker.
|
"""Initialize an entity linker.
|
||||||
|
|
||||||
|
@ -223,7 +222,7 @@ class EntityLinker(TrainablePipe):
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.use_gold_ents = use_gold_ents
|
self.use_gold_ents = use_gold_ents
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||||
"""Define the KB of this pipe by providing a function that will
|
"""Define the KB of this pipe by providing a function that will
|
||||||
|
@ -551,12 +550,13 @@ class EntityLinker(TrainablePipe):
|
||||||
i = 0
|
i = 0
|
||||||
overwrite = self.cfg["overwrite"]
|
overwrite = self.cfg["overwrite"]
|
||||||
for j, doc in enumerate(docs):
|
for j, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
for activation in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
# We only copy activations that are Ragged.
|
for act_name, acts in activations.items():
|
||||||
doc.activations[self.name][activation] = cast(
|
if act_name != "kb_ids":
|
||||||
Ragged, activations[activation][j]
|
# We only copy activations that are Ragged.
|
||||||
)
|
doc.activations[self.name][act_name] = cast(Ragged, acts[j])
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
kb_id = kb_ids[i]
|
kb_id = kb_ids[i]
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -668,7 +668,7 @@ class EntityLinker(TrainablePipe):
|
||||||
doc_scores: List[Floats1d],
|
doc_scores: List[Floats1d],
|
||||||
doc_ents: List[Ints1d],
|
doc_ents: List[Ints1d],
|
||||||
):
|
):
|
||||||
if len(self.store_activations) == 0:
|
if not self.store_activations:
|
||||||
return
|
return
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
|
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
|
||||||
|
@ -683,7 +683,7 @@ class EntityLinker(TrainablePipe):
|
||||||
scores: Sequence[float],
|
scores: Sequence[float],
|
||||||
ents: Sequence[int],
|
ents: Sequence[int],
|
||||||
):
|
):
|
||||||
if len(self.store_activations) == 0:
|
if not self.store_activations:
|
||||||
return
|
return
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
doc_scores.append(ops.asarray1f(scores))
|
doc_scores.append(ops.asarray1f(scores))
|
||||||
|
|
|
@ -69,7 +69,7 @@ def make_morphologizer(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
extend: bool,
|
extend: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
):
|
):
|
||||||
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer,
|
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer,
|
||||||
store_activations=store_activations)
|
store_activations=store_activations)
|
||||||
|
@ -104,7 +104,7 @@ class Morphologizer(Tagger):
|
||||||
overwrite: bool = BACKWARD_OVERWRITE,
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
extend: bool = BACKWARD_EXTEND,
|
extend: bool = BACKWARD_EXTEND,
|
||||||
scorer: Optional[Callable] = morphologizer_score,
|
scorer: Optional[Callable] = morphologizer_score,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize a morphologizer.
|
"""Initialize a morphologizer.
|
||||||
|
|
||||||
|
@ -115,8 +115,7 @@ class Morphologizer(Tagger):
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
Scorer.score_token_attr for the attributes "pos" and "morph" and
|
Scorer.score_token_attr for the attributes "pos" and "morph" and
|
||||||
Scorer.score_token_attr_per_feat for the attribute "morph".
|
Scorer.score_token_attr_per_feat for the attribute "morph".
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations are: "probs" and "guesses".
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/morphologizer#init
|
DOCS: https://spacy.io/api/morphologizer#init
|
||||||
"""
|
"""
|
||||||
|
@ -136,7 +135,7 @@ class Morphologizer(Tagger):
|
||||||
}
|
}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -250,9 +249,10 @@ class Morphologizer(Tagger):
|
||||||
# to allocate a compatible container out of the iterable.
|
# to allocate a compatible container out of the iterable.
|
||||||
labels = tuple(self.labels)
|
labels = tuple(self.labels)
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
for activation in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
for act_name, acts in activations.items():
|
||||||
|
doc.activations[self.name][act_name] = acts[i]
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -52,7 +52,7 @@ def make_senter(nlp: Language,
|
||||||
model: Model,
|
model: Model,
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
store_activations: Union[bool, List[str]]):
|
store_activations: bool):
|
||||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ class SentenceRecognizer(Tagger):
|
||||||
*,
|
*,
|
||||||
overwrite=BACKWARD_OVERWRITE,
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=senter_score,
|
scorer=senter_score,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize a sentence recognizer.
|
"""Initialize a sentence recognizer.
|
||||||
|
|
||||||
|
@ -93,8 +93,7 @@ class SentenceRecognizer(Tagger):
|
||||||
losses during training.
|
losses during training.
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
Scorer.score_spans for the attribute "sents".
|
Scorer.score_spans for the attribute "sents".
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations are: "probs" and "guesses".
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencerecognizer#init
|
DOCS: https://spacy.io/api/sentencerecognizer#init
|
||||||
"""
|
"""
|
||||||
|
@ -104,7 +103,7 @@ class SentenceRecognizer(Tagger):
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self.cfg = {"overwrite": overwrite}
|
self.cfg = {"overwrite": overwrite}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -136,9 +135,10 @@ class SentenceRecognizer(Tagger):
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
for activation in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
for act_name, acts in activations.items():
|
||||||
|
doc.activations[self.name][act_name] = acts[i]
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -120,7 +120,7 @@ def make_spancat(
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
max_positive: Optional[int],
|
max_positive: Optional[int],
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
) -> "SpanCategorizer":
|
) -> "SpanCategorizer":
|
||||||
"""Create a SpanCategorizer component. The span categorizer consists of two
|
"""Create a SpanCategorizer component. The span categorizer consists of two
|
||||||
parts: a suggester function that proposes candidate spans, and a labeller
|
parts: a suggester function that proposes candidate spans, and a labeller
|
||||||
|
@ -141,8 +141,7 @@ def make_spancat(
|
||||||
0.5.
|
0.5.
|
||||||
max_positive (Optional[int]): Maximum number of labels to consider positive
|
max_positive (Optional[int]): Maximum number of labels to consider positive
|
||||||
per span. Defaults to None, indicating no limit.
|
per span. Defaults to None, indicating no limit.
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations are: "indices" and "scores".
|
|
||||||
"""
|
"""
|
||||||
return SpanCategorizer(
|
return SpanCategorizer(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
|
@ -192,7 +191,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
max_positive: Optional[int] = None,
|
max_positive: Optional[int] = None,
|
||||||
scorer: Optional[Callable] = spancat_score,
|
scorer: Optional[Callable] = spancat_score,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the span categorizer.
|
"""Initialize the span categorizer.
|
||||||
vocab (Vocab): The shared vocabulary.
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
@ -225,7 +224,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
|
@ -317,10 +316,9 @@ class SpanCategorizer(TrainablePipe):
|
||||||
offset = 0
|
offset = 0
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
indices_i = indices[i].dataXd
|
indices_i = indices[i].dataXd
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
if "indices" in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name]["indices"] = indices_i
|
doc.activations[self.name]["indices"] = indices_i
|
||||||
if "scores" in self.store_activations:
|
|
||||||
doc.activations[self.name]["scores"] = scores[
|
doc.activations[self.name]["scores"] = scores[
|
||||||
offset : offset + indices.lengths[i]
|
offset : offset + indices.lengths[i]
|
||||||
]
|
]
|
||||||
|
|
|
@ -61,7 +61,7 @@ def make_tagger(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
neg_prefix: str,
|
neg_prefix: str,
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
):
|
):
|
||||||
"""Construct a part-of-speech tagger component.
|
"""Construct a part-of-speech tagger component.
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
|
||||||
overwrite=BACKWARD_OVERWRITE,
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=tagger_score,
|
scorer=tagger_score,
|
||||||
neg_prefix="!",
|
neg_prefix="!",
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize a part-of-speech tagger.
|
"""Initialize a part-of-speech tagger.
|
||||||
|
|
||||||
|
@ -107,8 +107,7 @@ class Tagger(TrainablePipe):
|
||||||
losses during training.
|
losses during training.
|
||||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||||
Scorer.score_token_attr for the attribute "tag".
|
Scorer.score_token_attr for the attribute "tag".
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations are: "probs" and "guesses".
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tagger#init
|
DOCS: https://spacy.io/api/tagger#init
|
||||||
"""
|
"""
|
||||||
|
@ -119,7 +118,7 @@ class Tagger(TrainablePipe):
|
||||||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -183,9 +182,10 @@ class Tagger(TrainablePipe):
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
labels = self.labels
|
labels = self.labels
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
for activation in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name][activation] = activations[activation][i]
|
for act_name, acts in activations.items():
|
||||||
|
doc.activations[self.name][act_name] = acts[i]
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -97,7 +97,7 @@ def make_textcat(
|
||||||
model: Model[List[Doc], List[Floats2d]],
|
model: Model[List[Doc], List[Floats2d]],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
) -> "TextCategorizer":
|
) -> "TextCategorizer":
|
||||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||||
over a whole document. It can learn one or more labels, and the labels are considered
|
over a whole document. It can learn one or more labels, and the labels are considered
|
||||||
|
@ -107,8 +107,7 @@ def make_textcat(
|
||||||
scores for each category.
|
scores for each category.
|
||||||
threshold (float): Cutoff to consider a prediction "positive".
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
scorer (Optional[Callable]): The scoring method.
|
scorer (Optional[Callable]): The scoring method.
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations is: "probs".
|
|
||||||
"""
|
"""
|
||||||
return TextCategorizer(
|
return TextCategorizer(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
|
@ -148,7 +147,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
*,
|
*,
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable] = textcat_score,
|
scorer: Optional[Callable] = textcat_score,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer for single-label classification.
|
"""Initialize a text categorizer for single-label classification.
|
||||||
|
|
||||||
|
@ -169,7 +168,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_missing_values(self):
|
def support_missing_values(self):
|
||||||
|
@ -224,8 +223,8 @@ class TextCategorizer(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
probs = activations["probs"]
|
probs = activations["probs"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
if self.store_activations:
|
||||||
if "probs" in self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name]["probs"] = probs[i]
|
doc.activations[self.name]["probs"] = probs[i]
|
||||||
for j, label in enumerate(self.labels):
|
for j, label in enumerate(self.labels):
|
||||||
doc.cats[label] = float(probs[i, j])
|
doc.cats[label] = float(probs[i, j])
|
||||||
|
|
|
@ -97,7 +97,7 @@ def make_multilabel_textcat(
|
||||||
model: Model[List[Doc], List[Floats2d]],
|
model: Model[List[Doc], List[Floats2d]],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
store_activations: Union[bool, List[str]],
|
store_activations: bool,
|
||||||
) -> "TextCategorizer":
|
) -> "TextCategorizer":
|
||||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||||
over a whole document. It can learn one or more labels, and the labels are considered
|
over a whole document. It can learn one or more labels, and the labels are considered
|
||||||
|
@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
*,
|
*,
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable] = textcat_multilabel_score,
|
scorer: Optional[Callable] = textcat_multilabel_score,
|
||||||
store_activations: Union[bool, List[str]] = False,
|
store_activations: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer for multi-label classification.
|
"""Initialize a text categorizer for multi-label classification.
|
||||||
|
|
||||||
|
@ -155,8 +155,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
name (str): The component instance name, used to add entries to the
|
name (str): The component instance name, used to add entries to the
|
||||||
losses during training.
|
losses during training.
|
||||||
threshold (float): Cutoff to consider a prediction "positive".
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
store_activations (Union[bool, List[str]]): Model activations to store in
|
store_activations (bool): store model activations in Doc when annotating.
|
||||||
Doc when annotating. supported activations is: "probs".
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/textcategorizer#init
|
DOCS: https://spacy.io/api/textcategorizer#init
|
||||||
"""
|
"""
|
||||||
|
@ -167,7 +166,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
cfg = {"labels": [], "threshold": threshold}
|
cfg = {"labels": [], "threshold": threshold}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.set_store_activations(store_activations)
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_missing_values(self):
|
def support_missing_values(self):
|
||||||
|
|
|
@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe):
|
||||||
cdef public object model
|
cdef public object model
|
||||||
cdef public object cfg
|
cdef public object cfg
|
||||||
cdef public object scorer
|
cdef public object scorer
|
||||||
cdef object _store_activations
|
cdef bint _store_activations
|
||||||
|
|
|
@ -352,19 +352,6 @@ cdef class TrainablePipe(Pipe):
|
||||||
def store_activations(self):
|
def store_activations(self):
|
||||||
return self._store_activations
|
return self._store_activations
|
||||||
|
|
||||||
def set_store_activations(self, activations):
|
@store_activations.setter
|
||||||
known_activations = self.activations
|
def store_activations(self, store_activations: bool):
|
||||||
if isinstance(activations, list):
|
self._store_activations = store_activations
|
||||||
self._store_activations = []
|
|
||||||
for activation in activations:
|
|
||||||
if activation in known_activations:
|
|
||||||
self._store_activations.append(activation)
|
|
||||||
else:
|
|
||||||
warnings.warn(Warnings.W400.format(activation=activation, pipe_name=self.name))
|
|
||||||
elif isinstance(activations, bool):
|
|
||||||
if activations:
|
|
||||||
self._store_activations = list(known_activations)
|
|
||||||
else:
|
|
||||||
self._store_activations = []
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E1400)
|
|
||||||
|
|
|
@ -293,15 +293,10 @@ def test_store_activations():
|
||||||
nO = lemmatizer.model.get_dim("nO")
|
nO = lemmatizer.model.get_dim("nO")
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
|
assert "trainable_lemmatizer" not in doc.activations
|
||||||
|
|
||||||
lemmatizer.set_store_activations(True)
|
lemmatizer.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
|
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
|
||||||
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||||
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
|
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
lemmatizer.set_store_activations(["probs"])
|
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
|
|
||||||
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
|
||||||
|
|
|
@ -1225,9 +1225,9 @@ def test_store_activations():
|
||||||
ruler.add_patterns(patterns)
|
ruler.add_patterns(patterns)
|
||||||
|
|
||||||
doc = nlp("Russ Cochran was a publisher")
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
assert len(doc.activations["entity_linker"].keys()) == 0
|
assert "entity_linker" not in doc.activations
|
||||||
|
|
||||||
entity_linker.set_store_activations(True)
|
entity_linker.store_activations = True
|
||||||
doc = nlp("Russ Cochran was a publisher")
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
|
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
|
||||||
ents = doc.activations["entity_linker"]["ents"]
|
ents = doc.activations["entity_linker"]["ents"]
|
||||||
|
@ -1240,12 +1240,3 @@ def test_store_activations():
|
||||||
assert scores.data.shape == (2, 1)
|
assert scores.data.shape == (2, 1)
|
||||||
assert scores.data.dtype == "float32"
|
assert scores.data.dtype == "float32"
|
||||||
assert scores.lengths.shape == (1,)
|
assert scores.lengths.shape == (1,)
|
||||||
|
|
||||||
entity_linker.set_store_activations(["scores"])
|
|
||||||
doc = nlp("Russ Cochran was a publisher")
|
|
||||||
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
|
|
||||||
scores = doc.activations["entity_linker"]["scores"]
|
|
||||||
assert isinstance(scores, Ragged)
|
|
||||||
assert scores.data.shape == (2, 1)
|
|
||||||
assert scores.data.dtype == "float32"
|
|
||||||
assert scores.lengths.shape == (1,)
|
|
||||||
|
|
|
@ -211,17 +211,11 @@ def test_store_activations():
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["morphologizer"].keys())) == 0
|
assert "morphologizer" not in doc.activations
|
||||||
|
|
||||||
morphologizer.set_store_activations(True)
|
morphologizer.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "morphologizer" in doc.activations
|
assert "morphologizer" in doc.activations
|
||||||
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
|
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
|
||||||
assert doc.activations["morphologizer"]["guesses"].shape == (5,)
|
assert doc.activations["morphologizer"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
morphologizer.set_store_activations(["probs"])
|
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert "morphologizer" in doc.activations
|
|
||||||
assert set(doc.activations["morphologizer"].keys()) == {"probs"}
|
|
||||||
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
|
|
||||||
|
|
|
@ -118,17 +118,11 @@ def test_store_activations():
|
||||||
nO = senter.model.get_dim("nO")
|
nO = senter.model.get_dim("nO")
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["senter"].keys())) == 0
|
assert "senter" not in doc.activations
|
||||||
|
|
||||||
senter.set_store_activations(True)
|
senter.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "senter" in doc.activations
|
assert "senter" in doc.activations
|
||||||
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["senter"]["probs"].shape == (5, nO)
|
assert doc.activations["senter"]["probs"].shape == (5, nO)
|
||||||
assert doc.activations["senter"]["guesses"].shape == (5,)
|
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
senter.set_store_activations(["probs"])
|
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert "senter" in doc.activations
|
|
||||||
assert set(doc.activations["senter"].keys()) == {"probs"}
|
|
||||||
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
|
||||||
|
|
|
@ -432,15 +432,10 @@ def test_store_activations():
|
||||||
assert set(spancat.labels) == {"LOC", "PERSON"}
|
assert set(spancat.labels) == {"LOC", "PERSON"}
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["spancat"].keys())) == 0
|
assert "spancat" not in doc.activations
|
||||||
|
|
||||||
spancat.set_store_activations(True)
|
spancat.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
|
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
|
||||||
assert doc.activations["spancat"]["indices"].shape == (12, 2)
|
assert doc.activations["spancat"]["indices"].shape == (12, 2)
|
||||||
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
||||||
|
|
||||||
spancat.set_store_activations(["scores"])
|
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert set(doc.activations["spancat"].keys()) == {"scores"}
|
|
||||||
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
|
||||||
|
|
|
@ -223,20 +223,15 @@ def test_store_activations():
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["tagger"].keys())) == 0
|
assert "tagger" not in doc.activations
|
||||||
|
|
||||||
tagger.set_store_activations(True)
|
tagger.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "tagger" in doc.activations
|
assert "tagger" in doc.activations
|
||||||
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||||
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
tagger.set_store_activations(["probs"])
|
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert set(doc.activations["tagger"].keys()) == {"probs"}
|
|
||||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
|
||||||
|
|
||||||
|
|
||||||
def test_tagger_requires_labels():
|
def test_tagger_requires_labels():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
|
@ -886,14 +886,9 @@ def test_store_activations():
|
||||||
nO = textcat.model.get_dim("nO")
|
nO = textcat.model.get_dim("nO")
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["textcat"].keys())) == 0
|
assert "textcat" not in doc.activations
|
||||||
|
|
||||||
textcat.set_store_activations(True)
|
textcat.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
|
||||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
|
||||||
|
|
||||||
textcat.set_store_activations(["probs"])
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||||
|
@ -911,14 +906,9 @@ def test_store_activations_multi():
|
||||||
nO = textcat.model.get_dim("nO")
|
nO = textcat.model.get_dim("nO")
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
|
assert "textcat_multilabel" not in doc.activations
|
||||||
|
|
||||||
textcat.set_store_activations(True)
|
textcat.store_activations = True
|
||||||
doc = nlp("This is a test.")
|
|
||||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
|
||||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
|
||||||
|
|
||||||
textcat.set_store_activations(["probs"])
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user