Merge pull request #6216 from svlandeg/feature/nel-initialize

This commit is contained in:
Ines Montani 2020-10-08 11:14:12 +02:00 committed by GitHub
commit 064575d79d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 137 additions and 107 deletions

View File

@ -843,7 +843,7 @@ class Language:
*,
config: Dict[str, Any] = SimpleFrozenDict(),
validate: bool = True,
) -> None:
) -> Callable[[Doc], Doc]:
"""Replace a component in the pipeline.
name (str): Name of the component to replace.
@ -852,6 +852,7 @@ class Language:
component. Will be merged with default config, if available.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
RETURNS (Callable[[Doc], Doc]): The new pipeline component.
DOCS: https://nightly.spacy.io/api/language#replace_pipe
"""
@ -866,9 +867,11 @@ class Language:
self.remove_pipe(name)
if not len(self._components) or pipe_index == len(self._components):
# we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name, config=config, validate=validate)
return self.add_pipe(
factory_name, name=name, config=config, validate=validate
)
else:
self.add_pipe(
return self.add_pipe(
factory_name,
name=name,
before=pipe_index,
@ -1300,7 +1303,11 @@ class Language:
kwargs.setdefault("batch_size", batch_size)
# non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if not hasattr(pipe, "pipe") or not hasattr(pipe, "is_trainable") or not pipe.is_trainable():
if (
not hasattr(pipe, "pipe")
or not hasattr(pipe, "is_trainable")
or not pipe.is_trainable()
):
docs = _pipe(docs, pipe, kwargs)
else:
docs = pipe.pipe(docs, **kwargs)
@ -1412,7 +1419,11 @@ class Language:
kwargs.setdefault("batch_size", batch_size)
# non-trainable components may have a pipe() implementation that refers to dummy
# predict and set_annotations methods
if hasattr(proc, "pipe") and hasattr(proc, "is_trainable") and proc.is_trainable():
if (
hasattr(proc, "pipe")
and hasattr(proc, "is_trainable")
and proc.is_trainable()
):
f = functools.partial(proc.pipe, **kwargs)
else:
# Apply the function, but yield the doc

View File

@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate
import warnings
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc
from .pipe import Pipe, deserialize_config
from ..language import Language
@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"],
default_config={
"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL,
"labels_discard": [],
"incl_prior": True,
"incl_context": True,
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
},
default_score_weights={
@ -58,11 +59,11 @@ def make_entity_linker(
nlp: Language,
name: str,
model: Model,
kb_loader: Callable[[Vocab], KnowledgeBase],
*,
labels_discard: Iterable[str],
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
):
"""Construct an EntityLinker component.
@ -70,19 +71,21 @@ def make_entity_linker(
model (Model[List[Doc], Floats2d]): A model that learns document vector
representations. Given a batch of Doc objects, it should return a single
array, with one row per item in the batch.
kb (KnowledgeBase): The knowledge-base to link entities to.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
"""
return EntityLinker(
nlp.vocab,
model,
name,
kb_loader=kb_loader,
labels_discard=labels_discard,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
)
@ -101,10 +104,10 @@ class EntityLinker(Pipe):
model: Model,
name: str = "entity_linker",
*,
kb_loader: Callable[[Vocab], KnowledgeBase],
labels_discard: Iterable[str],
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
) -> None:
"""Initialize an entity linker.
@ -113,10 +116,12 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
DOCS: https://nightly.spacy.io/api/entitylinker#init
"""
@ -127,15 +132,23 @@ class EntityLinker(Pipe):
"labels_discard": list(labels_discard),
"incl_prior": incl_prior,
"incl_context": incl_context,
"entity_vector_length": entity_vector_length,
}
self.kb = kb_loader(self.vocab)
self.get_candidates = get_candidates
self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account
self.n_sents = cfg.get("n_sents", 0)
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
self.kb = empty_kb(entity_vector_length)(self.vocab)
def _require_kb(self) -> None:
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
create it using this object's vocab."""
self.kb = kb_loader(self.vocab)
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
def validate_kb(self) -> None:
# Raise an error if the knowledge base is not initialized.
if len(self.kb) == 0:
raise ValueError(Errors.E139.format(name=self.name))
@ -145,6 +158,7 @@ class EntityLinker(Pipe):
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
kb_loader: Callable[[Vocab], KnowledgeBase] = None,
):
"""Initialize the pipe for training, using a representative set
of data examples.
@ -152,11 +166,16 @@ class EntityLinker(Pipe):
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
Note that providing this argument, will overwrite all data accumulated in the current KB.
Use this only when loading a KB as-such from file.
DOCS: https://nightly.spacy.io/api/entitylinker#initialize
"""
self._ensure_examples(get_examples)
self._require_kb()
if kb_loader is not None:
self.set_kb(kb_loader)
self.validate_kb()
nO = self.kb.entity_vector_length
doc_sample = []
vector_sample = []
@ -192,7 +211,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#update
"""
self._require_kb()
self.validate_kb()
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
@ -303,7 +322,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#predict
"""
self._require_kb()
self.validate_kb()
entity_count = 0
final_kb_ids = []
if not docs:

View File

@ -110,7 +110,7 @@ def test_kb_invalid_entity_vector(nlp):
def test_kb_default(nlp):
"""Test that the default (empty) KB is loaded when not providing a config"""
"""Test that the default (empty) KB is loaded upon construction"""
entity_linker = nlp.add_pipe("entity_linker", config={})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
@ -122,7 +122,7 @@ def test_kb_default(nlp):
def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe(
"entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
"entity_linker", config={"entity_vector_length": 35}
)
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
@ -130,18 +130,9 @@ def test_kb_custom_length(nlp):
assert entity_linker.kb.entity_vector_length == 35
def test_kb_undefined(nlp):
"""Test that the EL can't train without defining a KB"""
entity_linker = nlp.add_pipe("entity_linker", config={})
with pytest.raises(ValueError):
entity_linker.initialize(lambda: [])
def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB"""
config = {"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0
def test_kb_initialize_empty(nlp):
"""Test that the EL can't initialize without examples"""
entity_linker = nlp.add_pipe("entity_linker")
with pytest.raises(ValueError):
entity_linker.initialize(lambda: [])
@ -201,24 +192,21 @@ def test_el_pipe_configuration(nlp):
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns([pattern])
@registry.misc.register("myAdamKB.v1")
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
kb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
)
return kb
return create_kb
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
kb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
)
return kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe(
entity_linker = nlp.add_pipe(
"entity_linker",
config={"kb_loader": {"@misc": "myAdamKB.v1"}, "incl_context": False},
config={"incl_context": False},
)
entity_linker.set_kb(create_kb)
# With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same."
doc = nlp(text)
@ -234,15 +222,15 @@ def test_el_pipe_configuration(nlp):
return get_lowercased_candidates
# replace the pipe with a new one with with a different candidate generator
nlp.replace_pipe(
entity_linker = nlp.replace_pipe(
"entity_linker",
"entity_linker",
config={
"kb_loader": {"@misc": "myAdamKB.v1"},
"incl_context": False,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
},
)
entity_linker.set_kb(create_kb)
doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == ""
@ -334,19 +322,15 @@ def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""
vector_length = 1
@registry.misc.register("myLocationsKB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
# adding aliases
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
return mykb
return create_kb
def create_kb(vocab):
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
# adding aliases
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
return mykb
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
nlp.add_pipe("sentencizer")
@ -356,8 +340,9 @@ def test_preserving_links_asdoc(nlp):
]
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False}
entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True)
config = {"incl_prior": False}
entity_linker = nlp.add_pipe("entity_linker", config=config, last=True)
entity_linker.set_kb(create_kb)
nlp.initialize()
assert entity_linker.model.get_dim("nO") == vector_length
@ -435,30 +420,26 @@ def test_overfitting_IO():
doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation))
@registry.misc.register("myOverfittingKB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="Russ Cochran",
entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5],
)
return mykb
return create_kb
def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="Russ Cochran",
entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe(
"entity_linker",
config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
last=True,
)
entity_linker.set_kb(create_kb)
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)

View File

@ -71,17 +71,13 @@ def tagger():
def entity_linker():
nlp = Language()
@registry.misc.register("TestIssue5230KB.v1")
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb
return create_kb
config = {"kb_loader": {"@misc": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
entity_linker = nlp.add_pipe("entity_linker")
entity_linker.set_kb(create_kb)
# need to add model for two reasons:
# 1. no model leads to error in serialization,
# 2. the affected line is the one for model serialization

View File

@ -1,11 +1,12 @@
from typing import Callable
from spacy import util
from spacy.lang.en import English
from spacy.util import ensure_path, registry
from spacy.util import ensure_path, registry, load_model_from_config
from spacy.kb import KnowledgeBase
from thinc.api import Config
from ..util import make_tempdir
from numpy import zeros
def test_serialize_kb_disk(en_vocab):
@ -80,6 +81,28 @@ def _check_kb(kb):
def test_serialize_subclassed_kb():
"""Check that IO of a custom KB works fine as part of an EL pipe."""
config_string = """
[nlp]
lang = "en"
pipeline = ["entity_linker"]
[components]
[components.entity_linker]
factory = "entity_linker"
[initialize]
[initialize.components]
[initialize.components.entity_linker]
[initialize.components.entity_linker.kb_loader]
@misc = "spacy.CustomKB.v1"
entity_vector_length = 342
custom_field = 666
"""
class SubKnowledgeBase(KnowledgeBase):
def __init__(self, vocab, entity_vector_length, custom_field):
super().__init__(vocab, entity_vector_length)
@ -90,23 +113,21 @@ def test_serialize_subclassed_kb():
entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]:
def custom_kb_factory(vocab):
return SubKnowledgeBase(
kb = SubKnowledgeBase(
vocab=vocab,
entity_vector_length=entity_vector_length,
custom_field=custom_field,
)
kb.add_entity("random_entity", 0.0, zeros(entity_vector_length))
return kb
return custom_kb_factory
nlp = English()
config = {
"kb_loader": {
"@misc": "spacy.CustomKB.v1",
"entity_vector_length": 342,
"custom_field": 666,
}
}
entity_linker = nlp.add_pipe("entity_linker", config=config)
config = Config().from_str(config_string)
nlp = load_model_from_config(config, auto_fill=True)
nlp.initialize()
entity_linker = nlp.get_pipe("entity_linker")
assert type(entity_linker.kb) == SubKnowledgeBase
assert entity_linker.kb.entity_vector_length == 342
assert entity_linker.kb.custom_field == 666
@ -116,6 +137,7 @@ def test_serialize_subclassed_kb():
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker")
assert type(entity_linker2.kb) == SubKnowledgeBase
# After IO, the KB is the standard one
assert type(entity_linker2.kb) == KnowledgeBase
assert entity_linker2.kb.entity_vector_length == 342
assert entity_linker2.kb.custom_field == 666
assert not hasattr(entity_linker2.kb, "custom_field")

View File

@ -524,7 +524,7 @@ Get a pipeline component for a given component name.
## Language.replace_pipe {#replace_pipe tag="method" new="2"}
Replace a component in the pipeline.
Replace a component in the pipeline and return the new component.
<Infobox title="Changed in v3.0" variant="warning">
@ -538,7 +538,7 @@ and instead expects the **name of a component factory** registered using
> #### Example
>
> ```python
> nlp.replace_pipe("parser", my_custom_parser)
> new_parser = nlp.replace_pipe("parser", "my_custom_parser")
> ```
| Name | Description |
@ -548,6 +548,7 @@ and instead expects the **name of a component factory** registered using
| _keyword-only_ | |
| `config` <Tag variant="new">3</Tag> | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. ~~Optional[Dict[str, Any]]~~ |
| `validate` <Tag variant="new">3</Tag> | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ |
| **RETURNS** | The new pipeline component. ~~Callable[[Doc], Doc]~~ |
## Language.rename_pipe {#rename_pipe tag="method" new="2"}

View File

@ -297,7 +297,7 @@ packages. This lets one application easily customize the behavior of another, by
exposing an entry point in its `setup.py`. For a quick and fun intro to entry
points in Python, check out
[this excellent blog post](https://amir.rachum.com/blog/2017/07/28/python-entry-points/).
spaCy can load custom function from several different entry points to add
spaCy can load custom functions from several different entry points to add
pipeline component factories, language classes and other settings. To make spaCy
use your entry points, your package needs to expose them and it needs to be
installed in the same environment that's it.

View File

@ -395,7 +395,7 @@ type-check model definitions.
For data validation, spaCy v3.0 adopts
[`pydantic`](https://github.com/samuelcolvin/pydantic). It also powers the data
validation of Thinc's [config system](https://thinc.ai/docs/usage-config), which
lets you to register **custom functions with typed arguments**, reference them
lets you register **custom functions with typed arguments**, reference them
in your config and see validation errors if the argument values don't match.
<Infobox title="Details & Documentation" emoji="📖" list>