mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-04 06:16:33 +03:00
Change type of store_activations
to Union[bool, List[str]]
When the value is: - A bool: all activations are stored when set to `True`. - A List[str]: the activations named in the list are stored
This commit is contained in:
parent
b71c6043bc
commit
c3da32b46b
|
@ -209,6 +209,7 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
"Only the last span group will be loaded under "
|
"Only the last span group will be loaded under "
|
||||||
"Doc.spans['{group_name}']. Skipping span group with values: "
|
"Doc.spans['{group_name}']. Skipping span group with values: "
|
||||||
"{group_values}")
|
"{group_values}")
|
||||||
|
W121 = ("Activation '{activation}' is unknown for pipe '{pipe_name}'")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
|
@ -934,6 +935,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||||
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
E1042 = ("Function was called with `{arg1}`={arg1_values} and "
|
||||||
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
"`{arg2}`={arg2_values} but these arguments are conflicting.")
|
||||||
|
E1043 = ("store_activations attribute must be set to List[str] or bool")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -217,7 +217,7 @@ class Morphologizer(Tagger):
|
||||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
|
|
||||||
def set_annotations(self, docs, scores_guesses):
|
def set_annotations(self, docs, activations):
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -225,7 +225,7 @@ class Morphologizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
||||||
"""
|
"""
|
||||||
_, batch_tag_ids = scores_guesses
|
batch_tag_ids = activations["guesses"]
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable, List, Union
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -46,7 +46,12 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
},
|
},
|
||||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||||
)
|
)
|
||||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool):
|
def make_senter(nlp: Language,
|
||||||
|
name: str,
|
||||||
|
model: Model,
|
||||||
|
overwrite: bool,
|
||||||
|
scorer: Optional[Callable],
|
||||||
|
store_activations: Union[bool, List[str]]):
|
||||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,7 +119,7 @@ class SentenceRecognizer(Tagger):
|
||||||
def label_data(self):
|
def label_data(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set_annotations(self, docs, scores_guesses):
|
def set_annotations(self, docs, activations):
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -122,17 +127,15 @@ class SentenceRecognizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
||||||
"""
|
"""
|
||||||
_, batch_tag_ids = scores_guesses
|
batch_tag_ids = activations["guesses"]
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name] = {
|
for activation in self.store_activations:
|
||||||
"probs": scores_guesses[0][i],
|
doc.activations[self.name][activation] = activations[activation][i]
|
||||||
"guesses": scores_guesses[1][i],
|
|
||||||
}
|
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
@ -199,3 +202,7 @@ class SentenceRecognizer(Tagger):
|
||||||
|
|
||||||
def add_label(self, label, values=None):
|
def add_label(self, label, values=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activations(self):
|
||||||
|
return ["probs", "guesses"]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
from typing import Callable, Optional
|
from typing import Callable, List, Optional, Union
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
|
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
|
||||||
|
@ -61,7 +61,7 @@ def make_tagger(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
neg_prefix: str,
|
neg_prefix: str,
|
||||||
store_activations: bool,
|
store_activations: Union[bool, List[str]],
|
||||||
):
|
):
|
||||||
"""Construct a part-of-speech tagger component.
|
"""Construct a part-of-speech tagger component.
|
||||||
|
|
||||||
|
@ -148,12 +148,12 @@ class Tagger(TrainablePipe):
|
||||||
n_labels = len(self.labels)
|
n_labels = len(self.labels)
|
||||||
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
|
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
|
||||||
assert len(guesses) == len(docs)
|
assert len(guesses) == len(docs)
|
||||||
return guesses, guesses
|
return {"probs": guesses, "guesses": guesses}
|
||||||
scores = self.model.predict(docs)
|
scores = self.model.predict(docs)
|
||||||
assert len(scores) == len(docs), (len(scores), len(docs))
|
assert len(scores) == len(docs), (len(scores), len(docs))
|
||||||
guesses = self._scores2guesses(scores)
|
guesses = self._scores2guesses(scores)
|
||||||
assert len(guesses) == len(docs)
|
assert len(guesses) == len(docs)
|
||||||
return scores, guesses
|
return {"probs": scores, "guesses": guesses}
|
||||||
|
|
||||||
def _scores2guesses(self, scores):
|
def _scores2guesses(self, scores):
|
||||||
guesses = []
|
guesses = []
|
||||||
|
@ -164,7 +164,7 @@ class Tagger(TrainablePipe):
|
||||||
guesses.append(doc_guesses)
|
guesses.append(doc_guesses)
|
||||||
return guesses
|
return guesses
|
||||||
|
|
||||||
def set_annotations(self, docs, scores_guesses):
|
def set_annotations(self, docs, activations):
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -172,7 +172,7 @@ class Tagger(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tagger#set_annotations
|
DOCS: https://spacy.io/api/tagger#set_annotations
|
||||||
"""
|
"""
|
||||||
_, batch_tag_ids = scores_guesses
|
batch_tag_ids = activations["guesses"]
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
@ -180,11 +180,9 @@ class Tagger(TrainablePipe):
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
labels = self.labels
|
labels = self.labels
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if self.store_activations:
|
doc.activations[self.name] = {}
|
||||||
doc.activations[self.name] = {
|
for activation in self.store_activations:
|
||||||
"probs": scores_guesses[0][i],
|
doc.activations[self.name][activation] = activations[activation][i]
|
||||||
"guesses": scores_guesses[1][i],
|
|
||||||
}
|
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
@ -338,3 +336,7 @@ class Tagger(TrainablePipe):
|
||||||
self.cfg["labels"].append(label)
|
self.cfg["labels"].append(label)
|
||||||
self.vocab.strings.add(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activations(self):
|
||||||
|
return ["probs", "guesses"]
|
||||||
|
|
|
@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe):
|
||||||
cdef public object model
|
cdef public object model
|
||||||
cdef public object cfg
|
cdef public object cfg
|
||||||
cdef public object scorer
|
cdef public object scorer
|
||||||
cdef public bint store_activations
|
cdef object _store_activations
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
|
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import set_dropout_rate, Model, Optimizer
|
from thinc.api import set_dropout_rate, Model, Optimizer
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
|
|
||||||
from ..training import validate_examples
|
from ..training import validate_examples
|
||||||
from ..errors import Errors
|
from ..errors import Errors, Warnings
|
||||||
from .pipe import Pipe, deserialize_config
|
from .pipe import Pipe, deserialize_config
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
@ -342,3 +343,29 @@ cdef class TrainablePipe(Pipe):
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activations(self):
|
||||||
|
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="activations", name=self.name))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def store_activations(self):
|
||||||
|
return self._store_activations
|
||||||
|
|
||||||
|
@store_activations.setter
|
||||||
|
def store_activations(self, activations):
|
||||||
|
known_activations = self.activations
|
||||||
|
if isinstance(activations, list):
|
||||||
|
self._store_activations = []
|
||||||
|
for activation in activations:
|
||||||
|
if activation in known_activations:
|
||||||
|
self._store_activations.append(activation)
|
||||||
|
else:
|
||||||
|
warnings.warn(Warnings.W121.format(activation=activation, pipe_name=self.name))
|
||||||
|
elif isinstance(activations, bool):
|
||||||
|
if activations:
|
||||||
|
self._store_activations = list(known_activations)
|
||||||
|
else:
|
||||||
|
self._store_activations = []
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E1043)
|
||||||
|
|
|
@ -117,9 +117,14 @@ def test_store_activations():
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
senter.store_activations = True
|
senter.store_activations = True
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "senter" in doc.activations
|
assert "senter" in doc.activations
|
||||||
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
||||||
assert doc.activations["senter"]["guesses"].shape == (5,)
|
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
|
senter.store_activations = ["probs"]
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert "senter" in doc.activations
|
||||||
|
assert set(doc.activations["senter"].keys()) == {"probs"}
|
||||||
|
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
||||||
|
|
|
@ -223,14 +223,17 @@ def test_store_activations():
|
||||||
nlp.initialize(get_examples=lambda: train_examples)
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
tagger.store_activations = True
|
tagger.store_activations = True
|
||||||
|
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
|
|
||||||
assert "tagger" in doc.activations
|
assert "tagger" in doc.activations
|
||||||
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||||
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
|
tagger.store_activations = ["probs"]
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert set(doc.activations["tagger"].keys()) == {"probs"}
|
||||||
|
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||||
|
|
||||||
|
|
||||||
def test_tagger_requires_labels():
|
def test_tagger_requires_labels():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user