mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
fix NEL config and IO, and n_sents functionality (#7100)
* fix NEL config and IO, and n_sents functionality * add docs * fix test
This commit is contained in:
parent
113e8d082b
commit
b92f81d5da
|
@ -45,6 +45,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
default_config={
|
||||
"model": DEFAULT_NEL_MODEL,
|
||||
"labels_discard": [],
|
||||
"n_sents": 0,
|
||||
"incl_prior": True,
|
||||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
|
@ -62,6 +63,7 @@ def make_entity_linker(
|
|||
model: Model,
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
n_sents: int,
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
|
@ -73,6 +75,7 @@ def make_entity_linker(
|
|||
representations. Given a batch of Doc objects, it should return a single
|
||||
array, with one row per item in the batch.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
n_sents (int): The number of neighbouring sentences to take into account.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
|
@ -84,6 +87,7 @@ def make_entity_linker(
|
|||
model,
|
||||
name,
|
||||
labels_discard=labels_discard,
|
||||
n_sents=n_sents,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
|
@ -106,6 +110,7 @@ class EntityLinker(TrainablePipe):
|
|||
name: str = "entity_linker",
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
n_sents: int,
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
|
@ -118,6 +123,7 @@ class EntityLinker(TrainablePipe):
|
|||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
n_sents (int): The number of neighbouring sentences to take into account.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
|
@ -129,17 +135,14 @@ class EntityLinker(TrainablePipe):
|
|||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
cfg = {
|
||||
"labels_discard": list(labels_discard),
|
||||
"incl_prior": incl_prior,
|
||||
"incl_context": incl_context,
|
||||
"entity_vector_length": entity_vector_length,
|
||||
}
|
||||
self.labels_discard = list(labels_discard)
|
||||
self.n_sents = n_sents
|
||||
self.incl_prior = incl_prior
|
||||
self.incl_context = incl_context
|
||||
self.get_candidates = get_candidates
|
||||
self.cfg = dict(cfg)
|
||||
self.cfg = {}
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neightbour sentences to take into account
|
||||
self.n_sents = cfg.get("n_sents", 0)
|
||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||
|
||||
|
@ -150,7 +153,6 @@ class EntityLinker(TrainablePipe):
|
|||
raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
|
||||
|
||||
self.kb = kb_loader(self.vocab)
|
||||
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
|
||||
|
||||
def validate_kb(self) -> None:
|
||||
# Raise an error if the knowledge base is not initialized.
|
||||
|
@ -312,14 +314,13 @@ class EntityLinker(TrainablePipe):
|
|||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
xp = self.model.ops.xp
|
||||
if self.cfg.get("incl_context"):
|
||||
if self.incl_context:
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
for ent in sent.ents:
|
||||
entity_count += 1
|
||||
to_discard = self.cfg.get("labels_discard", [])
|
||||
if to_discard and ent.label_ in to_discard:
|
||||
if ent.label_ in self.labels_discard:
|
||||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
else:
|
||||
|
@ -337,13 +338,13 @@ class EntityLinker(TrainablePipe):
|
|||
prior_probs = xp.asarray(
|
||||
[c.prior_prob for c in candidates]
|
||||
)
|
||||
if not self.cfg.get("incl_prior"):
|
||||
if not self.incl_prior:
|
||||
prior_probs = xp.asarray(
|
||||
[0.0 for _ in candidates]
|
||||
)
|
||||
scores = prior_probs
|
||||
# add in similarity from the context
|
||||
if self.cfg.get("incl_context"):
|
||||
if self.incl_context:
|
||||
entity_encodings = xp.asarray(
|
||||
[c.entity_vector for c in candidates]
|
||||
)
|
||||
|
|
|
@ -250,6 +250,14 @@ def test_el_pipe_configuration(nlp):
|
|||
assert doc[2].ent_kb_id_ == "Q2"
|
||||
|
||||
|
||||
def test_nel_nsents(nlp):
|
||||
"""Test that n_sents can be set through the configuration"""
|
||||
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||
assert entity_linker.n_sents == 0
|
||||
entity_linker = nlp.replace_pipe("entity_linker", "entity_linker", config={"n_sents": 2})
|
||||
assert entity_linker.n_sents == 2
|
||||
|
||||
|
||||
def test_vocab_serialization(nlp):
|
||||
"""Test that string information is retained across storage"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
|
|
@ -83,9 +83,9 @@ def test_replace_last_pipe(nlp):
|
|||
def test_replace_pipe_config(nlp):
|
||||
nlp.add_pipe("entity_linker")
|
||||
nlp.add_pipe("sentencizer")
|
||||
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is True
|
||||
assert nlp.get_pipe("entity_linker").incl_prior is True
|
||||
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
|
||||
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is False
|
||||
assert nlp.get_pipe("entity_linker").incl_prior is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
|
||||
|
|
|
@ -31,6 +31,7 @@ architectures and their arguments and hyperparameters.
|
|||
> from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL
|
||||
> config = {
|
||||
> "labels_discard": [],
|
||||
> "n_sents": 0,
|
||||
> "incl_prior": True,
|
||||
> "incl_context": True,
|
||||
> "model": DEFAULT_NEL_MODEL,
|
||||
|
@ -43,6 +44,7 @@ architectures and their arguments and hyperparameters.
|
|||
| Setting | Description |
|
||||
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
|
||||
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
|
||||
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
|
||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
||||
|
@ -89,6 +91,7 @@ custom knowledge base, you should either call
|
|||
| `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ |
|
||||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||
| `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ |
|
||||
| `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ |
|
||||
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ |
|
||||
| `incl_context` | Whether or not to include the local context in the model. ~~bool~~ |
|
||||
|
||||
|
@ -248,7 +251,7 @@ pipe's entity linking model and context encoder. Delegates to
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||
| _keyword-only_ | |
|
||||
| `drop` | The dropout rate. ~~float~~ |
|
||||
|
|
Loading…
Reference in New Issue
Block a user