mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Merge pull request #6216 from svlandeg/feature/nel-initialize
This commit is contained in:
commit
064575d79d
|
@ -843,7 +843,7 @@ class Language:
|
|||
*,
|
||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
) -> None:
|
||||
) -> Callable[[Doc], Doc]:
|
||||
"""Replace a component in the pipeline.
|
||||
|
||||
name (str): Name of the component to replace.
|
||||
|
@ -852,6 +852,7 @@ class Language:
|
|||
component. Will be merged with default config, if available.
|
||||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Callable[[Doc], Doc]): The new pipeline component.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/language#replace_pipe
|
||||
"""
|
||||
|
@ -866,9 +867,11 @@ class Language:
|
|||
self.remove_pipe(name)
|
||||
if not len(self._components) or pipe_index == len(self._components):
|
||||
# we have no components to insert before/after, or we're replacing the last component
|
||||
self.add_pipe(factory_name, name=name, config=config, validate=validate)
|
||||
return self.add_pipe(
|
||||
factory_name, name=name, config=config, validate=validate
|
||||
)
|
||||
else:
|
||||
self.add_pipe(
|
||||
return self.add_pipe(
|
||||
factory_name,
|
||||
name=name,
|
||||
before=pipe_index,
|
||||
|
@ -1300,7 +1303,11 @@ class Language:
|
|||
kwargs.setdefault("batch_size", batch_size)
|
||||
# non-trainable components may have a pipe() implementation that refers to dummy
|
||||
# predict and set_annotations methods
|
||||
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
|
||||
if (
|
||||
not hasattr(pipe, "pipe")
|
||||
or not hasattr(pipe, "is_trainable")
|
||||
or not pipe.is_trainable()
|
||||
):
|
||||
docs = _pipe(docs, pipe, kwargs)
|
||||
else:
|
||||
docs = pipe.pipe(docs, **kwargs)
|
||||
|
@ -1412,7 +1419,11 @@ class Language:
|
|||
kwargs.setdefault("batch_size", batch_size)
|
||||
# non-trainable components may have a pipe() implementation that refers to dummy
|
||||
# predict and set_annotations methods
|
||||
if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable():
|
||||
if (
|
||||
hasattr(proc, "pipe")
|
||||
and hasattr(proc, "is_trainable")
|
||||
and proc.is_trainable()
|
||||
):
|
||||
f = functools.partial(proc.pipe, **kwargs)
|
||||
else:
|
||||
# Apply the function, but yield the doc
|
||||
|
|
|
@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate
|
|||
import warnings
|
||||
|
||||
from ..kb import KnowledgeBase, Candidate
|
||||
from ..ml import empty_kb
|
||||
from ..tokens import Doc
|
||||
from .pipe import Pipe, deserialize_config
|
||||
from ..language import Language
|
||||
|
@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
||||
assigns=["token.ent_kb_id"],
|
||||
default_config={
|
||||
"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64},
|
||||
"model": DEFAULT_NEL_MODEL,
|
||||
"labels_discard": [],
|
||||
"incl_prior": True,
|
||||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
|
@ -58,11 +59,11 @@ def make_entity_linker(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase],
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
|
||||
):
|
||||
"""Construct an EntityLinker component.
|
||||
|
@ -70,19 +71,21 @@ def make_entity_linker(
|
|||
model (Model[List[Doc], Floats2d]): A model that learns document vector
|
||||
representations. Given a batch of Doc objects, it should return a single
|
||||
array, with one row per item in the batch.
|
||||
kb (KnowledgeBase): The knowledge-base to link entities to.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
"""
|
||||
return EntityLinker(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
kb_loader=kb_loader,
|
||||
labels_discard=labels_discard,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
)
|
||||
|
||||
|
@ -101,10 +104,10 @@ class EntityLinker(Pipe):
|
|||
model: Model,
|
||||
name: str = "entity_linker",
|
||||
*,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase],
|
||||
labels_discard: Iterable[str],
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
@ -113,10 +116,12 @@ class EntityLinker(Pipe):
|
|||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#init
|
||||
"""
|
||||
|
@ -127,15 +132,23 @@ class EntityLinker(Pipe):
|
|||
"labels_discard": list(labels_discard),
|
||||
"incl_prior": incl_prior,
|
||||
"incl_context": incl_context,
|
||||
"entity_vector_length": entity_vector_length,
|
||||
}
|
||||
self.kb = kb_loader(self.vocab)
|
||||
self.get_candidates = get_candidates
|
||||
self.cfg = dict(cfg)
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neightbour sentences to take into account
|
||||
self.n_sents = cfg.get("n_sents", 0)
|
||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||
|
||||
def _require_kb(self) -> None:
|
||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||
"""Define the KB of this pipe by providing a function that will
|
||||
create it using this object's vocab."""
|
||||
self.kb = kb_loader(self.vocab)
|
||||
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
|
||||
|
||||
def validate_kb(self) -> None:
|
||||
# Raise an error if the knowledge base is not initialized.
|
||||
if len(self.kb) == 0:
|
||||
raise ValueError(Errors.E139.format(name=self.name))
|
||||
|
@ -145,6 +158,7 @@ class EntityLinker(Pipe):
|
|||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase] = None,
|
||||
):
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
@ -152,11 +166,16 @@ class EntityLinker(Pipe):
|
|||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
|
||||
Note that providing this argument, will overwrite all data accumulated in the current KB.
|
||||
Use this only when loading a KB as-such from file.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#initialize
|
||||
"""
|
||||
self._ensure_examples(get_examples)
|
||||
self._require_kb()
|
||||
if kb_loader is not None:
|
||||
self.set_kb(kb_loader)
|
||||
self.validate_kb()
|
||||
nO = self.kb.entity_vector_length
|
||||
doc_sample = []
|
||||
vector_sample = []
|
||||
|
@ -192,7 +211,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#update
|
||||
"""
|
||||
self._require_kb()
|
||||
self.validate_kb()
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
@ -303,7 +322,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#predict
|
||||
"""
|
||||
self._require_kb()
|
||||
self.validate_kb()
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
if not docs:
|
||||
|
|
|
@ -110,7 +110,7 @@ def test_kb_invalid_entity_vector(nlp):
|
|||
|
||||
|
||||
def test_kb_default(nlp):
|
||||
"""Test that the default (empty) KB is loaded when not providing a config"""
|
||||
"""Test that the default (empty) KB is loaded upon construction"""
|
||||
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||
assert len(entity_linker.kb) == 0
|
||||
assert entity_linker.kb.get_size_entities() == 0
|
||||
|
@ -122,7 +122,7 @@ def test_kb_default(nlp):
|
|||
def test_kb_custom_length(nlp):
|
||||
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
|
||||
"entity_linker", config={"entity_vector_length": 35}
|
||||
)
|
||||
assert len(entity_linker.kb) == 0
|
||||
assert entity_linker.kb.get_size_entities() == 0
|
||||
|
@ -130,18 +130,9 @@ def test_kb_custom_length(nlp):
|
|||
assert entity_linker.kb.entity_vector_length == 35
|
||||
|
||||
|
||||
def test_kb_undefined(nlp):
|
||||
"""Test that the EL can't train without defining a KB"""
|
||||
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||
with pytest.raises(ValueError):
|
||||
entity_linker.initialize(lambda: [])
|
||||
|
||||
|
||||
def test_kb_empty(nlp):
|
||||
"""Test that the EL can't train with an empty KB"""
|
||||
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
assert len(entity_linker.kb) == 0
|
||||
def test_kb_initialize_empty(nlp):
|
||||
"""Test that the EL can't initialize without examples"""
|
||||
entity_linker = nlp.add_pipe("entity_linker")
|
||||
with pytest.raises(ValueError):
|
||||
entity_linker.initialize(lambda: [])
|
||||
|
||||
|
@ -201,24 +192,21 @@ def test_el_pipe_configuration(nlp):
|
|||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns([pattern])
|
||||
|
||||
@registry.misc.register("myAdamKB.v1")
|
||||
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
kb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
|
||||
)
|
||||
return kb
|
||||
|
||||
return create_kb
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
kb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
|
||||
)
|
||||
return kb
|
||||
|
||||
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
||||
nlp.add_pipe(
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker",
|
||||
config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False},
|
||||
config={"incl_context": False},
|
||||
)
|
||||
entity_linker.set_kb(create_kb)
|
||||
# With the default get_candidates function, matching is case-sensitive
|
||||
text = "Douglas and douglas are not the same."
|
||||
doc = nlp(text)
|
||||
|
@ -234,15 +222,15 @@ def test_el_pipe_configuration(nlp):
|
|||
return get_lowercased_candidates
|
||||
|
||||
# replace the pipe with a new one with with a different candidate generator
|
||||
nlp.replace_pipe(
|
||||
entity_linker = nlp.replace_pipe(
|
||||
"entity_linker",
|
||||
"entity_linker",
|
||||
config={
|
||||
"kb_loader": {"@misc": "myAdamKB.v1"},
|
||||
"incl_context": False,
|
||||
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
|
||||
},
|
||||
)
|
||||
entity_linker.set_kb(create_kb)
|
||||
doc = nlp(text)
|
||||
assert doc[0].ent_kb_id_ == "Q2"
|
||||
assert doc[1].ent_kb_id_ == ""
|
||||
|
@ -334,19 +322,15 @@ def test_preserving_links_asdoc(nlp):
|
|||
"""Test that Span.as_doc preserves the existing entity links"""
|
||||
vector_length = 1
|
||||
|
||||
@registry.misc.register("myLocationsKB.v1")
|
||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||
return mykb
|
||||
|
||||
return create_kb
|
||||
def create_kb(vocab):
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||
return mykb
|
||||
|
||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||
nlp.add_pipe("sentencizer")
|
||||
|
@ -356,8 +340,9 @@ def test_preserving_links_asdoc(nlp):
|
|||
]
|
||||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns(patterns)
|
||||
el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True)
|
||||
config = {"incl_prior": False}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config, last=True)
|
||||
entity_linker.set_kb(create_kb)
|
||||
nlp.initialize()
|
||||
assert entity_linker.model.get_dim("nO") == vector_length
|
||||
|
||||
|
@ -435,30 +420,26 @@ def test_overfitting_IO():
|
|||
doc = nlp(text)
|
||||
train_examples.append(Example.from_dict(doc, annotation))
|
||||
|
||||
@registry.misc.register("myOverfittingKB.v1")
|
||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||
# Q2146908 (Russ Cochran): American golfer
|
||||
# Q7381115 (Russ Cochran): publisher
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||
mykb.add_alias(
|
||||
alias="Russ Cochran",
|
||||
entities=["Q2146908", "Q7381115"],
|
||||
probabilities=[0.5, 0.5],
|
||||
)
|
||||
return mykb
|
||||
|
||||
return create_kb
|
||||
def create_kb(vocab):
|
||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||
# Q2146908 (Russ Cochran): American golfer
|
||||
# Q7381115 (Russ Cochran): publisher
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||
mykb.add_alias(
|
||||
alias="Russ Cochran",
|
||||
entities=["Q2146908", "Q7381115"],
|
||||
probabilities=[0.5, 0.5],
|
||||
)
|
||||
return mykb
|
||||
|
||||
# Create the Entity Linker component and add it to the pipeline
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker",
|
||||
config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
|
||||
last=True,
|
||||
)
|
||||
entity_linker.set_kb(create_kb)
|
||||
|
||||
# train the NEL pipe
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
|
|
@ -71,17 +71,13 @@ def tagger():
|
|||
def entity_linker():
|
||||
nlp = Language()
|
||||
|
||||
@registry.misc.register("TestIssue5230KB.v1")
|
||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||
return kb
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||
return kb
|
||||
|
||||
return create_kb
|
||||
|
||||
config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
entity_linker = nlp.add_pipe("entity_linker")
|
||||
entity_linker.set_kb(create_kb)
|
||||
# need to add model for two reasons:
|
||||
# 1. no model leads to error in serialization,
|
||||
# 2. the affected line is the one for model serialization
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Callable
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from spacy.util import ensure_path, registry
|
||||
from spacy.util import ensure_path, registry, load_model_from_config
|
||||
from spacy.kb import KnowledgeBase
|
||||
from thinc.api import Config
|
||||
|
||||
from ..util import make_tempdir
|
||||
from numpy import zeros
|
||||
|
||||
|
||||
def test_serialize_kb_disk(en_vocab):
|
||||
|
@ -80,6 +81,28 @@ def _check_kb(kb):
|
|||
def test_serialize_subclassed_kb():
|
||||
"""Check that IO of a custom KB works fine as part of an EL pipe."""
|
||||
|
||||
config_string = """
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["entity_linker"]
|
||||
|
||||
[components]
|
||||
|
||||
[components.entity_linker]
|
||||
factory = "entity_linker"
|
||||
|
||||
[initialize]
|
||||
|
||||
[initialize.components]
|
||||
|
||||
[initialize.components.entity_linker]
|
||||
|
||||
[initialize.components.entity_linker.kb_loader]
|
||||
@misc = "spacy.CustomKB.v1"
|
||||
entity_vector_length = 342
|
||||
custom_field = 666
|
||||
"""
|
||||
|
||||
class SubKnowledgeBase(KnowledgeBase):
|
||||
def __init__(self, vocab, entity_vector_length, custom_field):
|
||||
super().__init__(vocab, entity_vector_length)
|
||||
|
@ -90,23 +113,21 @@ def test_serialize_subclassed_kb():
|
|||
entity_vector_length: int, custom_field: int
|
||||
) -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def custom_kb_factory(vocab):
|
||||
return SubKnowledgeBase(
|
||||
kb = SubKnowledgeBase(
|
||||
vocab=vocab,
|
||||
entity_vector_length=entity_vector_length,
|
||||
custom_field=custom_field,
|
||||
)
|
||||
kb.add_entity("random_entity", 0.0, zeros(entity_vector_length))
|
||||
return kb
|
||||
|
||||
return custom_kb_factory
|
||||
|
||||
nlp = English()
|
||||
config = {
|
||||
"kb_loader": {
|
||||
"@misc": "spacy.CustomKB.v1",
|
||||
"entity_vector_length": 342,
|
||||
"custom_field": 666,
|
||||
}
|
||||
}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
config = Config().from_str(config_string)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
nlp.initialize()
|
||||
|
||||
entity_linker = nlp.get_pipe("entity_linker")
|
||||
assert type(entity_linker.kb) == SubKnowledgeBase
|
||||
assert entity_linker.kb.entity_vector_length == 342
|
||||
assert entity_linker.kb.custom_field == 666
|
||||
|
@ -116,6 +137,7 @@ def test_serialize_subclassed_kb():
|
|||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||
assert type(entity_linker2.kb) == SubKnowledgeBase
|
||||
# After IO, the KB is the standard one
|
||||
assert type(entity_linker2.kb) == KnowledgeBase
|
||||
assert entity_linker2.kb.entity_vector_length == 342
|
||||
assert entity_linker2.kb.custom_field == 666
|
||||
assert not hasattr(entity_linker2.kb, "custom_field")
|
||||
|
|
|
@ -524,7 +524,7 @@ Get a pipeline component for a given component name.
|
|||
|
||||
## Language.replace_pipe {#replace_pipe tag="method" new="2"}
|
||||
|
||||
Replace a component in the pipeline.
|
||||
Replace a component in the pipeline and return the new component.
|
||||
|
||||
<Infobox title="Changed in v3.0" variant="warning">
|
||||
|
||||
|
@ -538,7 +538,7 @@ and instead expects the **name of a component factory** registered using
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> nlp.replace_pipe("parser", my_custom_parser)
|
||||
> new_parser = nlp.replace_pipe("parser", "my_custom_parser")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
|
@ -548,6 +548,7 @@ and instead expects the **name of a component factory** registered using
|
|||
| _keyword-only_ | |
|
||||
| `config` <Tag variant="new">3</Tag> | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ |
|
||||
| `validate` <Tag variant="new">3</Tag> | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ |
|
||||
| **RETURNS** | The new pipeline component. ~~Callable[[Doc], Doc]~~ |
|
||||
|
||||
## Language.rename_pipe {#rename_pipe tag="method" new="2"}
|
||||
|
||||
|
|
|
@ -297,7 +297,7 @@ packages. This lets one application easily customize the behavior of another, by
|
|||
exposing an entry point in its `setup.py`. For a quick and fun intro to entry
|
||||
points in Python, check out
|
||||
[this excellent blog post](https://amir.rachum.com/blog/2017/07/28/python-entry-points/).
|
||||
spaCy can load custom function from several different entry points to add
|
||||
spaCy can load custom functions from several different entry points to add
|
||||
pipeline component factories, language classes and other settings. To make spaCy
|
||||
use your entry points, your package needs to expose them and it needs to be
|
||||
installed in the same environment – that's it.
|
||||
|
|
|
@ -395,7 +395,7 @@ type-check model definitions.
|
|||
For data validation, spaCy v3.0 adopts
|
||||
[`pydantic`](https://github.com/samuelcolvin/pydantic). It also powers the data
|
||||
validation of Thinc's [config system](https://thinc.ai/docs/usage-config), which
|
||||
lets you to register **custom functions with typed arguments**, reference them
|
||||
lets you register **custom functions with typed arguments**, reference them
|
||||
in your config and see validation errors if the argument values don't match.
|
||||
|
||||
<Infobox title="Details & Documentation" emoji="📖" list>
|
||||
|
|
Loading…
Reference in New Issue
Block a user