2020-08-18 17:10:36 +03:00
|
|
|
from typing import Optional, Callable, Iterable
|
2020-02-27 20:42:27 +03:00
|
|
|
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
|
|
|
|
from thinc.api import Model, Maxout, Linear
|
|
|
|
|
2020-02-28 13:57:41 +03:00
|
|
|
from ...util import registry
|
2020-08-18 17:10:36 +03:00
|
|
|
from ...kb import KnowledgeBase, Candidate, get_candidates
|
2020-05-20 12:41:12 +03:00
|
|
|
from ...vocab import Vocab
|
2020-02-27 20:42:27 +03:00
|
|
|
|
|
|
|
|
|
|
|
@registry.architectures.register("spacy.EntityLinker.v1")
|
2020-07-31 18:02:54 +03:00
|
|
|
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
2020-02-27 20:42:27 +03:00
|
|
|
with Model.define_operators({">>": chain, "**": clone}):
|
|
|
|
token_width = tok2vec.get_dim("nO")
|
|
|
|
output_layer = Linear(nO=nO, nI=token_width)
|
|
|
|
model = (
|
|
|
|
tok2vec
|
|
|
|
>> list2ragged()
|
|
|
|
>> reduce_mean()
|
|
|
|
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
|
|
|
|
>> output_layer
|
|
|
|
)
|
|
|
|
model.set_ref("output_layer", output_layer)
|
|
|
|
model.set_ref("tok2vec", tok2vec)
|
|
|
|
return model
|
2020-05-20 12:41:12 +03:00
|
|
|
|
|
|
|
|
|
|
|
@registry.assets.register("spacy.KBFromFile.v1")
|
2020-08-18 17:10:36 +03:00
|
|
|
def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
|
|
|
|
def kb_from_file(vocab):
|
|
|
|
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
|
|
|
kb.from_disk(kb_path)
|
|
|
|
return kb
|
|
|
|
return kb_from_file
|
2020-08-04 15:34:09 +03:00
|
|
|
|
|
|
|
|
|
|
|
@registry.assets.register("spacy.EmptyKB.v1")
|
2020-08-18 17:10:36 +03:00
|
|
|
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
|
|
|
def empty_kb_factory(vocab):
|
|
|
|
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
|
|
|
|
return empty_kb_factory
|
|
|
|
|
|
|
|
|
|
|
|
@registry.assets.register("spacy.CandidateGenerator.v1")
|
|
|
|
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
|
|
|
return get_candidates
|