mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
Add overwrite settings for more components (#9050)
* Add overwrite settings for more components For pipeline components where it's relevant and not already implemented, add an explicit `overwrite` setting that controls whether `set_annotations` overwrites existing annotation. For the `morphologizer`, add an additional setting `extend`, which controls whether the existing features are preserved. * +overwrite, +extend: overwrite values of existing features, add any new features * +overwrite, -extend: overwrite completely, removing any existing features * -overwrite, +extend: keep values of existing features, add any new features * -overwrite, -extend: do not modify the existing value if set In all cases an unset value will be set by `set_annotations`. Preserve current overwrite defaults: * True: morphologizer, entity linker * False: tagger, sentencizer, senter * Add backwards compat overwrite settings * Put empty line back Removed by accident in last commit * Set backwards-compatible defaults in __init__ Because the `TrainablePipe` serialization methods update `cfg`, there's no straightforward way to detect whether models serialized with a previous version are missing the overwrite settings. It would be possible in the sentencizer due to its separate serialization methods, however to keep the changes parallel, this also sets the default in `__init__`. * Remove traces Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>
This commit is contained in:
parent
8fe525beb5
commit
03fefa37e2
|
@ -20,6 +20,8 @@ from ..util import SimpleFrozenList, registry
|
|||
from .. import util
|
||||
from ..scorer import Scorer
|
||||
|
||||
# See #9050
|
||||
BACKWARD_OVERWRITE = True
|
||||
|
||||
default_model_config = """
|
||||
[model]
|
||||
|
@ -50,6 +52,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
"overwrite": True,
|
||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
|
@ -69,6 +72,7 @@ def make_entity_linker(
|
|||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Construct an EntityLinker component.
|
||||
|
@ -95,6 +99,7 @@ def make_entity_linker(
|
|||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
@ -128,6 +133,7 @@ class EntityLinker(TrainablePipe):
|
|||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
scorer: Optional[Callable] = entity_linker_score,
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
@ -156,7 +162,7 @@ class EntityLinker(TrainablePipe):
|
|||
self.incl_prior = incl_prior
|
||||
self.incl_context = incl_context
|
||||
self.get_candidates = get_candidates
|
||||
self.cfg = {}
|
||||
self.cfg = {"overwrite": overwrite}
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neighbour sentences to take into account
|
||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||
|
@ -399,12 +405,14 @@ class EntityLinker(TrainablePipe):
|
|||
if count_ents != len(kb_ids):
|
||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||
i = 0
|
||||
overwrite = self.cfg["overwrite"]
|
||||
for doc in docs:
|
||||
for ent in doc.ents:
|
||||
kb_id = kb_ids[i]
|
||||
i += 1
|
||||
for token in ent:
|
||||
token.ent_kb_id_ = kb_id
|
||||
if token.ent_kb_id == 0 or overwrite:
|
||||
token.ent_kb_id_ = kb_id
|
||||
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
|
|
@ -19,6 +19,9 @@ from ..scorer import Scorer
|
|||
from ..training import validate_examples, validate_get_examples
|
||||
from ..util import registry
|
||||
|
||||
# See #9050
|
||||
BACKWARD_OVERWRITE = True
|
||||
BACKWARD_EXTEND = False
|
||||
|
||||
default_model_config = """
|
||||
[model]
|
||||
|
@ -49,16 +52,18 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"morphologizer",
|
||||
assigns=["token.morph", "token.pos"],
|
||||
default_config={"model": DEFAULT_MORPH_MODEL, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}},
|
||||
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}},
|
||||
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
|
||||
)
|
||||
def make_morphologizer(
|
||||
nlp: Language,
|
||||
model: Model,
|
||||
name: str,
|
||||
overwrite: bool,
|
||||
extend: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Morphologizer(nlp.vocab, model, name, scorer=scorer)
|
||||
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer)
|
||||
|
||||
|
||||
def morphologizer_score(examples, **kwargs):
|
||||
|
@ -87,6 +92,8 @@ class Morphologizer(Tagger):
|
|||
model: Model,
|
||||
name: str = "morphologizer",
|
||||
*,
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
extend: bool = BACKWARD_EXTEND,
|
||||
scorer: Optional[Callable] = morphologizer_score,
|
||||
):
|
||||
"""Initialize a morphologizer.
|
||||
|
@ -109,7 +116,12 @@ class Morphologizer(Tagger):
|
|||
# store mappings from morph+POS labels to token-level annotations:
|
||||
# 1) labels_morph stores a mapping from morph+POS->morph
|
||||
# 2) labels_pos stores a mapping from morph+POS->POS
|
||||
cfg = {"labels_morph": {}, "labels_pos": {}}
|
||||
cfg = {
|
||||
"labels_morph": {},
|
||||
"labels_pos": {},
|
||||
"overwrite": overwrite,
|
||||
"extend": extend,
|
||||
}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
|
||||
|
@ -217,14 +229,34 @@ class Morphologizer(Tagger):
|
|||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef Vocab vocab = self.vocab
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
cdef bint extend = self.cfg["extend"]
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
morph = self.labels[tag_id]
|
||||
doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels_morph"].get(morph, 0))
|
||||
doc.c[j].pos = self.cfg["labels_pos"].get(morph, 0)
|
||||
# set morph
|
||||
if doc.c[j].morph == 0 or overwrite or extend:
|
||||
if overwrite and extend:
|
||||
# morphologizer morph overwrites any existing features
|
||||
# while extending
|
||||
extended_morph = Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph])
|
||||
extended_morph.update(Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0)))
|
||||
doc.c[j].morph = self.vocab.morphology.add(extended_morph)
|
||||
elif extend:
|
||||
# existing features are preserved and any new features
|
||||
# are added
|
||||
extended_morph = Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0))
|
||||
extended_morph.update(Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph]))
|
||||
doc.c[j].morph = self.vocab.morphology.add(extended_morph)
|
||||
else:
|
||||
# clobber
|
||||
doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels_morph"].get(morph, 0))
|
||||
# set POS
|
||||
if doc.c[j].pos == 0 or overwrite:
|
||||
doc.c[j].pos = self.cfg["labels_pos"].get(morph, 0)
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
|
|
|
@ -10,20 +10,23 @@ from ..language import Language
|
|||
from ..scorer import Scorer
|
||||
from .. import util
|
||||
|
||||
# see #9050
|
||||
BACKWARD_OVERWRITE = False
|
||||
|
||||
@Language.factory(
|
||||
"sentencizer",
|
||||
assigns=["token.is_sent_start", "doc.sents"],
|
||||
default_config={"punct_chars": None, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_config={"punct_chars": None, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)
|
||||
def make_sentencizer(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
punct_chars: Optional[List[str]],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Sentencizer(name, punct_chars=punct_chars, scorer=scorer)
|
||||
return Sentencizer(name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer)
|
||||
|
||||
|
||||
class Sentencizer(Pipe):
|
||||
|
@ -49,6 +52,7 @@ class Sentencizer(Pipe):
|
|||
name="sentencizer",
|
||||
*,
|
||||
punct_chars=None,
|
||||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=senter_score,
|
||||
):
|
||||
"""Initialize the sentencizer.
|
||||
|
@ -65,6 +69,7 @@ class Sentencizer(Pipe):
|
|||
self.punct_chars = set(punct_chars)
|
||||
else:
|
||||
self.punct_chars = set(self.default_punct_chars)
|
||||
self.overwrite = overwrite
|
||||
self.scorer = scorer
|
||||
|
||||
def __call__(self, doc):
|
||||
|
@ -126,8 +131,7 @@ class Sentencizer(Pipe):
|
|||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
# Don't clobber existing sentence boundaries
|
||||
if doc.c[j].sent_start == 0:
|
||||
if doc.c[j].sent_start == 0 or self.overwrite:
|
||||
if tag_id:
|
||||
doc.c[j].sent_start = 1
|
||||
else:
|
||||
|
@ -140,7 +144,7 @@ class Sentencizer(Pipe):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencizer#to_bytes
|
||||
"""
|
||||
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
|
||||
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the sentencizer from a bytestring.
|
||||
|
@ -152,6 +156,7 @@ class Sentencizer(Pipe):
|
|||
"""
|
||||
cfg = srsly.msgpack_loads(bytes_data)
|
||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
self.overwrite = cfg.get("overwrite", self.overwrite)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
|
@ -161,7 +166,7 @@ class Sentencizer(Pipe):
|
|||
"""
|
||||
path = util.ensure_path(path)
|
||||
path = path.with_suffix(".json")
|
||||
srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
|
||||
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
||||
|
||||
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
|
@ -173,4 +178,5 @@ class Sentencizer(Pipe):
|
|||
path = path.with_suffix(".json")
|
||||
cfg = srsly.read_json(path)
|
||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
self.overwrite = cfg.get("overwrite", self.overwrite)
|
||||
return self
|
||||
|
|
|
@ -15,6 +15,8 @@ from ..training import validate_examples, validate_get_examples
|
|||
from ..util import registry
|
||||
from .. import util
|
||||
|
||||
# See #9050
|
||||
BACKWARD_OVERWRITE = False
|
||||
|
||||
default_model_config = """
|
||||
[model]
|
||||
|
@ -36,11 +38,11 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"senter",
|
||||
assigns=["token.is_sent_start"],
|
||||
default_config={"model": DEFAULT_SENTER_MODEL, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)
|
||||
def make_senter(nlp: Language, name: str, model: Model, scorer: Optional[Callable]):
|
||||
return SentenceRecognizer(nlp.vocab, model, name, scorer=scorer)
|
||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]):
|
||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||
|
||||
|
||||
def senter_score(examples, **kwargs):
|
||||
|
@ -62,7 +64,15 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer
|
||||
"""
|
||||
def __init__(self, vocab, model, name="senter", *, scorer=senter_score):
|
||||
def __init__(
|
||||
self,
|
||||
vocab,
|
||||
model,
|
||||
name="senter",
|
||||
*,
|
||||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=senter_score,
|
||||
):
|
||||
"""Initialize a sentence recognizer.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
|
@ -78,7 +88,7 @@ class SentenceRecognizer(Tagger):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
self.cfg = {}
|
||||
self.cfg = {"overwrite": overwrite}
|
||||
self.scorer = scorer
|
||||
|
||||
@property
|
||||
|
@ -104,13 +114,13 @@ class SentenceRecognizer(Tagger):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
# Don't clobber existing sentence boundaries
|
||||
if doc.c[j].sent_start == 0:
|
||||
if doc.c[j].sent_start == 0 or overwrite:
|
||||
if tag_id == 1:
|
||||
doc.c[j].sent_start = 1
|
||||
else:
|
||||
|
|
|
@ -22,6 +22,8 @@ from ..training import validate_examples, validate_get_examples
|
|||
from ..util import registry
|
||||
from .. import util
|
||||
|
||||
# See #9050
|
||||
BACKWARD_OVERWRITE = False
|
||||
|
||||
default_model_config = """
|
||||
[model]
|
||||
|
@ -43,10 +45,16 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
||||
default_score_weights={"tag_acc": 1.0},
|
||||
)
|
||||
def make_tagger(nlp: Language, name: str, model: Model, scorer: Optional[Callable]):
|
||||
def make_tagger(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Construct a part-of-speech tagger component.
|
||||
|
||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||
|
@ -54,7 +62,7 @@ def make_tagger(nlp: Language, name: str, model: Model, scorer: Optional[Callabl
|
|||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||
with the rows summing to 1).
|
||||
"""
|
||||
return Tagger(nlp.vocab, model, name, scorer=scorer)
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||
|
||||
|
||||
def tagger_score(examples, **kwargs):
|
||||
|
@ -71,7 +79,15 @@ class Tagger(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tagger
|
||||
"""
|
||||
def __init__(self, vocab, model, name="tagger", *, scorer=tagger_score):
|
||||
def __init__(
|
||||
self,
|
||||
vocab,
|
||||
model,
|
||||
name="tagger",
|
||||
*,
|
||||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=tagger_score,
|
||||
):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
|
@ -87,7 +103,7 @@ class Tagger(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": []}
|
||||
cfg = {"labels": [], "overwrite": overwrite}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
|
||||
|
@ -149,13 +165,13 @@ class Tagger(TrainablePipe):
|
|||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef Vocab vocab = self.vocab
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
# Don't clobber preset POS tags
|
||||
if doc.c[j].tag == 0:
|
||||
if doc.c[j].tag == 0 or overwrite:
|
||||
doc.c[j].tag = self.vocab.strings[self.labels[tag_id]]
|
||||
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||
|
|
|
@ -8,6 +8,7 @@ from spacy.language import Language
|
|||
from spacy.tests.util import make_tempdir
|
||||
from spacy.morphology import Morphology
|
||||
from spacy.attrs import MORPH
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
||||
def test_label_types():
|
||||
|
@ -137,6 +138,41 @@ def test_overfitting_IO():
|
|||
assert [str(t.morph) for t in doc] == gold_morphs
|
||||
assert [t.pos_ for t in doc] == gold_pos_tags
|
||||
|
||||
# Test overwrite+extend settings
|
||||
# (note that "" is unset, "_" is set and empty)
|
||||
morphs = ["Feat=V", "Feat=N", "_"]
|
||||
doc = Doc(nlp.vocab, words=["blue", "ham", "like"], morphs=morphs)
|
||||
orig_morphs = [str(t.morph) for t in doc]
|
||||
orig_pos_tags = [t.pos_ for t in doc]
|
||||
morphologizer = nlp.get_pipe("morphologizer")
|
||||
|
||||
# don't overwrite or extend
|
||||
morphologizer.cfg["overwrite"] = False
|
||||
doc = morphologizer(doc)
|
||||
assert [str(t.morph) for t in doc] == orig_morphs
|
||||
assert [t.pos_ for t in doc] == orig_pos_tags
|
||||
|
||||
# overwrite and extend
|
||||
morphologizer.cfg["overwrite"] = True
|
||||
morphologizer.cfg["extend"] = True
|
||||
doc = Doc(nlp.vocab, words=["I", "like"], morphs=["Feat=A|That=A|This=A", ""])
|
||||
doc = morphologizer(doc)
|
||||
assert [str(t.morph) for t in doc] == ["Feat=N|That=A|This=A", "Feat=V"]
|
||||
|
||||
# extend without overwriting
|
||||
morphologizer.cfg["overwrite"] = False
|
||||
morphologizer.cfg["extend"] = True
|
||||
doc = Doc(nlp.vocab, words=["I", "like"], morphs=["Feat=A|That=A|This=A", "That=B"])
|
||||
doc = morphologizer(doc)
|
||||
assert [str(t.morph) for t in doc] == ["Feat=A|That=A|This=A", "Feat=V|That=B"]
|
||||
|
||||
# overwrite without extending
|
||||
morphologizer.cfg["overwrite"] = True
|
||||
morphologizer.cfg["extend"] = False
|
||||
doc = Doc(nlp.vocab, words=["I", "like"], morphs=["Feat=A|That=A|This=A", ""])
|
||||
doc = morphologizer(doc)
|
||||
assert [str(t.morph) for t in doc] == ["Feat=N", "Feat=V"]
|
||||
|
||||
# Test with unset morph and partial POS
|
||||
nlp.remove_pipe("morphologizer")
|
||||
nlp.add_pipe("morphologizer")
|
||||
|
|
Loading…
Reference in New Issue
Block a user