mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-09 08:00:34 +03:00
Add overwrite settings for more components (#9050)
* Add overwrite settings for more components For pipeline components where it's relevant and not already implemented, add an explicit `overwrite` setting that controls whether `set_annotations` overwrites existing annotation. For the `morphologizer`, add an additional setting `extend`, which controls whether the existing features are preserved. * +overwrite, +extend: overwrite values of existing features, add any new features * +overwrite, -extend: overwrite completely, removing any existing features * -overwrite, +extend: keep values of existing features, add any new features * -overwrite, -extend: do not modify the existing value if set In all cases an unset value will be set by `set_annotations`. Preserve current overwrite defaults: * True: morphologizer, entity linker * False: tagger, sentencizer, senter * Add backwards compat overwrite settings * Put empty line back Removed by accident in last commit * Set backwards-compatible defaults in __init__ Because the `TrainablePipe` serialization methods update `cfg`, there's no straightforward way to detect whether models serialized with a previous version are missing the overwrite settings. It would be possible in the sentencizer due to its separate serialization methods, however to keep the changes parallel, this also sets the default in `__init__`. * Remove traces Co-authored-by: Paul O'Leary McCann <polm@dampfkraft.com>
This commit is contained in:
parent
8fe525beb5
commit
03fefa37e2
|
@ -20,6 +20,8 @@ from ..util import SimpleFrozenList, registry
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
|
|
||||||
|
# See #9050
|
||||||
|
BACKWARD_OVERWRITE = True
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
|
@ -50,6 +52,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
"incl_context": True,
|
"incl_context": True,
|
||||||
"entity_vector_length": 64,
|
"entity_vector_length": 64,
|
||||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||||
|
"overwrite": True,
|
||||||
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
|
||||||
},
|
},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
|
@ -69,6 +72,7 @@ def make_entity_linker(
|
||||||
incl_context: bool,
|
incl_context: bool,
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
):
|
):
|
||||||
"""Construct an EntityLinker component.
|
"""Construct an EntityLinker component.
|
||||||
|
@ -95,6 +99,7 @@ def make_entity_linker(
|
||||||
incl_context=incl_context,
|
incl_context=incl_context,
|
||||||
entity_vector_length=entity_vector_length,
|
entity_vector_length=entity_vector_length,
|
||||||
get_candidates=get_candidates,
|
get_candidates=get_candidates,
|
||||||
|
overwrite=overwrite,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -128,6 +133,7 @@ class EntityLinker(TrainablePipe):
|
||||||
incl_context: bool,
|
incl_context: bool,
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||||
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize an entity linker.
|
"""Initialize an entity linker.
|
||||||
|
@ -156,7 +162,7 @@ class EntityLinker(TrainablePipe):
|
||||||
self.incl_prior = incl_prior
|
self.incl_prior = incl_prior
|
||||||
self.incl_context = incl_context
|
self.incl_context = incl_context
|
||||||
self.get_candidates = get_candidates
|
self.get_candidates = get_candidates
|
||||||
self.cfg = {}
|
self.cfg = {"overwrite": overwrite}
|
||||||
self.distance = CosineDistance(normalize=False)
|
self.distance = CosineDistance(normalize=False)
|
||||||
# how many neighbour sentences to take into account
|
# how many neighbour sentences to take into account
|
||||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||||
|
@ -399,12 +405,14 @@ class EntityLinker(TrainablePipe):
|
||||||
if count_ents != len(kb_ids):
|
if count_ents != len(kb_ids):
|
||||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||||
i = 0
|
i = 0
|
||||||
|
overwrite = self.cfg["overwrite"]
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
kb_id = kb_ids[i]
|
kb_id = kb_ids[i]
|
||||||
i += 1
|
i += 1
|
||||||
for token in ent:
|
for token in ent:
|
||||||
token.ent_kb_id_ = kb_id
|
if token.ent_kb_id == 0 or overwrite:
|
||||||
|
token.ent_kb_id_ = kb_id
|
||||||
|
|
||||||
def to_bytes(self, *, exclude=tuple()):
|
def to_bytes(self, *, exclude=tuple()):
|
||||||
"""Serialize the pipe to a bytestring.
|
"""Serialize the pipe to a bytestring.
|
||||||
|
|
|
@ -19,6 +19,9 @@ from ..scorer import Scorer
|
||||||
from ..training import validate_examples, validate_get_examples
|
from ..training import validate_examples, validate_get_examples
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
|
|
||||||
|
# See #9050
|
||||||
|
BACKWARD_OVERWRITE = True
|
||||||
|
BACKWARD_EXTEND = False
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
|
@ -49,16 +52,18 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"morphologizer",
|
"morphologizer",
|
||||||
assigns=["token.morph", "token.pos"],
|
assigns=["token.morph", "token.pos"],
|
||||||
default_config={"model": DEFAULT_MORPH_MODEL, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}},
|
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}},
|
||||||
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
|
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
|
||||||
)
|
)
|
||||||
def make_morphologizer(
|
def make_morphologizer(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
model: Model,
|
model: Model,
|
||||||
name: str,
|
name: str,
|
||||||
|
overwrite: bool,
|
||||||
|
extend: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
):
|
):
|
||||||
return Morphologizer(nlp.vocab, model, name, scorer=scorer)
|
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer)
|
||||||
|
|
||||||
|
|
||||||
def morphologizer_score(examples, **kwargs):
|
def morphologizer_score(examples, **kwargs):
|
||||||
|
@ -87,6 +92,8 @@ class Morphologizer(Tagger):
|
||||||
model: Model,
|
model: Model,
|
||||||
name: str = "morphologizer",
|
name: str = "morphologizer",
|
||||||
*,
|
*,
|
||||||
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
|
extend: bool = BACKWARD_EXTEND,
|
||||||
scorer: Optional[Callable] = morphologizer_score,
|
scorer: Optional[Callable] = morphologizer_score,
|
||||||
):
|
):
|
||||||
"""Initialize a morphologizer.
|
"""Initialize a morphologizer.
|
||||||
|
@ -109,7 +116,12 @@ class Morphologizer(Tagger):
|
||||||
# store mappings from morph+POS labels to token-level annotations:
|
# store mappings from morph+POS labels to token-level annotations:
|
||||||
# 1) labels_morph stores a mapping from morph+POS->morph
|
# 1) labels_morph stores a mapping from morph+POS->morph
|
||||||
# 2) labels_pos stores a mapping from morph+POS->POS
|
# 2) labels_pos stores a mapping from morph+POS->POS
|
||||||
cfg = {"labels_morph": {}, "labels_pos": {}}
|
cfg = {
|
||||||
|
"labels_morph": {},
|
||||||
|
"labels_pos": {},
|
||||||
|
"overwrite": overwrite,
|
||||||
|
"extend": extend,
|
||||||
|
}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
|
@ -217,14 +229,34 @@ class Morphologizer(Tagger):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef Vocab vocab = self.vocab
|
cdef Vocab vocab = self.vocab
|
||||||
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
|
cdef bint extend = self.cfg["extend"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
for j, tag_id in enumerate(doc_tag_ids):
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
morph = self.labels[tag_id]
|
morph = self.labels[tag_id]
|
||||||
doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels_morph"].get(morph, 0))
|
# set morph
|
||||||
doc.c[j].pos = self.cfg["labels_pos"].get(morph, 0)
|
if doc.c[j].morph == 0 or overwrite or extend:
|
||||||
|
if overwrite and extend:
|
||||||
|
# morphologizer morph overwrites any existing features
|
||||||
|
# while extending
|
||||||
|
extended_morph = Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph])
|
||||||
|
extended_morph.update(Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0)))
|
||||||
|
doc.c[j].morph = self.vocab.morphology.add(extended_morph)
|
||||||
|
elif extend:
|
||||||
|
# existing features are preserved and any new features
|
||||||
|
# are added
|
||||||
|
extended_morph = Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0))
|
||||||
|
extended_morph.update(Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph]))
|
||||||
|
doc.c[j].morph = self.vocab.morphology.add(extended_morph)
|
||||||
|
else:
|
||||||
|
# clobber
|
||||||
|
doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels_morph"].get(morph, 0))
|
||||||
|
# set POS
|
||||||
|
if doc.c[j].pos == 0 or overwrite:
|
||||||
|
doc.c[j].pos = self.cfg["labels_pos"].get(morph, 0)
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
"""Find the loss and gradient of loss for the batch of documents and
|
"""Find the loss and gradient of loss for the batch of documents and
|
||||||
|
|
|
@ -10,20 +10,23 @@ from ..language import Language
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
# see #9050
|
||||||
|
BACKWARD_OVERWRITE = False
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"sentencizer",
|
"sentencizer",
|
||||||
assigns=["token.is_sent_start", "doc.sents"],
|
assigns=["token.is_sent_start", "doc.sents"],
|
||||||
default_config={"punct_chars": None, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
default_config={"punct_chars": None, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||||
)
|
)
|
||||||
def make_sentencizer(
|
def make_sentencizer(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
name: str,
|
name: str,
|
||||||
punct_chars: Optional[List[str]],
|
punct_chars: Optional[List[str]],
|
||||||
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
):
|
):
|
||||||
return Sentencizer(name, punct_chars=punct_chars, scorer=scorer)
|
return Sentencizer(name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer)
|
||||||
|
|
||||||
|
|
||||||
class Sentencizer(Pipe):
|
class Sentencizer(Pipe):
|
||||||
|
@ -49,6 +52,7 @@ class Sentencizer(Pipe):
|
||||||
name="sentencizer",
|
name="sentencizer",
|
||||||
*,
|
*,
|
||||||
punct_chars=None,
|
punct_chars=None,
|
||||||
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=senter_score,
|
scorer=senter_score,
|
||||||
):
|
):
|
||||||
"""Initialize the sentencizer.
|
"""Initialize the sentencizer.
|
||||||
|
@ -65,6 +69,7 @@ class Sentencizer(Pipe):
|
||||||
self.punct_chars = set(punct_chars)
|
self.punct_chars = set(punct_chars)
|
||||||
else:
|
else:
|
||||||
self.punct_chars = set(self.default_punct_chars)
|
self.punct_chars = set(self.default_punct_chars)
|
||||||
|
self.overwrite = overwrite
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
|
@ -126,8 +131,7 @@ class Sentencizer(Pipe):
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
for j, tag_id in enumerate(doc_tag_ids):
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
# Don't clobber existing sentence boundaries
|
if doc.c[j].sent_start == 0 or self.overwrite:
|
||||||
if doc.c[j].sent_start == 0:
|
|
||||||
if tag_id:
|
if tag_id:
|
||||||
doc.c[j].sent_start = 1
|
doc.c[j].sent_start = 1
|
||||||
else:
|
else:
|
||||||
|
@ -140,7 +144,7 @@ class Sentencizer(Pipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencizer#to_bytes
|
DOCS: https://spacy.io/api/sentencizer#to_bytes
|
||||||
"""
|
"""
|
||||||
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
|
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
"""Load the sentencizer from a bytestring.
|
"""Load the sentencizer from a bytestring.
|
||||||
|
@ -152,6 +156,7 @@ class Sentencizer(Pipe):
|
||||||
"""
|
"""
|
||||||
cfg = srsly.msgpack_loads(bytes_data)
|
cfg = srsly.msgpack_loads(bytes_data)
|
||||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||||
|
self.overwrite = cfg.get("overwrite", self.overwrite)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, *, exclude=tuple()):
|
def to_disk(self, path, *, exclude=tuple()):
|
||||||
|
@ -161,7 +166,7 @@ class Sentencizer(Pipe):
|
||||||
"""
|
"""
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
path = path.with_suffix(".json")
|
path = path.with_suffix(".json")
|
||||||
srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
|
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
||||||
|
|
||||||
|
|
||||||
def from_disk(self, path, *, exclude=tuple()):
|
def from_disk(self, path, *, exclude=tuple()):
|
||||||
|
@ -173,4 +178,5 @@ class Sentencizer(Pipe):
|
||||||
path = path.with_suffix(".json")
|
path = path.with_suffix(".json")
|
||||||
cfg = srsly.read_json(path)
|
cfg = srsly.read_json(path)
|
||||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||||
|
self.overwrite = cfg.get("overwrite", self.overwrite)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -15,6 +15,8 @@ from ..training import validate_examples, validate_get_examples
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
# See #9050
|
||||||
|
BACKWARD_OVERWRITE = False
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
|
@ -36,11 +38,11 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"senter",
|
"senter",
|
||||||
assigns=["token.is_sent_start"],
|
assigns=["token.is_sent_start"],
|
||||||
default_config={"model": DEFAULT_SENTER_MODEL, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||||
)
|
)
|
||||||
def make_senter(nlp: Language, name: str, model: Model, scorer: Optional[Callable]):
|
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]):
|
||||||
return SentenceRecognizer(nlp.vocab, model, name, scorer=scorer)
|
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||||
|
|
||||||
|
|
||||||
def senter_score(examples, **kwargs):
|
def senter_score(examples, **kwargs):
|
||||||
|
@ -62,7 +64,15 @@ class SentenceRecognizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencerecognizer
|
DOCS: https://spacy.io/api/sentencerecognizer
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab, model, name="senter", *, scorer=senter_score):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab,
|
||||||
|
model,
|
||||||
|
name="senter",
|
||||||
|
*,
|
||||||
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
|
scorer=senter_score,
|
||||||
|
):
|
||||||
"""Initialize a sentence recognizer.
|
"""Initialize a sentence recognizer.
|
||||||
|
|
||||||
vocab (Vocab): The shared vocabulary.
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
@ -78,7 +88,7 @@ class SentenceRecognizer(Tagger):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self.cfg = {}
|
self.cfg = {"overwrite": overwrite}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -104,13 +114,13 @@ class SentenceRecognizer(Tagger):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
for j, tag_id in enumerate(doc_tag_ids):
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
# Don't clobber existing sentence boundaries
|
if doc.c[j].sent_start == 0 or overwrite:
|
||||||
if doc.c[j].sent_start == 0:
|
|
||||||
if tag_id == 1:
|
if tag_id == 1:
|
||||||
doc.c[j].sent_start = 1
|
doc.c[j].sent_start = 1
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,6 +22,8 @@ from ..training import validate_examples, validate_get_examples
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
# See #9050
|
||||||
|
BACKWARD_OVERWRITE = False
|
||||||
|
|
||||||
default_model_config = """
|
default_model_config = """
|
||||||
[model]
|
[model]
|
||||||
|
@ -43,10 +45,16 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"tagger",
|
"tagger",
|
||||||
assigns=["token.tag"],
|
assigns=["token.tag"],
|
||||||
default_config={"model": DEFAULT_TAGGER_MODEL, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
||||||
default_score_weights={"tag_acc": 1.0},
|
default_score_weights={"tag_acc": 1.0},
|
||||||
)
|
)
|
||||||
def make_tagger(nlp: Language, name: str, model: Model, scorer: Optional[Callable]):
|
def make_tagger(
|
||||||
|
nlp: Language,
|
||||||
|
name: str,
|
||||||
|
model: Model,
|
||||||
|
overwrite: bool,
|
||||||
|
scorer: Optional[Callable],
|
||||||
|
):
|
||||||
"""Construct a part-of-speech tagger component.
|
"""Construct a part-of-speech tagger component.
|
||||||
|
|
||||||
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||||
|
@ -54,7 +62,7 @@ def make_tagger(nlp: Language, name: str, model: Model, scorer: Optional[Callabl
|
||||||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||||
with the rows summing to 1).
|
with the rows summing to 1).
|
||||||
"""
|
"""
|
||||||
return Tagger(nlp.vocab, model, name, scorer=scorer)
|
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||||
|
|
||||||
|
|
||||||
def tagger_score(examples, **kwargs):
|
def tagger_score(examples, **kwargs):
|
||||||
|
@ -71,7 +79,15 @@ class Tagger(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tagger
|
DOCS: https://spacy.io/api/tagger
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab, model, name="tagger", *, scorer=tagger_score):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab,
|
||||||
|
model,
|
||||||
|
name="tagger",
|
||||||
|
*,
|
||||||
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
|
scorer=tagger_score,
|
||||||
|
):
|
||||||
"""Initialize a part-of-speech tagger.
|
"""Initialize a part-of-speech tagger.
|
||||||
|
|
||||||
vocab (Vocab): The shared vocabulary.
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
@ -87,7 +103,7 @@ class Tagger(TrainablePipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
cfg = {"labels": []}
|
cfg = {"labels": [], "overwrite": overwrite}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
|
@ -149,13 +165,13 @@ class Tagger(TrainablePipe):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef Vocab vocab = self.vocab
|
cdef Vocab vocab = self.vocab
|
||||||
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
for j, tag_id in enumerate(doc_tag_ids):
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
# Don't clobber preset POS tags
|
if doc.c[j].tag == 0 or overwrite:
|
||||||
if doc.c[j].tag == 0:
|
|
||||||
doc.c[j].tag = self.vocab.strings[self.labels[tag_id]]
|
doc.c[j].tag = self.vocab.strings[self.labels[tag_id]]
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
|
|
|
@ -8,6 +8,7 @@ from spacy.language import Language
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
from spacy.morphology import Morphology
|
from spacy.morphology import Morphology
|
||||||
from spacy.attrs import MORPH
|
from spacy.attrs import MORPH
|
||||||
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
def test_label_types():
|
def test_label_types():
|
||||||
|
@ -137,6 +138,41 @@ def test_overfitting_IO():
|
||||||
assert [str(t.morph) for t in doc] == gold_morphs
|
assert [str(t.morph) for t in doc] == gold_morphs
|
||||||
assert [t.pos_ for t in doc] == gold_pos_tags
|
assert [t.pos_ for t in doc] == gold_pos_tags
|
||||||
|
|
||||||
|
# Test overwrite+extend settings
|
||||||
|
# (note that "" is unset, "_" is set and empty)
|
||||||
|
morphs = ["Feat=V", "Feat=N", "_"]
|
||||||
|
doc = Doc(nlp.vocab, words=["blue", "ham", "like"], morphs=morphs)
|
||||||
|
orig_morphs = [str(t.morph) for t in doc]
|
||||||
|
orig_pos_tags = [t.pos_ for t in doc]
|
||||||
|
morphologizer = nlp.get_pipe("morphologizer")
|
||||||
|
|
||||||
|
# don't overwrite or extend
|
||||||
|
morphologizer.cfg["overwrite"] = False
|
||||||
|
doc = morphologizer(doc)
|
||||||
|
assert [str(t.morph) for t in doc] == orig_morphs
|
||||||
|
assert [t.pos_ for t in doc] == orig_pos_tags
|
||||||
|
|
||||||
|
# overwrite and extend
|
||||||
|
morphologizer.cfg["overwrite"] = True
|
||||||
|
morphologizer.cfg["extend"] = True
|
||||||
|
doc = Doc(nlp.vocab, words=["I", "like"], morphs=["Feat=A|That=A|This=A", ""])
|
||||||
|
doc = morphologizer(doc)
|
||||||
|
assert [str(t.morph) for t in doc] == ["Feat=N|That=A|This=A", "Feat=V"]
|
||||||
|
|
||||||
|
# extend without overwriting
|
||||||
|
morphologizer.cfg["overwrite"] = False
|
||||||
|
morphologizer.cfg["extend"] = True
|
||||||
|
doc = Doc(nlp.vocab, words=["I", "like"], morphs=["Feat=A|That=A|This=A", "That=B"])
|
||||||
|
doc = morphologizer(doc)
|
||||||
|
assert [str(t.morph) for t in doc] == ["Feat=A|That=A|This=A", "Feat=V|That=B"]
|
||||||
|
|
||||||
|
# overwrite without extending
|
||||||
|
morphologizer.cfg["overwrite"] = True
|
||||||
|
morphologizer.cfg["extend"] = False
|
||||||
|
doc = Doc(nlp.vocab, words=["I", "like"], morphs=["Feat=A|That=A|This=A", ""])
|
||||||
|
doc = morphologizer(doc)
|
||||||
|
assert [str(t.morph) for t in doc] == ["Feat=N", "Feat=V"]
|
||||||
|
|
||||||
# Test with unset morph and partial POS
|
# Test with unset morph and partial POS
|
||||||
nlp.remove_pipe("morphologizer")
|
nlp.remove_pipe("morphologizer")
|
||||||
nlp.add_pipe("morphologizer")
|
nlp.add_pipe("morphologizer")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user