mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
fix NEL config and IO, and n_sents functionality (#7100)
* fix NEL config and IO, and n_sents functionality * add docs * fix test
This commit is contained in:
parent
113e8d082b
commit
b92f81d5da
|
@ -45,6 +45,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
default_config={
|
default_config={
|
||||||
"model": DEFAULT_NEL_MODEL,
|
"model": DEFAULT_NEL_MODEL,
|
||||||
"labels_discard": [],
|
"labels_discard": [],
|
||||||
|
"n_sents": 0,
|
||||||
"incl_prior": True,
|
"incl_prior": True,
|
||||||
"incl_context": True,
|
"incl_context": True,
|
||||||
"entity_vector_length": 64,
|
"entity_vector_length": 64,
|
||||||
|
@ -62,6 +63,7 @@ def make_entity_linker(
|
||||||
model: Model,
|
model: Model,
|
||||||
*,
|
*,
|
||||||
labels_discard: Iterable[str],
|
labels_discard: Iterable[str],
|
||||||
|
n_sents: int,
|
||||||
incl_prior: bool,
|
incl_prior: bool,
|
||||||
incl_context: bool,
|
incl_context: bool,
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
|
@ -73,6 +75,7 @@ def make_entity_linker(
|
||||||
representations. Given a batch of Doc objects, it should return a single
|
representations. Given a batch of Doc objects, it should return a single
|
||||||
array, with one row per item in the batch.
|
array, with one row per item in the batch.
|
||||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||||
|
n_sents (int): The number of neighbouring sentences to take into account.
|
||||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||||
incl_context (bool): Whether or not to include the local context in the model.
|
incl_context (bool): Whether or not to include the local context in the model.
|
||||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||||
|
@ -84,6 +87,7 @@ def make_entity_linker(
|
||||||
model,
|
model,
|
||||||
name,
|
name,
|
||||||
labels_discard=labels_discard,
|
labels_discard=labels_discard,
|
||||||
|
n_sents=n_sents,
|
||||||
incl_prior=incl_prior,
|
incl_prior=incl_prior,
|
||||||
incl_context=incl_context,
|
incl_context=incl_context,
|
||||||
entity_vector_length=entity_vector_length,
|
entity_vector_length=entity_vector_length,
|
||||||
|
@ -106,6 +110,7 @@ class EntityLinker(TrainablePipe):
|
||||||
name: str = "entity_linker",
|
name: str = "entity_linker",
|
||||||
*,
|
*,
|
||||||
labels_discard: Iterable[str],
|
labels_discard: Iterable[str],
|
||||||
|
n_sents: int,
|
||||||
incl_prior: bool,
|
incl_prior: bool,
|
||||||
incl_context: bool,
|
incl_context: bool,
|
||||||
entity_vector_length: int,
|
entity_vector_length: int,
|
||||||
|
@ -118,6 +123,7 @@ class EntityLinker(TrainablePipe):
|
||||||
name (str): The component instance name, used to add entries to the
|
name (str): The component instance name, used to add entries to the
|
||||||
losses during training.
|
losses during training.
|
||||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||||
|
n_sents (int): The number of neighbouring sentences to take into account.
|
||||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||||
incl_context (bool): Whether or not to include the local context in the model.
|
incl_context (bool): Whether or not to include the local context in the model.
|
||||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||||
|
@ -129,17 +135,14 @@ class EntityLinker(TrainablePipe):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
cfg = {
|
self.labels_discard = list(labels_discard)
|
||||||
"labels_discard": list(labels_discard),
|
self.n_sents = n_sents
|
||||||
"incl_prior": incl_prior,
|
self.incl_prior = incl_prior
|
||||||
"incl_context": incl_context,
|
self.incl_context = incl_context
|
||||||
"entity_vector_length": entity_vector_length,
|
|
||||||
}
|
|
||||||
self.get_candidates = get_candidates
|
self.get_candidates = get_candidates
|
||||||
self.cfg = dict(cfg)
|
self.cfg = {}
|
||||||
self.distance = CosineDistance(normalize=False)
|
self.distance = CosineDistance(normalize=False)
|
||||||
# how many neightbour sentences to take into account
|
# how many neightbour sentences to take into account
|
||||||
self.n_sents = cfg.get("n_sents", 0)
|
|
||||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||||
|
|
||||||
|
@ -150,7 +153,6 @@ class EntityLinker(TrainablePipe):
|
||||||
raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
|
raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
|
||||||
|
|
||||||
self.kb = kb_loader(self.vocab)
|
self.kb = kb_loader(self.vocab)
|
||||||
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
|
|
||||||
|
|
||||||
def validate_kb(self) -> None:
|
def validate_kb(self) -> None:
|
||||||
# Raise an error if the knowledge base is not initialized.
|
# Raise an error if the knowledge base is not initialized.
|
||||||
|
@ -312,14 +314,13 @@ class EntityLinker(TrainablePipe):
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
xp = self.model.ops.xp
|
xp = self.model.ops.xp
|
||||||
if self.cfg.get("incl_context"):
|
if self.incl_context:
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
sentence_encoding_t = sentence_encoding.T
|
sentence_encoding_t = sentence_encoding.T
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
for ent in sent.ents:
|
for ent in sent.ents:
|
||||||
entity_count += 1
|
entity_count += 1
|
||||||
to_discard = self.cfg.get("labels_discard", [])
|
if ent.label_ in self.labels_discard:
|
||||||
if to_discard and ent.label_ in to_discard:
|
|
||||||
# ignoring this entity - setting to NIL
|
# ignoring this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
else:
|
else:
|
||||||
|
@ -337,13 +338,13 @@ class EntityLinker(TrainablePipe):
|
||||||
prior_probs = xp.asarray(
|
prior_probs = xp.asarray(
|
||||||
[c.prior_prob for c in candidates]
|
[c.prior_prob for c in candidates]
|
||||||
)
|
)
|
||||||
if not self.cfg.get("incl_prior"):
|
if not self.incl_prior:
|
||||||
prior_probs = xp.asarray(
|
prior_probs = xp.asarray(
|
||||||
[0.0 for _ in candidates]
|
[0.0 for _ in candidates]
|
||||||
)
|
)
|
||||||
scores = prior_probs
|
scores = prior_probs
|
||||||
# add in similarity from the context
|
# add in similarity from the context
|
||||||
if self.cfg.get("incl_context"):
|
if self.incl_context:
|
||||||
entity_encodings = xp.asarray(
|
entity_encodings = xp.asarray(
|
||||||
[c.entity_vector for c in candidates]
|
[c.entity_vector for c in candidates]
|
||||||
)
|
)
|
||||||
|
|
|
@ -250,6 +250,14 @@ def test_el_pipe_configuration(nlp):
|
||||||
assert doc[2].ent_kb_id_ == "Q2"
|
assert doc[2].ent_kb_id_ == "Q2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_nel_nsents(nlp):
|
||||||
|
"""Test that n_sents can be set through the configuration"""
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||||
|
assert entity_linker.n_sents == 0
|
||||||
|
entity_linker = nlp.replace_pipe("entity_linker", "entity_linker", config={"n_sents": 2})
|
||||||
|
assert entity_linker.n_sents == 2
|
||||||
|
|
||||||
|
|
||||||
def test_vocab_serialization(nlp):
|
def test_vocab_serialization(nlp):
|
||||||
"""Test that string information is retained across storage"""
|
"""Test that string information is retained across storage"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
|
@ -83,9 +83,9 @@ def test_replace_last_pipe(nlp):
|
||||||
def test_replace_pipe_config(nlp):
|
def test_replace_pipe_config(nlp):
|
||||||
nlp.add_pipe("entity_linker")
|
nlp.add_pipe("entity_linker")
|
||||||
nlp.add_pipe("sentencizer")
|
nlp.add_pipe("sentencizer")
|
||||||
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is True
|
assert nlp.get_pipe("entity_linker").incl_prior is True
|
||||||
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
|
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
|
||||||
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] is False
|
assert nlp.get_pipe("entity_linker").incl_prior is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
|
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
|
||||||
|
|
|
@ -31,6 +31,7 @@ architectures and their arguments and hyperparameters.
|
||||||
> from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL
|
> from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL
|
||||||
> config = {
|
> config = {
|
||||||
> "labels_discard": [],
|
> "labels_discard": [],
|
||||||
|
> "n_sents": 0,
|
||||||
> "incl_prior": True,
|
> "incl_prior": True,
|
||||||
> "incl_context": True,
|
> "incl_context": True,
|
||||||
> "model": DEFAULT_NEL_MODEL,
|
> "model": DEFAULT_NEL_MODEL,
|
||||||
|
@ -43,6 +44,7 @@ architectures and their arguments and hyperparameters.
|
||||||
| Setting | Description |
|
| Setting | Description |
|
||||||
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
|
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||||
|
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
|
||||||
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
|
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
|
||||||
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
|
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
|
||||||
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
|
||||||
|
@ -89,6 +91,7 @@ custom knowledge base, you should either call
|
||||||
| `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ |
|
| `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ |
|
||||||
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
|
||||||
| `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ |
|
| `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ |
|
||||||
|
| `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ |
|
||||||
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ |
|
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ |
|
||||||
| `incl_context` | Whether or not to include the local context in the model. ~~bool~~ |
|
| `incl_context` | Whether or not to include the local context in the model. ~~bool~~ |
|
||||||
|
|
||||||
|
@ -248,7 +251,7 @@ pipe's entity linking model and context encoder. Delegates to
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||||
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `drop` | The dropout rate. ~~float~~ |
|
| `drop` | The dropout rate. ~~float~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user