fix NEL config and IO, and n_sents functionality (#7100)

* fix NEL config and IO, and n_sents functionality

* add docs

* fix test
This commit is contained in:
Sofie Van Landeghem 2021-02-22 04:49:52 +01:00 committed by GitHub
parent 113e8d082b
commit b92f81d5da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 24 deletions

View File

@ -45,6 +45,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_config={ default_config={
"model": DEFAULT_NEL_MODEL, "model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"n_sents": 0,
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"entity_vector_length": 64, "entity_vector_length": 64,
@ -62,6 +63,7 @@ def make_entity_linker(
model: Model, model: Model,
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
n_sents: int,
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
@ -73,6 +75,7 @@ def make_entity_linker(
representations. Given a batch of Doc objects, it should return a single representations. Given a batch of Doc objects, it should return a single
array, with one row per item in the batch. array, with one row per item in the batch.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
n_sents (int): The number of neighbouring sentences to take into account.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
@ -84,6 +87,7 @@ def make_entity_linker(
model, model,
name, name,
labels_discard=labels_discard, labels_discard=labels_discard,
n_sents=n_sents,
incl_prior=incl_prior, incl_prior=incl_prior,
incl_context=incl_context, incl_context=incl_context,
entity_vector_length=entity_vector_length, entity_vector_length=entity_vector_length,
@ -106,6 +110,7 @@ class EntityLinker(TrainablePipe):
name: str = "entity_linker", name: str = "entity_linker",
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
n_sents: int,
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
@ -118,6 +123,7 @@ class EntityLinker(TrainablePipe):
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
n_sents (int): The number of neighbouring sentences to take into account.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
@ -129,17 +135,14 @@ class EntityLinker(TrainablePipe):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.name = name self.name = name
cfg = { self.labels_discard = list(labels_discard)
"labels_discard": list(labels_discard), self.n_sents = n_sents
"incl_prior": incl_prior, self.incl_prior = incl_prior
"incl_context": incl_context, self.incl_context = incl_context
"entity_vector_length": entity_vector_length,
}
self.get_candidates = get_candidates self.get_candidates = get_candidates
self.cfg = dict(cfg) self.cfg = {}
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account # how many neightbour sentences to take into account
self.n_sents = cfg.get("n_sents", 0)
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
self.kb = empty_kb(entity_vector_length)(self.vocab) self.kb = empty_kb(entity_vector_length)(self.vocab)
@ -150,7 +153,6 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E885.format(arg_type=type(kb_loader))) raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
self.kb = kb_loader(self.vocab) self.kb = kb_loader(self.vocab)
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
def validate_kb(self) -> None: def validate_kb(self) -> None:
# Raise an error if the knowledge base is not initialized. # Raise an error if the knowledge base is not initialized.
@ -312,14 +314,13 @@ class EntityLinker(TrainablePipe):
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
xp = self.model.ops.xp xp = self.model.ops.xp
if self.cfg.get("incl_context"): if self.incl_context:
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents: for ent in sent.ents:
entity_count += 1 entity_count += 1
to_discard = self.cfg.get("labels_discard", []) if ent.label_ in self.labels_discard:
if to_discard and ent.label_ in to_discard:
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
else: else:
@ -337,13 +338,13 @@ class EntityLinker(TrainablePipe):
prior_probs = xp.asarray( prior_probs = xp.asarray(
[c.prior_prob for c in candidates] [c.prior_prob for c in candidates]
) )
if not self.cfg.get("incl_prior"): if not self.incl_prior:
prior_probs = xp.asarray( prior_probs = xp.asarray(
[0.0 for _ in candidates] [0.0 for _ in candidates]
) )
scores = prior_probs scores = prior_probs
# add in similarity from the context # add in similarity from the context
if self.cfg.get("incl_context"): if self.incl_context:
entity_encodings = xp.asarray( entity_encodings = xp.asarray(
[c.entity_vector for c in candidates] [c.entity_vector for c in candidates]
) )

View File

@ -250,6 +250,14 @@ def test_el_pipe_configuration(nlp):
assert doc[2].ent_kb_id_ == "Q2" assert doc[2].ent_kb_id_ == "Q2"
def test_nel_nsents(nlp):
"""Test that n_sents can be set through the configuration"""
entity_linker = nlp.add_pipe("entity_linker", config={})
assert entity_linker.n_sents == 0
entity_linker = nlp.replace_pipe("entity_linker", "entity_linker", config={"n_sents": 2})
assert entity_linker.n_sents == 2
def test_vocab_serialization(nlp): def test_vocab_serialization(nlp):
"""Test that string information is retained across storage""" """Test that string information is retained across storage"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)

View File

@ -83,9 +83,9 @@ def test_replace_last_pipe(nlp):
def test_replace_pipe_config(nlp): def test_replace_pipe_config(nlp):
nlp.add_pipe("entity_linker") nlp.add_pipe("entity_linker")
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is True assert nlp.get_pipe("entity_linker").incl_prior is True
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False}) nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is False assert nlp.get_pipe("entity_linker").incl_prior is False
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")]) @pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])

View File

@ -31,6 +31,7 @@ architectures and their arguments and hyperparameters.
> from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL > from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL
> config = { > config = {
> "labels_discard": [], > "labels_discard": [],
> "n_sents": 0,
> "incl_prior": True, > "incl_prior": True,
> "incl_context": True, > "incl_context": True,
> "model": DEFAULT_NEL_MODEL, > "model": DEFAULT_NEL_MODEL,
@ -43,6 +44,7 @@ architectures and their arguments and hyperparameters.
| Setting | Description | | Setting | Description |
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ | | `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ | | `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
@ -89,6 +91,7 @@ custom knowledge base, you should either call
| `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ | | `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ | | `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ |
| `incl_context` | Whether or not to include the local context in the model. ~~bool~~ | | `incl_context` | Whether or not to include the local context in the model. ~~bool~~ |
@ -248,7 +251,7 @@ pipe's entity linking model and context encoder. Delegates to
> ``` > ```
| Name | Description | | Name | Description |
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------ |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ | | `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ | | `drop` | The dropout rate. ~~float~~ |