mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
Change type of store_activations
to Union[bool, List[str]]
When the value is: - A bool: all activations are stored when set to `True`. - A List[str]: the activations named in the list are stored
This commit is contained in:
parent
b71c6043bc
commit
c3da32b46b
|
@ -209,6 +209,7 @@ class Warnings(metaclass=ErrorsWithCodes):
|
|||
"Only the last span group will be loaded under "
|
||||
"Doc.spans['{group_name}']. Skipping span group with values: "
|
||||
"{group_values}")
|
||||
W121 = ("Activation '{activation}' is unknown for pipe '{pipe_name}'")
|
||||
|
||||
|
||||
class Errors(metaclass=ErrorsWithCodes):
|
||||
|
@ -934,6 +935,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||
E1043 = ("store_activations attribute must be set to List[str] or bool")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -217,7 +217,7 @@ class Morphologizer(Tagger):
|
|||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||
|
||||
def set_annotations(self, docs, scores_guesses):
|
||||
def set_annotations(self, docs, activations):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -225,7 +225,7 @@ class Morphologizer(Tagger):
|
|||
|
||||
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
||||
"""
|
||||
_, batch_tag_ids = scores_guesses
|
||||
batch_tag_ids = activations["guesses"]
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, Callable, List, Union
|
||||
from itertools import islice
|
||||
|
||||
import srsly
|
||||
|
@ -46,7 +46,12 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)
|
||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool):
|
||||
def make_senter(nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
store_activations: Union[bool, List[str]]):
|
||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
||||
|
||||
|
||||
|
@ -114,7 +119,7 @@ class SentenceRecognizer(Tagger):
|
|||
def label_data(self):
|
||||
return None
|
||||
|
||||
def set_annotations(self, docs, scores_guesses):
|
||||
def set_annotations(self, docs, activations):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -122,17 +127,15 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
||||
"""
|
||||
_, batch_tag_ids = scores_guesses
|
||||
batch_tag_ids = activations["guesses"]
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
for i, doc in enumerate(docs):
|
||||
if self.store_activations:
|
||||
doc.activations[self.name] = {
|
||||
"probs": scores_guesses[0][i],
|
||||
"guesses": scores_guesses[1][i],
|
||||
}
|
||||
doc.activations[self.name] = {}
|
||||
for activation in self.store_activations:
|
||||
doc.activations[self.name][activation] = activations[activation][i]
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
|
@ -199,3 +202,7 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
def add_label(self, label, values=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def activations(self):
|
||||
return ["probs", "guesses"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional, Union
|
||||
import numpy
|
||||
import srsly
|
||||
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
|
||||
|
@ -61,7 +61,7 @@ def make_tagger(
|
|||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
neg_prefix: str,
|
||||
store_activations: bool,
|
||||
store_activations: Union[bool, List[str]],
|
||||
):
|
||||
"""Construct a part-of-speech tagger component.
|
||||
|
||||
|
@ -148,12 +148,12 @@ class Tagger(TrainablePipe):
|
|||
n_labels = len(self.labels)
|
||||
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
|
||||
assert len(guesses) == len(docs)
|
||||
return guesses, guesses
|
||||
return {"probs": guesses, "guesses": guesses}
|
||||
scores = self.model.predict(docs)
|
||||
assert len(scores) == len(docs), (len(scores), len(docs))
|
||||
guesses = self._scores2guesses(scores)
|
||||
assert len(guesses) == len(docs)
|
||||
return scores, guesses
|
||||
return {"probs": scores, "guesses": guesses}
|
||||
|
||||
def _scores2guesses(self, scores):
|
||||
guesses = []
|
||||
|
@ -164,7 +164,7 @@ class Tagger(TrainablePipe):
|
|||
guesses.append(doc_guesses)
|
||||
return guesses
|
||||
|
||||
def set_annotations(self, docs, scores_guesses):
|
||||
def set_annotations(self, docs, activations):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -172,7 +172,7 @@ class Tagger(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tagger#set_annotations
|
||||
"""
|
||||
_, batch_tag_ids = scores_guesses
|
||||
batch_tag_ids = activations["guesses"]
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -180,11 +180,9 @@ class Tagger(TrainablePipe):
|
|||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
labels = self.labels
|
||||
for i, doc in enumerate(docs):
|
||||
if self.store_activations:
|
||||
doc.activations[self.name] = {
|
||||
"probs": scores_guesses[0][i],
|
||||
"guesses": scores_guesses[1][i],
|
||||
}
|
||||
doc.activations[self.name] = {}
|
||||
for activation in self.store_activations:
|
||||
doc.activations[self.name][activation] = activations[activation][i]
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
|
@ -338,3 +336,7 @@ class Tagger(TrainablePipe):
|
|||
self.cfg["labels"].append(label)
|
||||
self.vocab.strings.add(label)
|
||||
return 1
|
||||
|
||||
@property
|
||||
def activations(self):
|
||||
return ["probs", "guesses"]
|
||||
|
|
|
@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe):
|
|||
cdef public object model
|
||||
cdef public object cfg
|
||||
cdef public object scorer
|
||||
cdef public bint store_activations
|
||||
cdef object _store_activations
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate, Model, Optimizer
|
||||
import warnings
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
from ..training import validate_examples
|
||||
from ..errors import Errors
|
||||
from ..errors import Errors, Warnings
|
||||
from .pipe import Pipe, deserialize_config
|
||||
from .. import util
|
||||
from ..vocab import Vocab
|
||||
|
@ -342,3 +343,29 @@ cdef class TrainablePipe(Pipe):
|
|||
deserialize["model"] = load_model
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
||||
@property
|
||||
def activations(self):
|
||||
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="activations", name=self.name))
|
||||
|
||||
@property
|
||||
def store_activations(self):
|
||||
return self._store_activations
|
||||
|
||||
@store_activations.setter
|
||||
def store_activations(self, activations):
|
||||
known_activations = self.activations
|
||||
if isinstance(activations, list):
|
||||
self._store_activations = []
|
||||
for activation in activations:
|
||||
if activation in known_activations:
|
||||
self._store_activations.append(activation)
|
||||
else:
|
||||
warnings.warn(Warnings.W121.format(activation=activation, pipe_name=self.name))
|
||||
elif isinstance(activations, bool):
|
||||
if activations:
|
||||
self._store_activations = list(known_activations)
|
||||
else:
|
||||
self._store_activations = []
|
||||
else:
|
||||
raise ValueError(Errors.E1043)
|
||||
|
|
|
@ -117,9 +117,14 @@ def test_store_activations():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
senter.store_activations = True
|
||||
|
||||
doc = nlp("This is a test.")
|
||||
assert "senter" in doc.activations
|
||||
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
||||
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||
|
||||
senter.store_activations = ["probs"]
|
||||
doc = nlp("This is a test.")
|
||||
assert "senter" in doc.activations
|
||||
assert set(doc.activations["senter"].keys()) == {"probs"}
|
||||
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
||||
|
|
|
@ -223,14 +223,17 @@ def test_store_activations():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
tagger.store_activations = True
|
||||
|
||||
doc = nlp("This is a test.")
|
||||
|
||||
assert "tagger" in doc.activations
|
||||
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||
|
||||
tagger.store_activations = ["probs"]
|
||||
doc = nlp("This is a test.")
|
||||
assert set(doc.activations["tagger"].keys()) == {"probs"}
|
||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||
|
||||
|
||||
def test_tagger_requires_labels():
|
||||
nlp = English()
|
||||
|
|
Loading…
Reference in New Issue
Block a user