Revert added_strings change (#6236)

This commit is contained in:
Ines Montani 2020-10-10 18:55:07 +02:00 committed by GitHub
parent 796f8b9424
commit bfa3931c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 110 additions and 94 deletions

View File

@ -1,6 +1,6 @@
# fmt: off
__title__ = "spacy-nightly"
__version__ = "3.0.0a37"
__version__ = "3.0.0a38"
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
__projects__ = "https://github.com/explosion/projects"

View File

@ -456,6 +456,14 @@ class Errors:
"issue tracker: http://github.com/explosion/spaCy/issues")
# TODO: fix numbering after merging develop into master
E898 = ("Can't serialize trainable pipe '{name}': the `model` attribute "
"is not set or None. If you've implemented a custom component, make "
"sure to store the component model as `self.model` in your "
"component's __init__ method.")
E899 = ("Can't serialize trainable pipe '{name}': the `vocab` attribute "
"is not set or None. If you've implemented a custom component, make "
"sure to store the current `nlp` object's vocab as `self.vocab` in "
"your component's __init__ method.")
E900 = ("Could not run the full pipeline for evaluation. If you specified "
"frozen components, make sure they were already initialized and "
"trained. Full pipeline: {pipeline}")

View File

@ -30,7 +30,6 @@ cdef class KnowledgeBase:
cdef Pool mem
cpdef readonly Vocab vocab
cdef int64_t entity_vector_length
cdef public set _added_strings
# This maps 64bit keys (hash of unique entity string)
# to 64bit values (position of the _KBEntryC struct in the _entries vector).

View File

@ -92,7 +92,6 @@ cdef class KnowledgeBase:
self._alias_index = PreshMap()
self.vocab = vocab
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
self._added_strings = set()
@property
def entity_vector_length(self):
@ -114,16 +113,12 @@ cdef class KnowledgeBase:
def get_alias_strings(self):
return [self.vocab.strings[x] for x in self._alias_index]
def add_string(self, string: str):
self._added_strings.add(string)
return self.vocab.strings.add(string)
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
"""
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end.
"""
cdef hash_t entity_hash = self.add_string(entity)
cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before
if entity_hash in self._entry_index:
@ -157,7 +152,7 @@ cdef class KnowledgeBase:
cdef hash_t entity_hash
while i < len(entity_list):
# only process this entity if its unique ID hadn't been added before
entity_hash = self.add_string(entity_list[i])
entity_hash = self.vocab.strings.add(entity_list[i])
if entity_hash in self._entry_index:
warnings.warn(Warnings.W018.format(entity=entity_list[i]))
@ -203,7 +198,7 @@ cdef class KnowledgeBase:
if prob_sum > 1.00001:
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
cdef hash_t alias_hash = self.add_string(alias)
cdef hash_t alias_hash = self.vocab.strings.add(alias)
# Check whether this alias was added before
if alias_hash in self._alias_index:
@ -332,7 +327,7 @@ cdef class KnowledgeBase:
raise ValueError(Errors.E928.format(loc=path))
serialize = {}
serialize["contents"] = lambda p: self.write_contents(p)
serialize["strings.json"] = lambda p: srsly.write_json(p, self._added_strings)
serialize["strings.json"] = lambda p: self.vocab.strings.to_disk(p)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
@ -343,7 +338,7 @@ cdef class KnowledgeBase:
raise ValueError(Errors.E928.format(loc=path))
deserialize = {}
deserialize["contents"] = lambda p: self.read_contents(p)
deserialize["strings.json"] = lambda p: [self.add_string(s) for s in srsly.read_json(p)]
deserialize["strings.json"] = lambda p: self.vocab.strings.from_disk(p)
util.from_disk(path, deserialize, exclude)
def write_contents(self, file_path):

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Union, Iterable, Any, Optional, Callable, Iterator
from typing import List, Dict, Union, Iterable, Any, Optional, Callable
from typing import Tuple
import srsly
from pathlib import Path
@ -57,7 +57,6 @@ class AttributeRuler(Pipe):
self.attrs = []
self._attrs_unnormed = [] # store for reference
self.indices = []
self._added_strings = set()
def clear(self) -> None:
"""Reset all patterns."""
@ -187,16 +186,12 @@ class AttributeRuler(Pipe):
# We need to make a string here, because otherwise the ID we pass back
# will be interpreted as the hash of a string, rather than an ordinal.
key = str(len(self.attrs))
self.matcher.add(self.add_string(key), patterns)
self.matcher.add(self.vocab.strings.add(key), patterns)
self._attrs_unnormed.append(attrs)
attrs = normalize_token_attrs(self.vocab, attrs)
self.attrs.append(attrs)
self.indices.append(index)
def add_string(self, string: str):
self._added_strings.add(string)
return self.vocab.strings.add(string)
def add_patterns(self, patterns: Iterable[AttributeRulerPatternType]) -> None:
"""Add patterns from a list of pattern dicts with the keys as the
arguments to AttributeRuler.add.
@ -256,8 +251,8 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
serialize["strings.json"] = lambda: srsly.json_dumps(sorted(self._added_strings))
return util.to_bytes(serialize, exclude)
def from_bytes(
@ -276,7 +271,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.msgpack_loads(b))
deserialize = {
"strings.json": lambda b: [self.add_string(s) for s in srsly.json_loads(b)],
"vocab": lambda b: self.vocab.from_bytes(b),
"patterns": load_patterns,
}
util.from_bytes(bytes_data, deserialize, exclude)
@ -293,7 +288,7 @@ class AttributeRuler(Pipe):
DOCS: https://nightly.spacy.io/api/attributeruler#to_disk
"""
serialize = {
"strings.json": lambda p: srsly.write_json(p, self._added_strings),
"vocab": lambda p: self.vocab.to_disk(p),
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
}
util.to_disk(path, serialize, exclude)
@ -314,7 +309,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.read_msgpack(p))
deserialize = {
"strings.json": lambda p: [self.add_string(s) for s in srsly.read_json(p)],
"vocab": lambda p: self.vocab.from_disk(p),
"patterns": load_patterns,
}
util.from_disk(path, deserialize, exclude)

View File

@ -453,6 +453,7 @@ class EntityLinker(TrainablePipe):
DOCS: https://nightly.spacy.io/api/entitylinker#to_disk
"""
serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p)
@ -481,8 +482,6 @@ class EntityLinker(TrainablePipe):
deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
for s in self.kb._added_strings:
self.vocab.strings.add(s)
return self
def rehearse(self, examples, *, sgd=None, losses=None, **config):

View File

@ -281,6 +281,7 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#to_disk
"""
serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["lookups"] = lambda p: self.lookups.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -296,6 +297,7 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#from_disk
"""
deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
util.from_disk(path, deserialize, exclude)
self._validate_tables()
@ -310,6 +312,7 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
serialize["lookups"] = self.lookups.to_bytes
return util.to_bytes(serialize, exclude)
@ -325,6 +328,7 @@ class Lemmatizer(Pipe):
DOCS: https://nightly.spacy.io/api/lemmatizer#from_bytes
"""
deserialize = {}
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
util.from_bytes(bytes_data, deserialize, exclude)
self._validate_tables()

View File

@ -95,7 +95,6 @@ class Morphologizer(Tagger):
# add mappings for empty morph
self.cfg["labels_morph"][Morphology.EMPTY_MORPH] = Morphology.EMPTY_MORPH
self.cfg["labels_pos"][Morphology.EMPTY_MORPH] = POS_IDS[""]
self._added_strings = set()
@property
def labels(self):
@ -129,7 +128,6 @@ class Morphologizer(Tagger):
label_dict.pop(self.POS_FEAT)
# normalize morph string and add to morphology table
norm_morph = self.vocab.strings[self.vocab.morphology.add(label_dict)]
self.add_string(norm_morph)
# add label mappings
if norm_label not in self.cfg["labels_morph"]:
self.cfg["labels_morph"][norm_label] = norm_morph
@ -161,7 +159,6 @@ class Morphologizer(Tagger):
if pos:
morph_dict[self.POS_FEAT] = pos
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
self.add_string(norm_label)
# add label->morph and label->POS mappings
if norm_label not in self.cfg["labels_morph"]:
self.cfg["labels_morph"][norm_label] = morph
@ -179,7 +176,6 @@ class Morphologizer(Tagger):
if pos:
morph_dict[self.POS_FEAT] = pos
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
self.add_string(norm_label)
gold_array.append([1.0 if label == norm_label else 0.0 for label in self.labels])
doc_sample.append(example.x)
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
@ -238,7 +234,6 @@ class Morphologizer(Tagger):
if pos:
label_dict[self.POS_FEAT] = pos
label = self.vocab.strings[self.vocab.morphology.add(label_dict)]
self.add_string(label)
eg_truths.append(label)
truths.append(eg_truths)
d_scores, loss = loss_func(scores, truths)

View File

@ -61,7 +61,6 @@ class SentenceRecognizer(Tagger):
self.name = name
self._rehearsal_model = None
self.cfg = {}
self._added_strings = set()
@property
def labels(self):

View File

@ -78,7 +78,6 @@ class Tagger(TrainablePipe):
self._rehearsal_model = None
cfg = {"labels": labels or []}
self.cfg = dict(sorted(cfg.items()))
self._added_strings = set()
@property
def labels(self):
@ -313,7 +312,7 @@ class Tagger(TrainablePipe):
return 0
self._allow_extra_label()
self.cfg["labels"].append(label)
self.add_string(label)
self.vocab.strings.add(label)
return 1
def score(self, examples, **kwargs):

View File

@ -110,7 +110,6 @@ class TextCategorizer(TrainablePipe):
self._rehearsal_model = None
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg)
self._added_strings = set()
@property
def labels(self) -> Tuple[str]:
@ -301,7 +300,7 @@ class TextCategorizer(TrainablePipe):
return 0
self._allow_extra_label()
self.cfg["labels"].append(label)
self.add_string(label)
self.vocab.strings.add(label)
return 1
def initialize(

View File

@ -64,7 +64,6 @@ class Tok2Vec(TrainablePipe):
self.name = name
self.listeners = []
self.cfg = {}
self._added_strings = set()
def add_listener(self, listener: "Tok2VecListener") -> None:
"""Add a listener for a downstream component. Usually internals."""

View File

@ -5,4 +5,3 @@ cdef class TrainablePipe(Pipe):
cdef public Vocab vocab
cdef public object model
cdef public object cfg
cdef public set _added_strings

View File

@ -13,6 +13,7 @@ from ..vocab import Vocab
from ..language import Language
from ..training import Example
cdef class TrainablePipe(Pipe):
"""This class is a base class and not instantiated directly. Trainable
pipeline components like the EntityRecognizer or TextCategorizer inherit
@ -35,7 +36,6 @@ cdef class TrainablePipe(Pipe):
self.model = model
self.name = name
self.cfg = dict(cfg)
self._added_strings = set()
def __call__(self, Doc doc) -> Doc:
"""Apply the pipe to one document. The document is modified in place,
@ -198,10 +198,6 @@ cdef class TrainablePipe(Pipe):
"""
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
def add_string(self, string: str):
self._added_strings.add(string)
return self.vocab.strings.add(string)
@property
def is_trainable(self) -> bool:
return True
@ -244,6 +240,16 @@ cdef class TrainablePipe(Pipe):
"""
self.model.finish_update(sgd)
def _validate_serialization_attrs(self):
"""Check that the pipe implements the required attributes. If a subclass
implements a custom __init__ method but doesn't set these attributes,
the currently default to None, so we need to perform additonal checks.
"""
if not hasattr(self, "vocab") or self.vocab is None:
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))
if not hasattr(self, "model") or self.model is None:
raise ValueError(Errors.E898.format(name=util.get_object_name(self)))
def to_bytes(self, *, exclude=tuple()):
"""Serialize the pipe to a bytestring.
@ -252,11 +258,12 @@ cdef class TrainablePipe(Pipe):
DOCS: https://nightly.spacy.io/api/pipe#to_bytes
"""
self._validate_serialization_attrs()
serialize = {}
if hasattr(self, "cfg"):
if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["vocab"] = self.vocab.to_bytes
serialize["model"] = self.model.to_bytes
serialize["strings.json"] = lambda: srsly.json_dumps(sorted(self._added_strings))
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, *, exclude=tuple()):
@ -267,6 +274,7 @@ cdef class TrainablePipe(Pipe):
DOCS: https://nightly.spacy.io/api/pipe#from_bytes
"""
self._validate_serialization_attrs()
def load_model(b):
try:
@ -275,9 +283,9 @@ cdef class TrainablePipe(Pipe):
raise ValueError(Errors.E149) from None
deserialize = {}
deserialize["strings.json"] = lambda b: [self.add_string(s) for s in srsly.json_loads(b)]
if hasattr(self, "cfg"):
if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude)
return self
@ -290,10 +298,11 @@ cdef class TrainablePipe(Pipe):
DOCS: https://nightly.spacy.io/api/pipe#to_disk
"""
self._validate_serialization_attrs()
serialize = {}
if hasattr(self, "cfg"):
if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["strings.json"] = lambda p: srsly.write_json(p, self._added_strings)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -306,6 +315,7 @@ cdef class TrainablePipe(Pipe):
DOCS: https://nightly.spacy.io/api/pipe#from_disk
"""
self._validate_serialization_attrs()
def load_model(p):
try:
@ -314,9 +324,9 @@ cdef class TrainablePipe(Pipe):
raise ValueError(Errors.E149) from None
deserialize = {}
deserialize["strings.json"] = lambda p: [self.add_string(s) for s in srsly.read_json(p)]
if hasattr(self, "cfg"):
if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
return self

View File

@ -76,7 +76,6 @@ cdef class Parser(TrainablePipe):
self.add_multitask_objective(multitask)
self._rehearsal_model = None
self._added_strings = set()
def __getnewargs_ex__(self):
"""This allows pickling the Parser and its keyword-only init arguments"""
@ -120,7 +119,7 @@ cdef class Parser(TrainablePipe):
resized = True
if resized:
self._resize()
self.add_string(label)
self.vocab.strings.add(label)
return 1
return 0
@ -456,24 +455,24 @@ cdef class Parser(TrainablePipe):
def to_disk(self, path, exclude=tuple()):
serializers = {
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
'strings.json': lambda p: srsly.write_json(p, self._added_strings),
'moves': lambda p: self.moves.to_disk(p, exclude=["strings"]),
'cfg': lambda p: srsly.write_json(p, self.cfg)
"model": lambda p: (self.model.to_disk(p) if self.model is not True else True),
"vocab": lambda p: self.vocab.to_disk(p),
"moves": lambda p: self.moves.to_disk(p, exclude=["strings"]),
"cfg": lambda p: srsly.write_json(p, self.cfg)
}
util.to_disk(path, serializers, exclude)
def from_disk(self, path, exclude=tuple()):
deserializers = {
'strings.json': lambda p: [self.add_string(s) for s in srsly.read_json(p)],
'moves': lambda p: self.moves.from_disk(p, exclude=["strings"]),
'cfg': lambda p: self.cfg.update(srsly.read_json(p)),
'model': lambda p: None,
"vocab": lambda p: self.vocab.from_disk(p),
"moves": lambda p: self.moves.from_disk(p, exclude=["strings"]),
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"model": lambda p: None,
}
util.from_disk(path, deserializers, exclude)
if 'model' not in exclude:
if "model" not in exclude:
path = util.ensure_path(path)
with (path / 'model').open('rb') as file_:
with (path / "model").open("rb") as file_:
bytes_data = file_.read()
try:
self._resize()
@ -485,7 +484,7 @@ cdef class Parser(TrainablePipe):
def to_bytes(self, exclude=tuple()):
serializers = {
"model": lambda: (self.model.to_bytes()),
"strings.json": lambda: srsly.json_dumps(sorted(self._added_strings)),
"vocab": lambda: self.vocab.to_bytes(),
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
}
@ -493,7 +492,7 @@ cdef class Parser(TrainablePipe):
def from_bytes(self, bytes_data, exclude=tuple()):
deserializers = {
"strings.json": lambda b: [self.add_string(s) for s in srsly.json_loads(b)],
"vocab": lambda b: self.vocab.from_bytes(b),
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"model": lambda b: None,

View File

@ -121,9 +121,7 @@ def test_kb_default(nlp):
def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe(
"entity_linker", config={"entity_vector_length": 35}
)
entity_linker = nlp.add_pipe("entity_linker", config={"entity_vector_length": 35})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
@ -213,16 +211,11 @@ def test_el_pipe_configuration(nlp):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
kb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
)
kb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
return kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only
entity_linker = nlp.add_pipe(
"entity_linker",
config={"incl_context": False},
)
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},)
entity_linker.set_kb(create_kb)
# With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same."
@ -453,14 +446,10 @@ def test_overfitting_IO():
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe(
"entity_linker",
last=True,
)
entity_linker = nlp.add_pipe("entity_linker", last=True,)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
assert "Q2146908" in entity_linker.kb.vocab.strings
assert "Q2146908" in entity_linker.kb._added_strings
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)

View File

@ -101,4 +101,3 @@ def test_overfitting_IO():
doc2 = nlp2(test_text)
assert [str(t.morph) for t in doc2] == gold_morphs
assert [t.pos_ for t in doc2] == gold_pos_tags
assert nlp.get_pipe("morphologizer")._added_strings == nlp2.get_pipe("morphologizer")._added_strings

View File

@ -80,4 +80,3 @@ def test_overfitting_IO():
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
assert [int(t.is_sent_start) for t in doc2] == gold_sent_starts
assert nlp.get_pipe("senter")._added_strings == nlp2.get_pipe("senter")._added_strings

View File

@ -98,7 +98,6 @@ def test_overfitting_IO():
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses["tagger"] < 0.00001
assert tagger._added_strings == {"J", "N", "V"}
# test the trained model
test_text = "I like blue eggs"
@ -117,7 +116,6 @@ def test_overfitting_IO():
assert doc2[1].tag_ is "V"
assert doc2[2].tag_ is "J"
assert doc2[3].tag_ is "N"
assert nlp2.get_pipe("tagger")._added_strings == {"J", "N", "V"}
def test_tagger_requires_labels():

View File

@ -146,7 +146,6 @@ def test_overfitting_IO():
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert textcat.model.get_dim("nO") == 2
assert textcat._added_strings == {"NEGATIVE", "POSITIVE"}
for i in range(50):
losses = {}
@ -168,7 +167,6 @@ def test_overfitting_IO():
cats2 = doc2.cats
assert cats2["POSITIVE"] > 0.9
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
assert nlp2.get_pipe("textcat")._added_strings == {"NEGATIVE", "POSITIVE"}
# Test scoring
scores = nlp.evaluate(train_examples)

View File

@ -7,6 +7,7 @@ from spacy.kb import KnowledgeBase, Writer
from spacy.vectors import Vectors
from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.vocab import Vocab
from ..util import make_tempdir
@ -50,8 +51,9 @@ def custom_pipe():
else:
self.cfg = None
self.model = SerializableDummy()
self.vocab = vocab
return MyPipe(None)
return MyPipe(Vocab())
def tagger():

View File

@ -1,13 +1,13 @@
import pytest
import srsly
from spacy import registry, Vocab
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
from spacy.pipeline import TextCategorizer, SentenceRecognizer
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
from spacy.lang.en import English
from thinc.api import Linear
import spacy
from ..util import make_tempdir
@ -89,7 +89,6 @@ def test_serialize_parser_strings(Parser):
assert label not in vocab2.strings
parser2 = Parser(vocab2, model, **config)
parser2 = parser2.from_bytes(parser1.to_bytes(exclude=["vocab"]))
assert parser1._added_strings == parser2._added_strings == {"FunnyLabel"}
assert label in parser2.vocab.strings
@ -166,17 +165,13 @@ def test_serialize_tagger_strings(en_vocab, de_vocab, taggers):
# check that custom labels are serialized as part of the component's strings.jsonl
tagger.add_label(label)
assert label in tagger.vocab.strings
assert tagger._added_strings == {label}
file_path = d / "tagger1"
tagger.to_disk(file_path)
strings = srsly.read_json(file_path / "strings.json")
assert strings == ["SomeWeirdLabel"]
# ensure that the custom strings are loaded back in when using the tagger in another pipeline
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.resolve(cfg, validate=True)["model"]
tagger2 = Tagger(de_vocab, model).from_disk(file_path)
assert label in tagger2.vocab.strings
assert tagger2._added_strings == {label}
def test_serialize_textcat_empty(en_vocab):
@ -253,3 +248,40 @@ def test_serialize_pipeline_disable_enable():
assert nlp5.pipe_names == ["ner"]
assert nlp5.component_names == ["ner"]
assert nlp5.disabled == []
def test_serialize_custom_trainable_pipe():
class BadCustomPipe1(TrainablePipe):
def __init__(self, vocab):
pass
class BadCustomPipe2(TrainablePipe):
def __init__(self, vocab):
self.vocab = vocab
self.model = None
class CustomPipe(TrainablePipe):
def __init__(self, vocab, model):
self.vocab = vocab
self.model = model
pipe = BadCustomPipe1(Vocab())
with pytest.raises(ValueError):
pipe.to_bytes()
with make_tempdir() as d:
with pytest.raises(ValueError):
pipe.to_disk(d)
pipe = BadCustomPipe2(Vocab())
with pytest.raises(ValueError):
pipe.to_bytes()
with make_tempdir() as d:
with pytest.raises(ValueError):
pipe.to_disk(d)
pipe = CustomPipe(Vocab(), Linear())
pipe_bytes = pipe.to_bytes()
new_pipe = CustomPipe(Vocab(), Linear()).from_bytes(pipe_bytes)
assert new_pipe.to_bytes() == pipe_bytes
with make_tempdir() as d:
pipe.to_disk(d)
new_pipe = CustomPipe(Vocab(), Linear()).from_disk(d)
assert new_pipe.to_bytes() == pipe_bytes

View File

@ -821,7 +821,7 @@ def get_object_name(obj: Any) -> str:
obj (Any): The Python object, typically a function or class.
RETURNS (str): A human-readable name.
"""
if hasattr(obj, "name"):
if hasattr(obj, "name") and obj.name is not None:
return obj.name
if hasattr(obj, "__name__"):
return obj.__name__