set_kb method for entity_linker

This commit is contained in:
svlandeg 2020-10-08 10:34:01 +02:00
parent efedccea8d
commit eaf5c265cb
6 changed files with 100 additions and 115 deletions

View File

@ -600,7 +600,6 @@ class Language:
*, *,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
raw_config: Optional[Config] = None, raw_config: Optional[Config] = None,
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Create a pipeline component. Mostly used internally. To create and """Create a pipeline component. Mostly used internally. To create and
@ -612,9 +611,6 @@ class Language:
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
raw_config (Optional[Config]): Internals: the non-interpolated config. raw_config (Optional[Config]): Internals: the non-interpolated config.
init_config (Optional[Dict[str, Any]]): Config parameters to use to
initialize this component. Will be used to update the internal
'initialize' config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -625,13 +621,9 @@ class Language:
if not isinstance(config, dict): if not isinstance(config, dict):
err = Errors.E962.format(style="config", name=name, cfg_type=type(config)) err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
raise ValueError(err) raise ValueError(err)
if not isinstance(init_config, dict):
err = Errors.E962.format(style="init_config", name=name, cfg_type=type(init_config))
raise ValueError(err) raise ValueError(err)
if not srsly.is_json_serializable(config): if not srsly.is_json_serializable(config):
raise ValueError(Errors.E961.format(config=config)) raise ValueError(Errors.E961.format(config=config))
if not srsly.is_json_serializable(init_config):
raise ValueError(Errors.E961.format(config=init_config))
if not self.has_factory(factory_name): if not self.has_factory(factory_name):
err = Errors.E002.format( err = Errors.E002.format(
name=factory_name, name=factory_name,
@ -643,8 +635,6 @@ class Language:
raise ValueError(err) raise ValueError(err)
pipe_meta = self.get_factory_meta(factory_name) pipe_meta = self.get_factory_meta(factory_name)
config = config or {} config = config or {}
if init_config:
self._config["initialize"]["components"][name] = init_config
# This is unideal, but the alternative would mean you always need to # This is unideal, but the alternative would mean you always need to
# specify the full config settings, which is not really viable. # specify the full config settings, which is not really viable.
if pipe_meta.default_config: if pipe_meta.default_config:
@ -719,7 +709,6 @@ class Language:
source: Optional["Language"] = None, source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
raw_config: Optional[Config] = None, raw_config: Optional[Config] = None,
init_config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> Callable[[Doc], Doc]: ) -> Callable[[Doc], Doc]:
"""Add a component to the processing pipeline. Valid components are """Add a component to the processing pipeline. Valid components are
@ -742,9 +731,6 @@ class Language:
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
raw_config (Optional[Config]): Internals: the non-interpolated config. raw_config (Optional[Config]): Internals: the non-interpolated config.
init_config (Optional[Dict[str, Any]]): Config parameters to use to
initialize this component. Will be used to update the internal
'initialize' config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The pipeline component. RETURNS (Callable[[Doc], Doc]): The pipeline component.
@ -778,7 +764,6 @@ class Language:
name=name, name=name,
config=config, config=config,
raw_config=raw_config, raw_config=raw_config,
init_config=init_config,
validate=validate, validate=validate,
) )
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
@ -858,20 +843,17 @@ class Language:
factory_name: str, factory_name: str,
*, *,
config: Dict[str, Any] = SimpleFrozenDict(), config: Dict[str, Any] = SimpleFrozenDict(),
init_config: Dict[str, Any] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
) -> None: ) -> Callable[[Doc], Doc]:
"""Replace a component in the pipeline. """Replace a component in the pipeline.
name (str): Name of the component to replace. name (str): Name of the component to replace.
factory_name (str): Factory name of replacement component. factory_name (str): Factory name of replacement component.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
init_config (Optional[Dict[str, Any]]): Config parameters to use to
initialize this component. Will be used to update the internal
'initialize' config.
validate (bool): Whether to validate the component config against the validate (bool): Whether to validate the component config against the
arguments and types expected by the factory. arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The new pipeline component.
DOCS: https://nightly.spacy.io/api/language#replace_pipe DOCS: https://nightly.spacy.io/api/language#replace_pipe
""" """
@ -886,14 +868,15 @@ class Language:
self.remove_pipe(name) self.remove_pipe(name)
if not len(self._components) or pipe_index == len(self._components): if not len(self._components) or pipe_index == len(self._components):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, init_config=init_config, validate=validate) return self.add_pipe(
factory_name, name=name, config=config, validate=validate
)
else: else:
self.add_pipe( return self.add_pipe(
factory_name, factory_name,
name=name, name=name,
before=pipe_index, before=pipe_index,
config=config, config=config,
init_config=init_config,
validate=validate, validate=validate,
) )
@ -1321,7 +1304,11 @@ class Language:
kwargs.setdefault("batch_size", batch_size) kwargs.setdefault("batch_size", batch_size)
# non-trainable components may have a pipe() implementation that refers to dummy # non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods # predict and set_annotations methods
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable(): if (
not hasattr(pipe, "pipe")
or not hasattr(pipe, "is_trainable")
or not pipe.is_trainable()
):
docs = _pipe(docs, pipe, kwargs) docs = _pipe(docs, pipe, kwargs)
else: else:
docs = pipe.pipe(docs, **kwargs) docs = pipe.pipe(docs, **kwargs)
@ -1433,7 +1420,11 @@ class Language:
kwargs.setdefault("batch_size", batch_size) kwargs.setdefault("batch_size", batch_size)
# non-trainable components may have a pipe() implementation that refers to dummy # non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods # predict and set_annotations methods
if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable(): if (
hasattr(proc, "pipe")
and hasattr(proc, "is_trainable")
and proc.is_trainable()
):
f = functools.partial(proc.pipe, **kwargs) f = functools.partial(proc.pipe, **kwargs)
else: else:
# Apply the function, but yield the doc # Apply the function, but yield the doc

View File

@ -142,6 +142,12 @@ class EntityLinker(Pipe):
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
self.kb = empty_kb(entity_vector_length)(self.vocab) self.kb = empty_kb(entity_vector_length)(self.vocab)
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
create it using this object's vocab."""
self.kb = kb_loader(self.vocab)
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
def validate_kb(self) -> None: def validate_kb(self) -> None:
# Raise an error if the knowledge base is not initialized. # Raise an error if the knowledge base is not initialized.
if len(self.kb) == 0: if len(self.kb) == 0:
@ -168,8 +174,7 @@ class EntityLinker(Pipe):
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
if kb_loader is not None: if kb_loader is not None:
self.kb = kb_loader(self.vocab) self.set_kb(kb_loader)
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
self.validate_kb() self.validate_kb()
nO = self.kb.entity_vector_length nO = self.kb.entity_vector_length
doc_sample = [] doc_sample = []

View File

@ -130,18 +130,9 @@ def test_kb_custom_length(nlp):
assert entity_linker.kb.entity_vector_length == 35 assert entity_linker.kb.entity_vector_length == 35
def test_kb_undefined(nlp): def test_kb_initialize_empty(nlp):
"""Test that the EL can't train without defining a KB""" """Test that the EL can't initialize without examples"""
entity_linker = nlp.add_pipe("entity_linker", config={}) entity_linker = nlp.add_pipe("entity_linker")
with pytest.raises(ValueError):
entity_linker.initialize(lambda: [])
def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB"""
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
assert len(entity_linker.kb) == 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
entity_linker.initialize(lambda: []) entity_linker.initialize(lambda: [])
@ -201,8 +192,6 @@ def test_el_pipe_configuration(nlp):
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns([pattern]) ruler.add_patterns([pattern])
@registry.misc.register("myAdamKB.v1")
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2]) kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
@ -212,15 +201,12 @@ def test_el_pipe_configuration(nlp):
) )
return kb return kb
return create_kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only # run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", "entity_linker",
config={"incl_context": False}, config={"incl_context": False},
init_config={"kb_loader": {"@misc": "myAdamKB.v1"}},
) )
nlp.initialize() entity_linker.set_kb(create_kb)
# With the default get_candidates function, matching is case-sensitive # With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same." text = "Douglas and douglas are not the same."
doc = nlp(text) doc = nlp(text)
@ -236,18 +222,15 @@ def test_el_pipe_configuration(nlp):
return get_lowercased_candidates return get_lowercased_candidates
# replace the pipe with a new one with with a different candidate generator # replace the pipe with a new one with with a different candidate generator
nlp.replace_pipe( entity_linker = nlp.replace_pipe(
"entity_linker", "entity_linker",
"entity_linker", "entity_linker",
config={ config={
"incl_context": False, "incl_context": False,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
}, },
init_config={
"kb_loader": {"@misc": "myAdamKB.v1"},
},
) )
nlp.initialize() entity_linker.set_kb(create_kb)
doc = nlp(text) doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2" assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == "" assert doc[1].ent_kb_id_ == ""
@ -339,8 +322,6 @@ def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""
vector_length = 1 vector_length = 1
@registry.misc.register("myLocationsKB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
# adding entities # adding entities
@ -351,8 +332,6 @@ def test_preserving_links_asdoc(nlp):
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
return mykb return mykb
return create_kb
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
patterns = [ patterns = [
@ -362,8 +341,8 @@ def test_preserving_links_asdoc(nlp):
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
config = {"incl_prior": False} config = {"incl_prior": False}
init_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}} entity_linker = nlp.add_pipe("entity_linker", config=config, last=True)
entity_linker = nlp.add_pipe("entity_linker", config=config, init_config=init_config, last=True) entity_linker.set_kb(create_kb)
nlp.initialize() nlp.initialize()
assert entity_linker.model.get_dim("nO") == vector_length assert entity_linker.model.get_dim("nO") == vector_length
@ -441,8 +420,6 @@ def test_overfitting_IO():
doc = nlp(text) doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation)) train_examples.append(Example.from_dict(doc, annotation))
@registry.misc.register("myOverfittingKB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's # create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer # Q2146908 (Russ Cochran): American golfer
@ -457,14 +434,12 @@ def test_overfitting_IO():
) )
return mykb return mykb
return create_kb
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", "entity_linker",
init_config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
last=True, last=True,
) )
entity_linker.set_kb(create_kb)
# train the NEL pipe # train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)

View File

@ -71,17 +71,13 @@ def tagger():
def entity_linker(): def entity_linker():
nlp = Language() nlp = Language()
@registry.misc.register("TestIssue5230KB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab): def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb return kb
return create_kb entity_linker = nlp.add_pipe("entity_linker")
entity_linker.set_kb(create_kb)
init_config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", init_config=init_config)
# need to add model for two reasons: # need to add model for two reasons:
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization # 2. the affected line is the one for model serialization

View File

@ -1,9 +1,9 @@
from typing import Callable from typing import Callable
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.util import ensure_path, registry, load_model_from_config
from spacy.util import ensure_path, registry
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from thinc.api import Config
from ..util import make_tempdir from ..util import make_tempdir
from numpy import zeros from numpy import zeros
@ -81,6 +81,28 @@ def _check_kb(kb):
def test_serialize_subclassed_kb(): def test_serialize_subclassed_kb():
"""Check that IO of a custom KB works fine as part of an EL pipe.""" """Check that IO of a custom KB works fine as part of an EL pipe."""
config_string = """
[nlp]
lang = "en"
pipeline = ["entity_linker"]
[components]
[components.entity_linker]
factory = "entity_linker"
[initialize]
[initialize.components]
[initialize.components.entity_linker]
[initialize.components.entity_linker.kb_loader]
@misc = "spacy.CustomKB.v1"
entity_vector_length = 342
custom_field = 666
"""
class SubKnowledgeBase(KnowledgeBase): class SubKnowledgeBase(KnowledgeBase):
def __init__(self, vocab, entity_vector_length, custom_field): def __init__(self, vocab, entity_vector_length, custom_field):
super().__init__(vocab, entity_vector_length) super().__init__(vocab, entity_vector_length)
@ -101,16 +123,11 @@ def test_serialize_subclassed_kb():
return custom_kb_factory return custom_kb_factory
nlp = English() config = Config().from_str(config_string)
config = { nlp = load_model_from_config(config, auto_fill=True)
"kb_loader": {
"@misc": "spacy.CustomKB.v1",
"entity_vector_length": 342,
"custom_field": 666,
}
}
entity_linker = nlp.add_pipe("entity_linker", init_config=config)
nlp.initialize() nlp.initialize()
entity_linker = nlp.get_pipe("entity_linker")
assert type(entity_linker.kb) == SubKnowledgeBase assert type(entity_linker.kb) == SubKnowledgeBase
assert entity_linker.kb.entity_vector_length == 342 assert entity_linker.kb.entity_vector_length == 342
assert entity_linker.kb.custom_field == 666 assert entity_linker.kb.custom_field == 666

View File

@ -524,7 +524,7 @@ Get a pipeline component for a given component name.
## Language.replace_pipe {#replace_pipe tag="method" new="2"} ## Language.replace_pipe {#replace_pipe tag="method" new="2"}
Replace a component in the pipeline. Replace a component in the pipeline and return the new component.
<Infobox title="Changed in v3.0" variant="warning"> <Infobox title="Changed in v3.0" variant="warning">
@ -538,7 +538,7 @@ and instead expects the **name of a component factory** registered using
> #### Example > #### Example
> >
> ```python > ```python
> nlp.replace_pipe("parser", my_custom_parser) > new_parser = nlp.replace_pipe("parser", "my_custom_parser")
> ``` > ```
| Name | Description | | Name | Description |
@ -548,6 +548,7 @@ and instead expects the **name of a component factory** registered using
| _keyword-only_ | | | _keyword-only_ | |
| `config` <Tag variant="new">3</Tag> | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ | | `config` <Tag variant="new">3</Tag> | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ |
| `validate` <Tag variant="new">3</Tag> | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ | | `validate` <Tag variant="new">3</Tag> | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ |
| **RETURNS** | The new pipeline component. ~~Callable[[Doc], Doc]~~ |
## Language.rename_pipe {#rename_pipe tag="method" new="2"} ## Language.rename_pipe {#rename_pipe tag="method" new="2"}