mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Revert added_strings change (#6236)
This commit is contained in:
parent
796f8b9424
commit
bfa3931c9d
|
@ -1,6 +1,6 @@
|
||||||
# fmt: off
|
# fmt: off
|
||||||
__title__ = "spacy-nightly"
|
__title__ = "spacy-nightly"
|
||||||
__version__ = "3.0.0a37"
|
__version__ = "3.0.0a38"
|
||||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||||
__projects__ = "https://github.com/explosion/projects"
|
__projects__ = "https://github.com/explosion/projects"
|
||||||
|
|
|
@ -456,6 +456,14 @@ class Errors:
|
||||||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E898 = ("Can't serialize trainable pipe '{name}': the `model` attribute "
|
||||||
|
"is not set or None. If you've implemented a custom component, make "
|
||||||
|
"sure to store the component model as `self.model` in your "
|
||||||
|
"component's __init__ method.")
|
||||||
|
E899 = ("Can't serialize trainable pipe '{name}': the `vocab` attribute "
|
||||||
|
"is not set or None. If you've implemented a custom component, make "
|
||||||
|
"sure to store the current `nlp` object's vocab as `self.vocab` in "
|
||||||
|
"your component's __init__ method.")
|
||||||
E900 = ("Could not run the full pipeline for evaluation. If you specified "
|
E900 = ("Could not run the full pipeline for evaluation. If you specified "
|
||||||
"frozen components, make sure they were already initialized and "
|
"frozen components, make sure they were already initialized and "
|
||||||
"trained. Full pipeline: {pipeline}")
|
"trained. Full pipeline: {pipeline}")
|
||||||
|
|
|
@ -30,7 +30,6 @@ cdef class KnowledgeBase:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cpdef readonly Vocab vocab
|
cpdef readonly Vocab vocab
|
||||||
cdef int64_t entity_vector_length
|
cdef int64_t entity_vector_length
|
||||||
cdef public set _added_strings
|
|
||||||
|
|
||||||
# This maps 64bit keys (hash of unique entity string)
|
# This maps 64bit keys (hash of unique entity string)
|
||||||
# to 64bit values (position of the _KBEntryC struct in the _entries vector).
|
# to 64bit values (position of the _KBEntryC struct in the _entries vector).
|
||||||
|
|
15
spacy/kb.pyx
15
spacy/kb.pyx
|
@ -92,7 +92,6 @@ cdef class KnowledgeBase:
|
||||||
self._alias_index = PreshMap()
|
self._alias_index = PreshMap()
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity_vector_length(self):
|
def entity_vector_length(self):
|
||||||
|
@ -114,16 +113,12 @@ cdef class KnowledgeBase:
|
||||||
def get_alias_strings(self):
|
def get_alias_strings(self):
|
||||||
return [self.vocab.strings[x] for x in self._alias_index]
|
return [self.vocab.strings[x] for x in self._alias_index]
|
||||||
|
|
||||||
def add_string(self, string: str):
|
|
||||||
self._added_strings.add(string)
|
|
||||||
return self.vocab.strings.add(string)
|
|
||||||
|
|
||||||
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
||||||
"""
|
"""
|
||||||
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||||
Return the hash of the entity ID/name at the end.
|
Return the hash of the entity ID/name at the end.
|
||||||
"""
|
"""
|
||||||
cdef hash_t entity_hash = self.add_string(entity)
|
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||||
|
|
||||||
# Return if this entity was added before
|
# Return if this entity was added before
|
||||||
if entity_hash in self._entry_index:
|
if entity_hash in self._entry_index:
|
||||||
|
@ -157,7 +152,7 @@ cdef class KnowledgeBase:
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
while i < len(entity_list):
|
while i < len(entity_list):
|
||||||
# only process this entity if its unique ID hadn't been added before
|
# only process this entity if its unique ID hadn't been added before
|
||||||
entity_hash = self.add_string(entity_list[i])
|
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||||
if entity_hash in self._entry_index:
|
if entity_hash in self._entry_index:
|
||||||
warnings.warn(Warnings.W018.format(entity=entity_list[i]))
|
warnings.warn(Warnings.W018.format(entity=entity_list[i]))
|
||||||
|
|
||||||
|
@ -203,7 +198,7 @@ cdef class KnowledgeBase:
|
||||||
if prob_sum > 1.00001:
|
if prob_sum > 1.00001:
|
||||||
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
|
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
|
||||||
|
|
||||||
cdef hash_t alias_hash = self.add_string(alias)
|
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||||
|
|
||||||
# Check whether this alias was added before
|
# Check whether this alias was added before
|
||||||
if alias_hash in self._alias_index:
|
if alias_hash in self._alias_index:
|
||||||
|
@ -332,7 +327,7 @@ cdef class KnowledgeBase:
|
||||||
raise ValueError(Errors.E928.format(loc=path))
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
serialize = {}
|
serialize = {}
|
||||||
serialize["contents"] = lambda p: self.write_contents(p)
|
serialize["contents"] = lambda p: self.write_contents(p)
|
||||||
serialize["strings.json"] = lambda p: srsly.write_json(p, self._added_strings)
|
serialize["strings.json"] = lambda p: self.vocab.strings.to_disk(p)
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||||
|
@ -343,7 +338,7 @@ cdef class KnowledgeBase:
|
||||||
raise ValueError(Errors.E928.format(loc=path))
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
deserialize = {}
|
deserialize = {}
|
||||||
deserialize["contents"] = lambda p: self.read_contents(p)
|
deserialize["contents"] = lambda p: self.read_contents(p)
|
||||||
deserialize["strings.json"] = lambda p: [self.add_string(s) for s in srsly.read_json(p)]
|
deserialize["strings.json"] = lambda p: self.vocab.strings.from_disk(p)
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
||||||
def write_contents(self, file_path):
|
def write_contents(self, file_path):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Dict, Union, Iterable, Any, Optional, Callable, Iterator
|
from typing import List, Dict, Union, Iterable, Any, Optional, Callable
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import srsly
|
import srsly
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -57,7 +57,6 @@ class AttributeRuler(Pipe):
|
||||||
self.attrs = []
|
self.attrs = []
|
||||||
self._attrs_unnormed = [] # store for reference
|
self._attrs_unnormed = [] # store for reference
|
||||||
self.indices = []
|
self.indices = []
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Reset all patterns."""
|
"""Reset all patterns."""
|
||||||
|
@ -187,16 +186,12 @@ class AttributeRuler(Pipe):
|
||||||
# We need to make a string here, because otherwise the ID we pass back
|
# We need to make a string here, because otherwise the ID we pass back
|
||||||
# will be interpreted as the hash of a string, rather than an ordinal.
|
# will be interpreted as the hash of a string, rather than an ordinal.
|
||||||
key = str(len(self.attrs))
|
key = str(len(self.attrs))
|
||||||
self.matcher.add(self.add_string(key), patterns)
|
self.matcher.add(self.vocab.strings.add(key), patterns)
|
||||||
self._attrs_unnormed.append(attrs)
|
self._attrs_unnormed.append(attrs)
|
||||||
attrs = normalize_token_attrs(self.vocab, attrs)
|
attrs = normalize_token_attrs(self.vocab, attrs)
|
||||||
self.attrs.append(attrs)
|
self.attrs.append(attrs)
|
||||||
self.indices.append(index)
|
self.indices.append(index)
|
||||||
|
|
||||||
def add_string(self, string: str):
|
|
||||||
self._added_strings.add(string)
|
|
||||||
return self.vocab.strings.add(string)
|
|
||||||
|
|
||||||
def add_patterns(self, patterns: Iterable[AttributeRulerPatternType]) -> None:
|
def add_patterns(self, patterns: Iterable[AttributeRulerPatternType]) -> None:
|
||||||
"""Add patterns from a list of pattern dicts with the keys as the
|
"""Add patterns from a list of pattern dicts with the keys as the
|
||||||
arguments to AttributeRuler.add.
|
arguments to AttributeRuler.add.
|
||||||
|
@ -256,8 +251,8 @@ class AttributeRuler(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/attributeruler#to_bytes
|
DOCS: https://nightly.spacy.io/api/attributeruler#to_bytes
|
||||||
"""
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
|
serialize["vocab"] = self.vocab.to_bytes
|
||||||
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
|
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
|
||||||
serialize["strings.json"] = lambda: srsly.json_dumps(sorted(self._added_strings))
|
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(
|
def from_bytes(
|
||||||
|
@ -276,7 +271,7 @@ class AttributeRuler(Pipe):
|
||||||
self.add_patterns(srsly.msgpack_loads(b))
|
self.add_patterns(srsly.msgpack_loads(b))
|
||||||
|
|
||||||
deserialize = {
|
deserialize = {
|
||||||
"strings.json": lambda b: [self.add_string(s) for s in srsly.json_loads(b)],
|
"vocab": lambda b: self.vocab.from_bytes(b),
|
||||||
"patterns": load_patterns,
|
"patterns": load_patterns,
|
||||||
}
|
}
|
||||||
util.from_bytes(bytes_data, deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
|
@ -293,7 +288,7 @@ class AttributeRuler(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/attributeruler#to_disk
|
DOCS: https://nightly.spacy.io/api/attributeruler#to_disk
|
||||||
"""
|
"""
|
||||||
serialize = {
|
serialize = {
|
||||||
"strings.json": lambda p: srsly.write_json(p, self._added_strings),
|
"vocab": lambda p: self.vocab.to_disk(p),
|
||||||
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
|
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
|
||||||
}
|
}
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
@ -314,7 +309,7 @@ class AttributeRuler(Pipe):
|
||||||
self.add_patterns(srsly.read_msgpack(p))
|
self.add_patterns(srsly.read_msgpack(p))
|
||||||
|
|
||||||
deserialize = {
|
deserialize = {
|
||||||
"strings.json": lambda p: [self.add_string(s) for s in srsly.read_json(p)],
|
"vocab": lambda p: self.vocab.from_disk(p),
|
||||||
"patterns": load_patterns,
|
"patterns": load_patterns,
|
||||||
}
|
}
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
|
@ -453,6 +453,7 @@ class EntityLinker(TrainablePipe):
|
||||||
DOCS: https://nightly.spacy.io/api/entitylinker#to_disk
|
DOCS: https://nightly.spacy.io/api/entitylinker#to_disk
|
||||||
"""
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
serialize["kb"] = lambda p: self.kb.to_disk(p)
|
serialize["kb"] = lambda p: self.kb.to_disk(p)
|
||||||
serialize["model"] = lambda p: self.model.to_disk(p)
|
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||||
|
@ -481,8 +482,6 @@ class EntityLinker(TrainablePipe):
|
||||||
deserialize["kb"] = lambda p: self.kb.from_disk(p)
|
deserialize["kb"] = lambda p: self.kb.from_disk(p)
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
for s in self.kb._added_strings:
|
|
||||||
self.vocab.strings.add(s)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||||
|
|
|
@ -281,6 +281,7 @@ class Lemmatizer(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/lemmatizer#to_disk
|
DOCS: https://nightly.spacy.io/api/lemmatizer#to_disk
|
||||||
"""
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||||
serialize["lookups"] = lambda p: self.lookups.to_disk(p)
|
serialize["lookups"] = lambda p: self.lookups.to_disk(p)
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -296,6 +297,7 @@ class Lemmatizer(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/lemmatizer#from_disk
|
DOCS: https://nightly.spacy.io/api/lemmatizer#from_disk
|
||||||
"""
|
"""
|
||||||
deserialize = {}
|
deserialize = {}
|
||||||
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||||
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
|
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
self._validate_tables()
|
self._validate_tables()
|
||||||
|
@ -310,6 +312,7 @@ class Lemmatizer(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/lemmatizer#to_bytes
|
DOCS: https://nightly.spacy.io/api/lemmatizer#to_bytes
|
||||||
"""
|
"""
|
||||||
serialize = {}
|
serialize = {}
|
||||||
|
serialize["vocab"] = self.vocab.to_bytes
|
||||||
serialize["lookups"] = self.lookups.to_bytes
|
serialize["lookups"] = self.lookups.to_bytes
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
|
@ -325,6 +328,7 @@ class Lemmatizer(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/lemmatizer#from_bytes
|
DOCS: https://nightly.spacy.io/api/lemmatizer#from_bytes
|
||||||
"""
|
"""
|
||||||
deserialize = {}
|
deserialize = {}
|
||||||
|
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||||
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
|
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
|
||||||
util.from_bytes(bytes_data, deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
self._validate_tables()
|
self._validate_tables()
|
||||||
|
|
|
@ -95,7 +95,6 @@ class Morphologizer(Tagger):
|
||||||
# add mappings for empty morph
|
# add mappings for empty morph
|
||||||
self.cfg["labels_morph"][Morphology.EMPTY_MORPH] = Morphology.EMPTY_MORPH
|
self.cfg["labels_morph"][Morphology.EMPTY_MORPH] = Morphology.EMPTY_MORPH
|
||||||
self.cfg["labels_pos"][Morphology.EMPTY_MORPH] = POS_IDS[""]
|
self.cfg["labels_pos"][Morphology.EMPTY_MORPH] = POS_IDS[""]
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -129,7 +128,6 @@ class Morphologizer(Tagger):
|
||||||
label_dict.pop(self.POS_FEAT)
|
label_dict.pop(self.POS_FEAT)
|
||||||
# normalize morph string and add to morphology table
|
# normalize morph string and add to morphology table
|
||||||
norm_morph = self.vocab.strings[self.vocab.morphology.add(label_dict)]
|
norm_morph = self.vocab.strings[self.vocab.morphology.add(label_dict)]
|
||||||
self.add_string(norm_morph)
|
|
||||||
# add label mappings
|
# add label mappings
|
||||||
if norm_label not in self.cfg["labels_morph"]:
|
if norm_label not in self.cfg["labels_morph"]:
|
||||||
self.cfg["labels_morph"][norm_label] = norm_morph
|
self.cfg["labels_morph"][norm_label] = norm_morph
|
||||||
|
@ -161,7 +159,6 @@ class Morphologizer(Tagger):
|
||||||
if pos:
|
if pos:
|
||||||
morph_dict[self.POS_FEAT] = pos
|
morph_dict[self.POS_FEAT] = pos
|
||||||
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
|
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
|
||||||
self.add_string(norm_label)
|
|
||||||
# add label->morph and label->POS mappings
|
# add label->morph and label->POS mappings
|
||||||
if norm_label not in self.cfg["labels_morph"]:
|
if norm_label not in self.cfg["labels_morph"]:
|
||||||
self.cfg["labels_morph"][norm_label] = morph
|
self.cfg["labels_morph"][norm_label] = morph
|
||||||
|
@ -179,7 +176,6 @@ class Morphologizer(Tagger):
|
||||||
if pos:
|
if pos:
|
||||||
morph_dict[self.POS_FEAT] = pos
|
morph_dict[self.POS_FEAT] = pos
|
||||||
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
|
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
|
||||||
self.add_string(norm_label)
|
|
||||||
gold_array.append([1.0 if label == norm_label else 0.0 for label in self.labels])
|
gold_array.append([1.0 if label == norm_label else 0.0 for label in self.labels])
|
||||||
doc_sample.append(example.x)
|
doc_sample.append(example.x)
|
||||||
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||||
|
@ -238,7 +234,6 @@ class Morphologizer(Tagger):
|
||||||
if pos:
|
if pos:
|
||||||
label_dict[self.POS_FEAT] = pos
|
label_dict[self.POS_FEAT] = pos
|
||||||
label = self.vocab.strings[self.vocab.morphology.add(label_dict)]
|
label = self.vocab.strings[self.vocab.morphology.add(label_dict)]
|
||||||
self.add_string(label)
|
|
||||||
eg_truths.append(label)
|
eg_truths.append(label)
|
||||||
truths.append(eg_truths)
|
truths.append(eg_truths)
|
||||||
d_scores, loss = loss_func(scores, truths)
|
d_scores, loss = loss_func(scores, truths)
|
||||||
|
|
|
@ -61,7 +61,6 @@ class SentenceRecognizer(Tagger):
|
||||||
self.name = name
|
self.name = name
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self.cfg = {}
|
self.cfg = {}
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
|
@ -78,7 +78,6 @@ class Tagger(TrainablePipe):
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
cfg = {"labels": labels or []}
|
cfg = {"labels": labels or []}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -313,7 +312,7 @@ class Tagger(TrainablePipe):
|
||||||
return 0
|
return 0
|
||||||
self._allow_extra_label()
|
self._allow_extra_label()
|
||||||
self.cfg["labels"].append(label)
|
self.cfg["labels"].append(label)
|
||||||
self.add_string(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
|
|
|
@ -110,7 +110,6 @@ class TextCategorizer(TrainablePipe):
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self) -> Tuple[str]:
|
def labels(self) -> Tuple[str]:
|
||||||
|
@ -301,7 +300,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
return 0
|
return 0
|
||||||
self._allow_extra_label()
|
self._allow_extra_label()
|
||||||
self.cfg["labels"].append(label)
|
self.cfg["labels"].append(label)
|
||||||
self.add_string(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
|
|
|
@ -64,7 +64,6 @@ class Tok2Vec(TrainablePipe):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.listeners = []
|
self.listeners = []
|
||||||
self.cfg = {}
|
self.cfg = {}
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
def add_listener(self, listener: "Tok2VecListener") -> None:
|
def add_listener(self, listener: "Tok2VecListener") -> None:
|
||||||
"""Add a listener for a downstream component. Usually internals."""
|
"""Add a listener for a downstream component. Usually internals."""
|
||||||
|
|
|
@ -5,4 +5,3 @@ cdef class TrainablePipe(Pipe):
|
||||||
cdef public Vocab vocab
|
cdef public Vocab vocab
|
||||||
cdef public object model
|
cdef public object model
|
||||||
cdef public object cfg
|
cdef public object cfg
|
||||||
cdef public set _added_strings
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ..vocab import Vocab
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..training import Example
|
from ..training import Example
|
||||||
|
|
||||||
|
|
||||||
cdef class TrainablePipe(Pipe):
|
cdef class TrainablePipe(Pipe):
|
||||||
"""This class is a base class and not instantiated directly. Trainable
|
"""This class is a base class and not instantiated directly. Trainable
|
||||||
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
||||||
|
@ -35,7 +36,6 @@ cdef class TrainablePipe(Pipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
def __call__(self, Doc doc) -> Doc:
|
def __call__(self, Doc doc) -> Doc:
|
||||||
"""Apply the pipe to one document. The document is modified in place,
|
"""Apply the pipe to one document. The document is modified in place,
|
||||||
|
@ -198,10 +198,6 @@ cdef class TrainablePipe(Pipe):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="add_label", name=self.name))
|
||||||
|
|
||||||
def add_string(self, string: str):
|
|
||||||
self._added_strings.add(string)
|
|
||||||
return self.vocab.strings.add(string)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -244,6 +240,16 @@ cdef class TrainablePipe(Pipe):
|
||||||
"""
|
"""
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
|
|
||||||
|
def _validate_serialization_attrs(self):
|
||||||
|
"""Check that the pipe implements the required attributes. If a subclass
|
||||||
|
implements a custom __init__ method but doesn't set these attributes,
|
||||||
|
the currently default to None, so we need to perform additonal checks.
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "vocab") or self.vocab is None:
|
||||||
|
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))
|
||||||
|
if not hasattr(self, "model") or self.model is None:
|
||||||
|
raise ValueError(Errors.E898.format(name=util.get_object_name(self)))
|
||||||
|
|
||||||
def to_bytes(self, *, exclude=tuple()):
|
def to_bytes(self, *, exclude=tuple()):
|
||||||
"""Serialize the pipe to a bytestring.
|
"""Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
|
@ -252,11 +258,12 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/pipe#to_bytes
|
DOCS: https://nightly.spacy.io/api/pipe#to_bytes
|
||||||
"""
|
"""
|
||||||
|
self._validate_serialization_attrs()
|
||||||
serialize = {}
|
serialize = {}
|
||||||
if hasattr(self, "cfg"):
|
if hasattr(self, "cfg") and self.cfg is not None:
|
||||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||||
|
serialize["vocab"] = self.vocab.to_bytes
|
||||||
serialize["model"] = self.model.to_bytes
|
serialize["model"] = self.model.to_bytes
|
||||||
serialize["strings.json"] = lambda: srsly.json_dumps(sorted(self._added_strings))
|
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
|
@ -267,6 +274,7 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/pipe#from_bytes
|
DOCS: https://nightly.spacy.io/api/pipe#from_bytes
|
||||||
"""
|
"""
|
||||||
|
self._validate_serialization_attrs()
|
||||||
|
|
||||||
def load_model(b):
|
def load_model(b):
|
||||||
try:
|
try:
|
||||||
|
@ -275,9 +283,9 @@ cdef class TrainablePipe(Pipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
deserialize = {}
|
deserialize = {}
|
||||||
deserialize["strings.json"] = lambda b: [self.add_string(s) for s in srsly.json_loads(b)]
|
if hasattr(self, "cfg") and self.cfg is not None:
|
||||||
if hasattr(self, "cfg"):
|
|
||||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||||
|
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
util.from_bytes(bytes_data, deserialize, exclude)
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -290,10 +298,11 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/pipe#to_disk
|
DOCS: https://nightly.spacy.io/api/pipe#to_disk
|
||||||
"""
|
"""
|
||||||
|
self._validate_serialization_attrs()
|
||||||
serialize = {}
|
serialize = {}
|
||||||
if hasattr(self, "cfg"):
|
if hasattr(self, "cfg") and self.cfg is not None:
|
||||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
serialize["strings.json"] = lambda p: srsly.write_json(p, self._added_strings)
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||||
serialize["model"] = lambda p: self.model.to_disk(p)
|
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -306,6 +315,7 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/pipe#from_disk
|
DOCS: https://nightly.spacy.io/api/pipe#from_disk
|
||||||
"""
|
"""
|
||||||
|
self._validate_serialization_attrs()
|
||||||
|
|
||||||
def load_model(p):
|
def load_model(p):
|
||||||
try:
|
try:
|
||||||
|
@ -314,9 +324,9 @@ cdef class TrainablePipe(Pipe):
|
||||||
raise ValueError(Errors.E149) from None
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
deserialize = {}
|
deserialize = {}
|
||||||
deserialize["strings.json"] = lambda p: [self.add_string(s) for s in srsly.read_json(p)]
|
if hasattr(self, "cfg") and self.cfg is not None:
|
||||||
if hasattr(self, "cfg"):
|
|
||||||
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
||||||
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||||
deserialize["model"] = load_model
|
deserialize["model"] = load_model
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -76,7 +76,6 @@ cdef class Parser(TrainablePipe):
|
||||||
self.add_multitask_objective(multitask)
|
self.add_multitask_objective(multitask)
|
||||||
|
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self._added_strings = set()
|
|
||||||
|
|
||||||
def __getnewargs_ex__(self):
|
def __getnewargs_ex__(self):
|
||||||
"""This allows pickling the Parser and its keyword-only init arguments"""
|
"""This allows pickling the Parser and its keyword-only init arguments"""
|
||||||
|
@ -120,7 +119,7 @@ cdef class Parser(TrainablePipe):
|
||||||
resized = True
|
resized = True
|
||||||
if resized:
|
if resized:
|
||||||
self._resize()
|
self._resize()
|
||||||
self.add_string(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -456,24 +455,24 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def to_disk(self, path, exclude=tuple()):
|
def to_disk(self, path, exclude=tuple()):
|
||||||
serializers = {
|
serializers = {
|
||||||
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
|
"model": lambda p: (self.model.to_disk(p) if self.model is not True else True),
|
||||||
'strings.json': lambda p: srsly.write_json(p, self._added_strings),
|
"vocab": lambda p: self.vocab.to_disk(p),
|
||||||
'moves': lambda p: self.moves.to_disk(p, exclude=["strings"]),
|
"moves": lambda p: self.moves.to_disk(p, exclude=["strings"]),
|
||||||
'cfg': lambda p: srsly.write_json(p, self.cfg)
|
"cfg": lambda p: srsly.write_json(p, self.cfg)
|
||||||
}
|
}
|
||||||
util.to_disk(path, serializers, exclude)
|
util.to_disk(path, serializers, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, exclude=tuple()):
|
def from_disk(self, path, exclude=tuple()):
|
||||||
deserializers = {
|
deserializers = {
|
||||||
'strings.json': lambda p: [self.add_string(s) for s in srsly.read_json(p)],
|
"vocab": lambda p: self.vocab.from_disk(p),
|
||||||
'moves': lambda p: self.moves.from_disk(p, exclude=["strings"]),
|
"moves": lambda p: self.moves.from_disk(p, exclude=["strings"]),
|
||||||
'cfg': lambda p: self.cfg.update(srsly.read_json(p)),
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||||
'model': lambda p: None,
|
"model": lambda p: None,
|
||||||
}
|
}
|
||||||
util.from_disk(path, deserializers, exclude)
|
util.from_disk(path, deserializers, exclude)
|
||||||
if 'model' not in exclude:
|
if "model" not in exclude:
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
with (path / 'model').open('rb') as file_:
|
with (path / "model").open("rb") as file_:
|
||||||
bytes_data = file_.read()
|
bytes_data = file_.read()
|
||||||
try:
|
try:
|
||||||
self._resize()
|
self._resize()
|
||||||
|
@ -485,7 +484,7 @@ cdef class Parser(TrainablePipe):
|
||||||
def to_bytes(self, exclude=tuple()):
|
def to_bytes(self, exclude=tuple()):
|
||||||
serializers = {
|
serializers = {
|
||||||
"model": lambda: (self.model.to_bytes()),
|
"model": lambda: (self.model.to_bytes()),
|
||||||
"strings.json": lambda: srsly.json_dumps(sorted(self._added_strings)),
|
"vocab": lambda: self.vocab.to_bytes(),
|
||||||
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
|
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
|
||||||
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
|
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
|
||||||
}
|
}
|
||||||
|
@ -493,7 +492,7 @@ cdef class Parser(TrainablePipe):
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||||
deserializers = {
|
deserializers = {
|
||||||
"strings.json": lambda b: [self.add_string(s) for s in srsly.json_loads(b)],
|
"vocab": lambda b: self.vocab.from_bytes(b),
|
||||||
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
|
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
|
||||||
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||||
"model": lambda b: None,
|
"model": lambda b: None,
|
||||||
|
|
|
@ -121,9 +121,7 @@ def test_kb_default(nlp):
|
||||||
|
|
||||||
def test_kb_custom_length(nlp):
|
def test_kb_custom_length(nlp):
|
||||||
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
||||||
entity_linker = nlp.add_pipe(
|
entity_linker = nlp.add_pipe("entity_linker", config={"entity_vector_length": 35})
|
||||||
"entity_linker", config={"entity_vector_length": 35}
|
|
||||||
)
|
|
||||||
assert len(entity_linker.kb) == 0
|
assert len(entity_linker.kb) == 0
|
||||||
assert entity_linker.kb.get_size_entities() == 0
|
assert entity_linker.kb.get_size_entities() == 0
|
||||||
assert entity_linker.kb.get_size_aliases() == 0
|
assert entity_linker.kb.get_size_aliases() == 0
|
||||||
|
@ -213,16 +211,11 @@ def test_el_pipe_configuration(nlp):
|
||||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||||
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||||
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||||
kb.add_alias(
|
kb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
|
|
||||||
)
|
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
||||||
entity_linker = nlp.add_pipe(
|
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},)
|
||||||
"entity_linker",
|
|
||||||
config={"incl_context": False},
|
|
||||||
)
|
|
||||||
entity_linker.set_kb(create_kb)
|
entity_linker.set_kb(create_kb)
|
||||||
# With the default get_candidates function, matching is case-sensitive
|
# With the default get_candidates function, matching is case-sensitive
|
||||||
text = "Douglas and douglas are not the same."
|
text = "Douglas and douglas are not the same."
|
||||||
|
@ -453,14 +446,10 @@ def test_overfitting_IO():
|
||||||
return mykb
|
return mykb
|
||||||
|
|
||||||
# Create the Entity Linker component and add it to the pipeline
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
entity_linker = nlp.add_pipe(
|
entity_linker = nlp.add_pipe("entity_linker", last=True,)
|
||||||
"entity_linker",
|
|
||||||
last=True,
|
|
||||||
)
|
|
||||||
entity_linker.set_kb(create_kb)
|
entity_linker.set_kb(create_kb)
|
||||||
assert "Q2146908" in entity_linker.vocab.strings
|
assert "Q2146908" in entity_linker.vocab.strings
|
||||||
assert "Q2146908" in entity_linker.kb.vocab.strings
|
assert "Q2146908" in entity_linker.kb.vocab.strings
|
||||||
assert "Q2146908" in entity_linker.kb._added_strings
|
|
||||||
|
|
||||||
# train the NEL pipe
|
# train the NEL pipe
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
|
@ -101,4 +101,3 @@ def test_overfitting_IO():
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
assert [str(t.morph) for t in doc2] == gold_morphs
|
assert [str(t.morph) for t in doc2] == gold_morphs
|
||||||
assert [t.pos_ for t in doc2] == gold_pos_tags
|
assert [t.pos_ for t in doc2] == gold_pos_tags
|
||||||
assert nlp.get_pipe("morphologizer")._added_strings == nlp2.get_pipe("morphologizer")._added_strings
|
|
||||||
|
|
|
@ -80,4 +80,3 @@ def test_overfitting_IO():
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
assert [int(t.is_sent_start) for t in doc2] == gold_sent_starts
|
assert [int(t.is_sent_start) for t in doc2] == gold_sent_starts
|
||||||
assert nlp.get_pipe("senter")._added_strings == nlp2.get_pipe("senter")._added_strings
|
|
||||||
|
|
|
@ -98,7 +98,6 @@ def test_overfitting_IO():
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses["tagger"] < 0.00001
|
assert losses["tagger"] < 0.00001
|
||||||
assert tagger._added_strings == {"J", "N", "V"}
|
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
test_text = "I like blue eggs"
|
test_text = "I like blue eggs"
|
||||||
|
@ -117,7 +116,6 @@ def test_overfitting_IO():
|
||||||
assert doc2[1].tag_ is "V"
|
assert doc2[1].tag_ is "V"
|
||||||
assert doc2[2].tag_ is "J"
|
assert doc2[2].tag_ is "J"
|
||||||
assert doc2[3].tag_ is "N"
|
assert doc2[3].tag_ is "N"
|
||||||
assert nlp2.get_pipe("tagger")._added_strings == {"J", "N", "V"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_tagger_requires_labels():
|
def test_tagger_requires_labels():
|
||||||
|
|
|
@ -146,7 +146,6 @@ def test_overfitting_IO():
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert textcat.model.get_dim("nO") == 2
|
assert textcat.model.get_dim("nO") == 2
|
||||||
assert textcat._added_strings == {"NEGATIVE", "POSITIVE"}
|
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -168,7 +167,6 @@ def test_overfitting_IO():
|
||||||
cats2 = doc2.cats
|
cats2 = doc2.cats
|
||||||
assert cats2["POSITIVE"] > 0.9
|
assert cats2["POSITIVE"] > 0.9
|
||||||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
||||||
assert nlp2.get_pipe("textcat")._added_strings == {"NEGATIVE", "POSITIVE"}
|
|
||||||
|
|
||||||
# Test scoring
|
# Test scoring
|
||||||
scores = nlp.evaluate(train_examples)
|
scores = nlp.evaluate(train_examples)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from spacy.kb import KnowledgeBase, Writer
|
||||||
from spacy.vectors import Vectors
|
from spacy.vectors import Vectors
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.pipeline import TrainablePipe
|
from spacy.pipeline import TrainablePipe
|
||||||
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -50,8 +51,9 @@ def custom_pipe():
|
||||||
else:
|
else:
|
||||||
self.cfg = None
|
self.cfg = None
|
||||||
self.model = SerializableDummy()
|
self.model = SerializableDummy()
|
||||||
|
self.vocab = vocab
|
||||||
|
|
||||||
return MyPipe(None)
|
return MyPipe(Vocab())
|
||||||
|
|
||||||
|
|
||||||
def tagger():
|
def tagger():
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
import srsly
|
|
||||||
from spacy import registry, Vocab
|
from spacy import registry, Vocab
|
||||||
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||||
from spacy.pipeline import TextCategorizer, SentenceRecognizer
|
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
|
||||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||||
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
|
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
|
||||||
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
||||||
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
|
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
from thinc.api import Linear
|
||||||
import spacy
|
import spacy
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
@ -89,7 +89,6 @@ def test_serialize_parser_strings(Parser):
|
||||||
assert label not in vocab2.strings
|
assert label not in vocab2.strings
|
||||||
parser2 = Parser(vocab2, model, **config)
|
parser2 = Parser(vocab2, model, **config)
|
||||||
parser2 = parser2.from_bytes(parser1.to_bytes(exclude=["vocab"]))
|
parser2 = parser2.from_bytes(parser1.to_bytes(exclude=["vocab"]))
|
||||||
assert parser1._added_strings == parser2._added_strings == {"FunnyLabel"}
|
|
||||||
assert label in parser2.vocab.strings
|
assert label in parser2.vocab.strings
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,17 +165,13 @@ def test_serialize_tagger_strings(en_vocab, de_vocab, taggers):
|
||||||
# check that custom labels are serialized as part of the component's strings.jsonl
|
# check that custom labels are serialized as part of the component's strings.jsonl
|
||||||
tagger.add_label(label)
|
tagger.add_label(label)
|
||||||
assert label in tagger.vocab.strings
|
assert label in tagger.vocab.strings
|
||||||
assert tagger._added_strings == {label}
|
|
||||||
file_path = d / "tagger1"
|
file_path = d / "tagger1"
|
||||||
tagger.to_disk(file_path)
|
tagger.to_disk(file_path)
|
||||||
strings = srsly.read_json(file_path / "strings.json")
|
|
||||||
assert strings == ["SomeWeirdLabel"]
|
|
||||||
# ensure that the custom strings are loaded back in when using the tagger in another pipeline
|
# ensure that the custom strings are loaded back in when using the tagger in another pipeline
|
||||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
model = registry.resolve(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
tagger2 = Tagger(de_vocab, model).from_disk(file_path)
|
tagger2 = Tagger(de_vocab, model).from_disk(file_path)
|
||||||
assert label in tagger2.vocab.strings
|
assert label in tagger2.vocab.strings
|
||||||
assert tagger2._added_strings == {label}
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_textcat_empty(en_vocab):
|
def test_serialize_textcat_empty(en_vocab):
|
||||||
|
@ -253,3 +248,40 @@ def test_serialize_pipeline_disable_enable():
|
||||||
assert nlp5.pipe_names == ["ner"]
|
assert nlp5.pipe_names == ["ner"]
|
||||||
assert nlp5.component_names == ["ner"]
|
assert nlp5.component_names == ["ner"]
|
||||||
assert nlp5.disabled == []
|
assert nlp5.disabled == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_custom_trainable_pipe():
|
||||||
|
class BadCustomPipe1(TrainablePipe):
|
||||||
|
def __init__(self, vocab):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadCustomPipe2(TrainablePipe):
|
||||||
|
def __init__(self, vocab):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
class CustomPipe(TrainablePipe):
|
||||||
|
def __init__(self, vocab, model):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
pipe = BadCustomPipe1(Vocab())
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pipe.to_bytes()
|
||||||
|
with make_tempdir() as d:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pipe.to_disk(d)
|
||||||
|
pipe = BadCustomPipe2(Vocab())
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pipe.to_bytes()
|
||||||
|
with make_tempdir() as d:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pipe.to_disk(d)
|
||||||
|
pipe = CustomPipe(Vocab(), Linear())
|
||||||
|
pipe_bytes = pipe.to_bytes()
|
||||||
|
new_pipe = CustomPipe(Vocab(), Linear()).from_bytes(pipe_bytes)
|
||||||
|
assert new_pipe.to_bytes() == pipe_bytes
|
||||||
|
with make_tempdir() as d:
|
||||||
|
pipe.to_disk(d)
|
||||||
|
new_pipe = CustomPipe(Vocab(), Linear()).from_disk(d)
|
||||||
|
assert new_pipe.to_bytes() == pipe_bytes
|
||||||
|
|
|
@ -821,7 +821,7 @@ def get_object_name(obj: Any) -> str:
|
||||||
obj (Any): The Python object, typically a function or class.
|
obj (Any): The Python object, typically a function or class.
|
||||||
RETURNS (str): A human-readable name.
|
RETURNS (str): A human-readable name.
|
||||||
"""
|
"""
|
||||||
if hasattr(obj, "name"):
|
if hasattr(obj, "name") and obj.name is not None:
|
||||||
return obj.name
|
return obj.name
|
||||||
if hasattr(obj, "__name__"):
|
if hasattr(obj, "__name__"):
|
||||||
return obj.__name__
|
return obj.__name__
|
||||||
|
|
Loading…
Reference in New Issue
Block a user