mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
set_kb method for entity_linker
This commit is contained in:
parent
efedccea8d
commit
eaf5c265cb
|
@ -600,7 +600,6 @@ class Language:
|
||||||
*,
|
*,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
raw_config: Optional[Config] = None,
|
raw_config: Optional[Config] = None,
|
||||||
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Callable[[Doc], Doc]:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Create a pipeline component. Mostly used internally. To create and
|
"""Create a pipeline component. Mostly used internally. To create and
|
||||||
|
@ -612,9 +611,6 @@ class Language:
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||||
init_config (Optional[Dict[str, Any]]): Config parameters to use to
|
|
||||||
initialize this component. Will be used to update the internal
|
|
||||||
'initialize' config.
|
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
@ -625,13 +621,9 @@ class Language:
|
||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
|
err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
if not isinstance(init_config, dict):
|
|
||||||
err = Errors.E962.format(style="init_config", name=name, cfg_type=type(init_config))
|
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
if not srsly.is_json_serializable(config):
|
if not srsly.is_json_serializable(config):
|
||||||
raise ValueError(Errors.E961.format(config=config))
|
raise ValueError(Errors.E961.format(config=config))
|
||||||
if not srsly.is_json_serializable(init_config):
|
|
||||||
raise ValueError(Errors.E961.format(config=init_config))
|
|
||||||
if not self.has_factory(factory_name):
|
if not self.has_factory(factory_name):
|
||||||
err = Errors.E002.format(
|
err = Errors.E002.format(
|
||||||
name=factory_name,
|
name=factory_name,
|
||||||
|
@ -643,8 +635,6 @@ class Language:
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
pipe_meta = self.get_factory_meta(factory_name)
|
pipe_meta = self.get_factory_meta(factory_name)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
if init_config:
|
|
||||||
self._config["initialize"]["components"][name] = init_config
|
|
||||||
# This is unideal, but the alternative would mean you always need to
|
# This is unideal, but the alternative would mean you always need to
|
||||||
# specify the full config settings, which is not really viable.
|
# specify the full config settings, which is not really viable.
|
||||||
if pipe_meta.default_config:
|
if pipe_meta.default_config:
|
||||||
|
@ -719,7 +709,6 @@ class Language:
|
||||||
source: Optional["Language"] = None,
|
source: Optional["Language"] = None,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
raw_config: Optional[Config] = None,
|
raw_config: Optional[Config] = None,
|
||||||
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Callable[[Doc], Doc]:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Add a component to the processing pipeline. Valid components are
|
"""Add a component to the processing pipeline. Valid components are
|
||||||
|
@ -742,9 +731,6 @@ class Language:
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
raw_config (Optional[Config]): Internals: the non-interpolated config.
|
||||||
init_config (Optional[Dict[str, Any]]): Config parameters to use to
|
|
||||||
initialize this component. Will be used to update the internal
|
|
||||||
'initialize' config.
|
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||||
|
@ -778,7 +764,6 @@ class Language:
|
||||||
name=name,
|
name=name,
|
||||||
config=config,
|
config=config,
|
||||||
raw_config=raw_config,
|
raw_config=raw_config,
|
||||||
init_config=init_config,
|
|
||||||
validate=validate,
|
validate=validate,
|
||||||
)
|
)
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
|
@ -858,20 +843,17 @@ class Language:
|
||||||
factory_name: str,
|
factory_name: str,
|
||||||
*,
|
*,
|
||||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||||
init_config: Dict[str, Any] = SimpleFrozenDict(),
|
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> None:
|
) -> Callable[[Doc], Doc]:
|
||||||
"""Replace a component in the pipeline.
|
"""Replace a component in the pipeline.
|
||||||
|
|
||||||
name (str): Name of the component to replace.
|
name (str): Name of the component to replace.
|
||||||
factory_name (str): Factory name of replacement component.
|
factory_name (str): Factory name of replacement component.
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
init_config (Optional[Dict[str, Any]]): Config parameters to use to
|
|
||||||
initialize this component. Will be used to update the internal
|
|
||||||
'initialize' config.
|
|
||||||
validate (bool): Whether to validate the component config against the
|
validate (bool): Whether to validate the component config against the
|
||||||
arguments and types expected by the factory.
|
arguments and types expected by the factory.
|
||||||
|
RETURNS (Callable[[Doc], Doc]): The new pipeline component.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/language#replace_pipe
|
DOCS: https://nightly.spacy.io/api/language#replace_pipe
|
||||||
"""
|
"""
|
||||||
|
@ -886,14 +868,15 @@ class Language:
|
||||||
self.remove_pipe(name)
|
self.remove_pipe(name)
|
||||||
if not len(self._components) or pipe_index == len(self._components):
|
if not len(self._components) or pipe_index == len(self._components):
|
||||||
# we have no components to insert before/after, or we're replacing the last component
|
# we have no components to insert before/after, or we're replacing the last component
|
||||||
self.add_pipe(factory_name, name=name, config=config, init_config=init_config, validate=validate)
|
return self.add_pipe(
|
||||||
|
factory_name, name=name, config=config, validate=validate
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.add_pipe(
|
return self.add_pipe(
|
||||||
factory_name,
|
factory_name,
|
||||||
name=name,
|
name=name,
|
||||||
before=pipe_index,
|
before=pipe_index,
|
||||||
config=config,
|
config=config,
|
||||||
init_config=init_config,
|
|
||||||
validate=validate,
|
validate=validate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1321,7 +1304,11 @@ class Language:
|
||||||
kwargs.setdefault("batch_size", batch_size)
|
kwargs.setdefault("batch_size", batch_size)
|
||||||
# non-trainable components may have a pipe() implementation that refers to dummy
|
# non-trainable components may have a pipe() implementation that refers to dummy
|
||||||
# predict and set_annotations methods
|
# predict and set_annotations methods
|
||||||
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
|
if (
|
||||||
|
not hasattr(pipe, "pipe")
|
||||||
|
or not hasattr(pipe, "is_trainable")
|
||||||
|
or not pipe.is_trainable()
|
||||||
|
):
|
||||||
docs = _pipe(docs, pipe, kwargs)
|
docs = _pipe(docs, pipe, kwargs)
|
||||||
else:
|
else:
|
||||||
docs = pipe.pipe(docs, **kwargs)
|
docs = pipe.pipe(docs, **kwargs)
|
||||||
|
@ -1433,7 +1420,11 @@ class Language:
|
||||||
kwargs.setdefault("batch_size", batch_size)
|
kwargs.setdefault("batch_size", batch_size)
|
||||||
# non-trainable components may have a pipe() implementation that refers to dummy
|
# non-trainable components may have a pipe() implementation that refers to dummy
|
||||||
# predict and set_annotations methods
|
# predict and set_annotations methods
|
||||||
if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable():
|
if (
|
||||||
|
hasattr(proc, "pipe")
|
||||||
|
and hasattr(proc, "is_trainable")
|
||||||
|
and proc.is_trainable()
|
||||||
|
):
|
||||||
f = functools.partial(proc.pipe, **kwargs)
|
f = functools.partial(proc.pipe, **kwargs)
|
||||||
else:
|
else:
|
||||||
# Apply the function, but yield the doc
|
# Apply the function, but yield the doc
|
||||||
|
|
|
@ -142,6 +142,12 @@ class EntityLinker(Pipe):
|
||||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||||
|
|
||||||
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||||
|
"""Define the KB of this pipe by providing a function that will
|
||||||
|
create it using this object's vocab."""
|
||||||
|
self.kb = kb_loader(self.vocab)
|
||||||
|
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
|
||||||
|
|
||||||
def validate_kb(self) -> None:
|
def validate_kb(self) -> None:
|
||||||
# Raise an error if the knowledge base is not initialized.
|
# Raise an error if the knowledge base is not initialized.
|
||||||
if len(self.kb) == 0:
|
if len(self.kb) == 0:
|
||||||
|
@ -168,8 +174,7 @@ class EntityLinker(Pipe):
|
||||||
"""
|
"""
|
||||||
self._ensure_examples(get_examples)
|
self._ensure_examples(get_examples)
|
||||||
if kb_loader is not None:
|
if kb_loader is not None:
|
||||||
self.kb = kb_loader(self.vocab)
|
self.set_kb(kb_loader)
|
||||||
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
|
|
||||||
self.validate_kb()
|
self.validate_kb()
|
||||||
nO = self.kb.entity_vector_length
|
nO = self.kb.entity_vector_length
|
||||||
doc_sample = []
|
doc_sample = []
|
||||||
|
|
|
@ -130,18 +130,9 @@ def test_kb_custom_length(nlp):
|
||||||
assert entity_linker.kb.entity_vector_length == 35
|
assert entity_linker.kb.entity_vector_length == 35
|
||||||
|
|
||||||
|
|
||||||
def test_kb_undefined(nlp):
|
def test_kb_initialize_empty(nlp):
|
||||||
"""Test that the EL can't train without defining a KB"""
|
"""Test that the EL can't initialize without examples"""
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config={})
|
entity_linker = nlp.add_pipe("entity_linker")
|
||||||
with pytest.raises(ValueError):
|
|
||||||
entity_linker.initialize(lambda: [])
|
|
||||||
|
|
||||||
|
|
||||||
def test_kb_empty(nlp):
|
|
||||||
"""Test that the EL can't train with an empty KB"""
|
|
||||||
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
|
||||||
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
|
|
||||||
assert len(entity_linker.kb) == 0
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
entity_linker.initialize(lambda: [])
|
entity_linker.initialize(lambda: [])
|
||||||
|
|
||||||
|
@ -201,26 +192,21 @@ def test_el_pipe_configuration(nlp):
|
||||||
ruler = nlp.add_pipe("entity_ruler")
|
ruler = nlp.add_pipe("entity_ruler")
|
||||||
ruler.add_patterns([pattern])
|
ruler.add_patterns([pattern])
|
||||||
|
|
||||||
@registry.misc.register("myAdamKB.v1")
|
def create_kb(vocab):
|
||||||
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
|
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||||
def create_kb(vocab):
|
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||||
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
kb.add_alias(
|
||||||
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
|
||||||
kb.add_alias(
|
)
|
||||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
|
return kb
|
||||||
)
|
|
||||||
return kb
|
|
||||||
|
|
||||||
return create_kb
|
|
||||||
|
|
||||||
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
||||||
nlp.add_pipe(
|
entity_linker = nlp.add_pipe(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
config={"incl_context": False},
|
config={"incl_context": False},
|
||||||
init_config={"kb_loader": {"@misc": "myAdamKB.v1"}},
|
|
||||||
)
|
)
|
||||||
nlp.initialize()
|
entity_linker.set_kb(create_kb)
|
||||||
# With the default get_candidates function, matching is case-sensitive
|
# With the default get_candidates function, matching is case-sensitive
|
||||||
text = "Douglas and douglas are not the same."
|
text = "Douglas and douglas are not the same."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
@ -236,18 +222,15 @@ def test_el_pipe_configuration(nlp):
|
||||||
return get_lowercased_candidates
|
return get_lowercased_candidates
|
||||||
|
|
||||||
# replace the pipe with a new one with with a different candidate generator
|
# replace the pipe with a new one with with a different candidate generator
|
||||||
nlp.replace_pipe(
|
entity_linker = nlp.replace_pipe(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
config={
|
config={
|
||||||
"incl_context": False,
|
"incl_context": False,
|
||||||
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
|
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
|
||||||
},
|
},
|
||||||
init_config={
|
|
||||||
"kb_loader": {"@misc": "myAdamKB.v1"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
nlp.initialize()
|
entity_linker.set_kb(create_kb)
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
assert doc[0].ent_kb_id_ == "Q2"
|
assert doc[0].ent_kb_id_ == "Q2"
|
||||||
assert doc[1].ent_kb_id_ == ""
|
assert doc[1].ent_kb_id_ == ""
|
||||||
|
@ -339,19 +322,15 @@ def test_preserving_links_asdoc(nlp):
|
||||||
"""Test that Span.as_doc preserves the existing entity links"""
|
"""Test that Span.as_doc preserves the existing entity links"""
|
||||||
vector_length = 1
|
vector_length = 1
|
||||||
|
|
||||||
@registry.misc.register("myLocationsKB.v1")
|
def create_kb(vocab):
|
||||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
def create_kb(vocab):
|
# adding entities
|
||||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||||
# adding entities
|
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
# adding aliases
|
||||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||||
# adding aliases
|
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
return mykb
|
||||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
|
||||||
return mykb
|
|
||||||
|
|
||||||
return create_kb
|
|
||||||
|
|
||||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||||
nlp.add_pipe("sentencizer")
|
nlp.add_pipe("sentencizer")
|
||||||
|
@ -362,8 +341,8 @@ def test_preserving_links_asdoc(nlp):
|
||||||
ruler = nlp.add_pipe("entity_ruler")
|
ruler = nlp.add_pipe("entity_ruler")
|
||||||
ruler.add_patterns(patterns)
|
ruler.add_patterns(patterns)
|
||||||
config = {"incl_prior": False}
|
config = {"incl_prior": False}
|
||||||
init_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}}
|
entity_linker = nlp.add_pipe("entity_linker", config=config, last=True)
|
||||||
entity_linker = nlp.add_pipe("entity_linker", config=config, init_config=init_config, last=True)
|
entity_linker.set_kb(create_kb)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert entity_linker.model.get_dim("nO") == vector_length
|
assert entity_linker.model.get_dim("nO") == vector_length
|
||||||
|
|
||||||
|
@ -441,30 +420,26 @@ def test_overfitting_IO():
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
train_examples.append(Example.from_dict(doc, annotation))
|
train_examples.append(Example.from_dict(doc, annotation))
|
||||||
|
|
||||||
@registry.misc.register("myOverfittingKB.v1")
|
def create_kb(vocab):
|
||||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||||
def create_kb(vocab):
|
# Q2146908 (Russ Cochran): American golfer
|
||||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
# Q7381115 (Russ Cochran): publisher
|
||||||
# Q2146908 (Russ Cochran): American golfer
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
# Q7381115 (Russ Cochran): publisher
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
mykb.add_alias(
|
||||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
alias="Russ Cochran",
|
||||||
mykb.add_alias(
|
entities=["Q2146908", "Q7381115"],
|
||||||
alias="Russ Cochran",
|
probabilities=[0.5, 0.5],
|
||||||
entities=["Q2146908", "Q7381115"],
|
)
|
||||||
probabilities=[0.5, 0.5],
|
return mykb
|
||||||
)
|
|
||||||
return mykb
|
|
||||||
|
|
||||||
return create_kb
|
|
||||||
|
|
||||||
# Create the Entity Linker component and add it to the pipeline
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
entity_linker = nlp.add_pipe(
|
entity_linker = nlp.add_pipe(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
init_config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
|
|
||||||
last=True,
|
last=True,
|
||||||
)
|
)
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
|
|
||||||
# train the NEL pipe
|
# train the NEL pipe
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
|
@ -71,17 +71,13 @@ def tagger():
|
||||||
def entity_linker():
|
def entity_linker():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
||||||
@registry.misc.register("TestIssue5230KB.v1")
|
def create_kb(vocab):
|
||||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||||
def create_kb(vocab):
|
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
return kb
|
||||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
|
||||||
return kb
|
|
||||||
|
|
||||||
return create_kb
|
entity_linker = nlp.add_pipe("entity_linker")
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
|
|
||||||
entity_linker = nlp.add_pipe("entity_linker", init_config=init_config)
|
|
||||||
# need to add model for two reasons:
|
# need to add model for two reasons:
|
||||||
# 1. no model leads to error in serialization,
|
# 1. no model leads to error in serialization,
|
||||||
# 2. the affected line is the one for model serialization
|
# 2. the affected line is the one for model serialization
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.util import ensure_path, registry, load_model_from_config
|
||||||
from spacy.util import ensure_path, registry
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
from thinc.api import Config
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
from numpy import zeros
|
from numpy import zeros
|
||||||
|
@ -81,6 +81,28 @@ def _check_kb(kb):
|
||||||
def test_serialize_subclassed_kb():
|
def test_serialize_subclassed_kb():
|
||||||
"""Check that IO of a custom KB works fine as part of an EL pipe."""
|
"""Check that IO of a custom KB works fine as part of an EL pipe."""
|
||||||
|
|
||||||
|
config_string = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["entity_linker"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.entity_linker]
|
||||||
|
factory = "entity_linker"
|
||||||
|
|
||||||
|
[initialize]
|
||||||
|
|
||||||
|
[initialize.components]
|
||||||
|
|
||||||
|
[initialize.components.entity_linker]
|
||||||
|
|
||||||
|
[initialize.components.entity_linker.kb_loader]
|
||||||
|
@misc = "spacy.CustomKB.v1"
|
||||||
|
entity_vector_length = 342
|
||||||
|
custom_field = 666
|
||||||
|
"""
|
||||||
|
|
||||||
class SubKnowledgeBase(KnowledgeBase):
|
class SubKnowledgeBase(KnowledgeBase):
|
||||||
def __init__(self, vocab, entity_vector_length, custom_field):
|
def __init__(self, vocab, entity_vector_length, custom_field):
|
||||||
super().__init__(vocab, entity_vector_length)
|
super().__init__(vocab, entity_vector_length)
|
||||||
|
@ -101,16 +123,11 @@ def test_serialize_subclassed_kb():
|
||||||
|
|
||||||
return custom_kb_factory
|
return custom_kb_factory
|
||||||
|
|
||||||
nlp = English()
|
config = Config().from_str(config_string)
|
||||||
config = {
|
nlp = load_model_from_config(config, auto_fill=True)
|
||||||
"kb_loader": {
|
|
||||||
"@misc": "spacy.CustomKB.v1",
|
|
||||||
"entity_vector_length": 342,
|
|
||||||
"custom_field": 666,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
|
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
|
|
||||||
|
entity_linker = nlp.get_pipe("entity_linker")
|
||||||
assert type(entity_linker.kb) == SubKnowledgeBase
|
assert type(entity_linker.kb) == SubKnowledgeBase
|
||||||
assert entity_linker.kb.entity_vector_length == 342
|
assert entity_linker.kb.entity_vector_length == 342
|
||||||
assert entity_linker.kb.custom_field == 666
|
assert entity_linker.kb.custom_field == 666
|
||||||
|
|
|
@ -524,7 +524,7 @@ Get a pipeline component for a given component name.
|
||||||
|
|
||||||
## Language.replace_pipe {#replace_pipe tag="method" new="2"}
|
## Language.replace_pipe {#replace_pipe tag="method" new="2"}
|
||||||
|
|
||||||
Replace a component in the pipeline.
|
Replace a component in the pipeline and return the new component.
|
||||||
|
|
||||||
<Infobox title="Changed in v3.0" variant="warning">
|
<Infobox title="Changed in v3.0" variant="warning">
|
||||||
|
|
||||||
|
@ -538,7 +538,7 @@ and instead expects the **name of a component factory** registered using
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> nlp.replace_pipe("parser", my_custom_parser)
|
> new_parser = nlp.replace_pipe("parser", "my_custom_parser")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|
@ -548,6 +548,7 @@ and instead expects the **name of a component factory** registered using
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `config` <Tag variant="new">3</Tag> | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ |
|
| `config` <Tag variant="new">3</Tag> | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ |
|
||||||
| `validate` <Tag variant="new">3</Tag> | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ |
|
| `validate` <Tag variant="new">3</Tag> | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ |
|
||||||
|
| **RETURNS** | The new pipeline component. ~~Callable[[Doc], Doc]~~ |
|
||||||
|
|
||||||
## Language.rename_pipe {#rename_pipe tag="method" new="2"}
|
## Language.rename_pipe {#rename_pipe tag="method" new="2"}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user