mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Merge pull request #3459 from svlandeg/feature/el-framework
Basic framework and APIs for entity linker
This commit is contained in:
commit
68900066e0
71
examples/pipeline/dummy_entity_linking.py
Normal file
71
examples/pipeline/dummy_entity_linking.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
"""Demonstrate how to build a simple knowledge base and run an Entity Linking algorithm.
|
||||
Currently still a bit of a dummy algorithm: taking simply the entity with highest probability for a given alias
|
||||
"""
|
||||
import spacy
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab=vocab)
|
||||
|
||||
# adding entities
|
||||
entity_0 = "Q1004791_Douglas"
|
||||
print("adding entity", entity_0)
|
||||
kb.add_entity(entity=entity_0, prob=0.5)
|
||||
|
||||
entity_1 = "Q42_Douglas_Adams"
|
||||
print("adding entity", entity_1)
|
||||
kb.add_entity(entity=entity_1, prob=0.5)
|
||||
|
||||
entity_2 = "Q5301561_Douglas_Haig"
|
||||
print("adding entity", entity_2)
|
||||
kb.add_entity(entity=entity_2, prob=0.5)
|
||||
|
||||
# adding aliases
|
||||
print()
|
||||
alias_0 = "Douglas"
|
||||
print("adding alias", alias_0)
|
||||
kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.1, 0.6, 0.2])
|
||||
|
||||
alias_1 = "Douglas Adams"
|
||||
print("adding alias", alias_1)
|
||||
kb.add_alias(alias=alias_1, entities=[entity_1], probabilities=[0.9])
|
||||
|
||||
print()
|
||||
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
||||
|
||||
return kb
|
||||
|
||||
|
||||
def add_el(kb, nlp):
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": kb})
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
||||
for alias in ["Douglas Adams", "Douglas"]:
|
||||
candidates = nlp.linker.kb.get_candidates(alias)
|
||||
print()
|
||||
print(len(candidates), "candidate(s) for", alias, ":")
|
||||
for c in candidates:
|
||||
print(" ", c.entity_, c.prior_prob)
|
||||
|
||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||
"Douglas reminds us to always bring our towel. " \
|
||||
"The main character in Doug's novel is called Arthur Dent."
|
||||
doc = nlp(text)
|
||||
|
||||
print()
|
||||
for token in doc:
|
||||
print("token", token.text, token.ent_type_, token.ent_kb_id_)
|
||||
|
||||
print()
|
||||
for ent in doc.ents:
|
||||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
nlp = spacy.load('en_core_web_sm')
|
||||
my_kb = create_kb(nlp.vocab)
|
||||
add_el(my_kb, nlp)
|
1
setup.py
1
setup.py
|
@ -40,6 +40,7 @@ MOD_NAMES = [
|
|||
"spacy.lexeme",
|
||||
"spacy.vocab",
|
||||
"spacy.attrs",
|
||||
"spacy.kb",
|
||||
"spacy.morphology",
|
||||
"spacy.pipeline.pipes",
|
||||
"spacy.syntax.stateclass",
|
||||
|
|
|
@ -80,6 +80,8 @@ class Warnings(object):
|
|||
"the v2.x models cannot release the global interpreter lock. "
|
||||
"Future versions may introduce a `n_process` argument for "
|
||||
"parallel inference via multiprocessing.")
|
||||
W017 = ("Alias '{alias}' already exists in the Knowledge base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge base.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
@ -371,6 +373,16 @@ class Errors(object):
|
|||
"with spacy >= 2.1.0. To fix this, reinstall Python and use a wide "
|
||||
"unicode build instead. You can also rebuild Python and set the "
|
||||
"--enable-unicode=ucs4 flag.")
|
||||
E131 = ("Cannot write the kb_id of an existing Span object because a Span "
|
||||
"is a read-only view of the underlying Token objects stored in the Doc. "
|
||||
"Instead, create a new Span object and specify the `kb_id` keyword argument, "
|
||||
"for example:\nfrom spacy.tokens import Span\n"
|
||||
"span = Span(doc, start={start}, end={end}, label='{label}', kb_id='{kb_id}')")
|
||||
E132 = ("The vectors for entities and probabilities for alias '{alias}' should have equal length, "
|
||||
"but found {entities_length} and {probabilities_length} respectively.")
|
||||
E133 = ("The sum of prior probabilities for alias '{alias}' should not exceed 1, "
|
||||
"but found {sum}.")
|
||||
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
148
spacy/kb.pxd
Normal file
148
spacy/kb.pxd
Normal file
|
@ -0,0 +1,148 @@
|
|||
"""Knowledge-base for entity or concept linking."""
|
||||
from cymem.cymem cimport Pool
|
||||
from preshed.maps cimport PreshMap
|
||||
from libcpp.vector cimport vector
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
|
||||
from spacy.vocab cimport Vocab
|
||||
from .typedefs cimport hash_t
|
||||
|
||||
|
||||
# Internal struct, for storage and disambiguation. This isn't what we return
|
||||
# to the user as the answer to "here's your entity". It's the minimum number
|
||||
# of bits we need to keep track of the answers.
|
||||
cdef struct _EntryC:
|
||||
|
||||
# The hash of this entry's unique ID and name in the kB
|
||||
hash_t entity_hash
|
||||
|
||||
# Allows retrieval of one or more vectors.
|
||||
# Each element of vector_rows should be an index into a vectors table.
|
||||
# Every entry should have the same number of vectors, so we can avoid storing
|
||||
# the number of vectors in each knowledge-base struct
|
||||
int32_t* vector_rows
|
||||
|
||||
# Allows retrieval of a struct of non-vector features. We could make this a
|
||||
# pointer, but we have 32 bits left over in the struct after prob, so we'd
|
||||
# like this to only be 32 bits. We can also set this to -1, for the common
|
||||
# case where there are no features.
|
||||
int32_t feats_row
|
||||
|
||||
# log probability of entity, based on corpus frequency
|
||||
float prob
|
||||
|
||||
|
||||
# Each alias struct stores a list of Entry pointers with their prior probabilities
|
||||
# for this specific mention/alias.
|
||||
cdef struct _AliasC:
|
||||
|
||||
# All entry candidates for this alias
|
||||
vector[int64_t] entry_indices
|
||||
|
||||
# Prior probability P(entity|alias) - should sum up to (at most) 1.
|
||||
vector[float] probs
|
||||
|
||||
|
||||
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
|
||||
cdef class Candidate:
|
||||
|
||||
cdef readonly KnowledgeBase kb
|
||||
cdef hash_t entity_hash
|
||||
cdef hash_t alias_hash
|
||||
cdef float prior_prob
|
||||
|
||||
|
||||
cdef class KnowledgeBase:
|
||||
cdef Pool mem
|
||||
cpdef readonly Vocab vocab
|
||||
|
||||
# This maps 64bit keys (hash of unique entity string)
|
||||
# to 64bit values (position of the _EntryC struct in the _entries vector).
|
||||
# The PreshMap is pretty space efficient, as it uses open addressing. So
|
||||
# the only overhead is the vacancy rate, which is approximately 30%.
|
||||
cdef PreshMap _entry_index
|
||||
|
||||
# Each entry takes 128 bits, and again we'll have a 30% or so overhead for
|
||||
# over allocation.
|
||||
# In total we end up with (N*128*1.3)+(N*128*1.3) bits for N entries.
|
||||
# Storing 1m entries would take 41.6mb under this scheme.
|
||||
cdef vector[_EntryC] _entries
|
||||
|
||||
# This maps 64bit keys (hash of unique alias string)
|
||||
# to 64bit values (position of the _AliasC struct in the _aliases_table vector).
|
||||
cdef PreshMap _alias_index
|
||||
|
||||
# This should map mention hashes to (entry_id, prob) tuples. The probability
|
||||
# should be P(entity | mention), which is pretty important to know.
|
||||
# We can pack both pieces of information into a 64-bit value, to keep things
|
||||
# efficient.
|
||||
cdef vector[_AliasC] _aliases_table
|
||||
|
||||
# This is the part which might take more space: storing various
|
||||
# categorical features for the entries, and storing vectors for disambiguation
|
||||
# and possibly usage.
|
||||
# If each entry gets a 300-dimensional vector, for 1m entries we would need
|
||||
# 1.2gb. That gets expensive fast. What might be better is to avoid learning
|
||||
# a unique vector for every entity. We could instead have a compositional
|
||||
# model, that embeds different features of the entities into vectors. We'll
|
||||
# still want some per-entity features, like the Wikipedia text or entity
|
||||
# co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions.
|
||||
cdef object _vectors_table
|
||||
|
||||
# It's very useful to track categorical features, at least for output, even
|
||||
# if they're not useful in the model itself. For instance, we should be
|
||||
# able to track stuff like a person's date of birth or whatever. This can
|
||||
# easily make the KB bigger, but if this isn't needed by the model, and it's
|
||||
# optional data, we can let users configure a DB as the backend for this.
|
||||
cdef object _features_table
|
||||
|
||||
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
|
||||
int32_t* vector_rows, int feats_row):
|
||||
"""Add an entry to the knowledge base."""
|
||||
# This is what we'll map the hash key to. It's where the entry will sit
|
||||
# in the vector of entries, so we can get it later.
|
||||
cdef int64_t new_index = self._entries.size()
|
||||
self._entries.push_back(
|
||||
_EntryC(
|
||||
entity_hash=entity_hash,
|
||||
vector_rows=vector_rows,
|
||||
feats_row=feats_row,
|
||||
prob=prob
|
||||
))
|
||||
self._entry_index[entity_hash] = new_index
|
||||
return new_index
|
||||
|
||||
cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs):
|
||||
"""Connect a mention to a list of potential entities with their prior probabilities ."""
|
||||
cdef int64_t new_index = self._aliases_table.size()
|
||||
|
||||
self._aliases_table.push_back(
|
||||
_AliasC(
|
||||
entry_indices=entry_indices,
|
||||
probs=probs
|
||||
))
|
||||
self._alias_index[alias_hash] = new_index
|
||||
return new_index
|
||||
|
||||
cdef inline _create_empty_vectors(self):
|
||||
"""
|
||||
Making sure the first element of each vector is a dummy,
|
||||
because the PreshMap maps pointing to indices in these vectors can not contain 0 as value
|
||||
cf. https://github.com/explosion/preshed/issues/17
|
||||
"""
|
||||
cdef int32_t dummy_value = 0
|
||||
self.vocab.strings.add("")
|
||||
self._entries.push_back(
|
||||
_EntryC(
|
||||
entity_hash=self.vocab.strings[""],
|
||||
vector_rows=&dummy_value,
|
||||
feats_row=dummy_value,
|
||||
prob=dummy_value
|
||||
))
|
||||
self._aliases_table.push_back(
|
||||
_AliasC(
|
||||
entry_indices=[dummy_value],
|
||||
probs=[dummy_value]
|
||||
))
|
||||
|
||||
|
131
spacy/kb.pyx
Normal file
131
spacy/kb.pyx
Normal file
|
@ -0,0 +1,131 @@
|
|||
# cython: profile=True
|
||||
# coding: utf8
|
||||
from spacy.errors import Errors, Warnings, user_warning
|
||||
|
||||
|
||||
cdef class Candidate:
|
||||
|
||||
def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob):
|
||||
self.kb = kb
|
||||
self.entity_hash = entity_hash
|
||||
self.alias_hash = alias_hash
|
||||
self.prior_prob = prior_prob
|
||||
|
||||
@property
|
||||
def entity(self):
|
||||
"""RETURNS (uint64): hash of the entity's KB ID/name"""
|
||||
return self.entity_hash
|
||||
|
||||
@property
|
||||
def entity_(self):
|
||||
"""RETURNS (unicode): ID/name of this entity in the KB"""
|
||||
return self.kb.vocab.strings[self.entity]
|
||||
|
||||
@property
|
||||
def alias(self):
|
||||
"""RETURNS (uint64): hash of the alias"""
|
||||
return self.alias_hash
|
||||
|
||||
@property
|
||||
def alias_(self):
|
||||
"""RETURNS (unicode): ID of the original alias"""
|
||||
return self.kb.vocab.strings[self.alias]
|
||||
|
||||
@property
|
||||
def prior_prob(self):
|
||||
return self.prior_prob
|
||||
|
||||
|
||||
cdef class KnowledgeBase:
|
||||
|
||||
def __init__(self, Vocab vocab):
|
||||
self.vocab = vocab
|
||||
self._entry_index = PreshMap()
|
||||
self._alias_index = PreshMap()
|
||||
self.mem = Pool()
|
||||
self._create_empty_vectors()
|
||||
|
||||
def __len__(self):
|
||||
return self.get_size_entities()
|
||||
|
||||
def get_size_entities(self):
|
||||
return self._entries.size() - 1 # not counting dummy element on index 0
|
||||
|
||||
def get_size_aliases(self):
|
||||
return self._aliases_table.size() - 1 # not counting dummy element on index 0
|
||||
|
||||
def add_entity(self, unicode entity, float prob=0.5, vectors=None, features=None):
|
||||
"""
|
||||
Add an entity to the KB.
|
||||
Return the hash of the entity ID at the end
|
||||
"""
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
|
||||
# Return if this entity was added before
|
||||
if entity_hash in self._entry_index:
|
||||
user_warning(Warnings.W018.format(entity=entity))
|
||||
return
|
||||
|
||||
cdef int32_t dummy_value = 342
|
||||
self.c_add_entity(entity_hash=entity_hash, prob=prob,
|
||||
vector_rows=&dummy_value, feats_row=dummy_value)
|
||||
# TODO self._vectors_table.get_pointer(vectors),
|
||||
# self._features_table.get(features))
|
||||
|
||||
return entity_hash
|
||||
|
||||
def add_alias(self, unicode alias, entities, probabilities):
|
||||
"""
|
||||
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||
Return the alias_hash at the end
|
||||
"""
|
||||
|
||||
# Throw an error if the length of entities and probabilities are not the same
|
||||
if not len(entities) == len(probabilities):
|
||||
raise ValueError(Errors.E132.format(alias=alias,
|
||||
entities_length=len(entities),
|
||||
probabilities_length=len(probabilities)))
|
||||
|
||||
# Throw an error if the probabilities sum up to more than 1
|
||||
prob_sum = sum(probabilities)
|
||||
if prob_sum > 1:
|
||||
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
|
||||
|
||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||
|
||||
# Return if this alias was added before
|
||||
if alias_hash in self._alias_index:
|
||||
user_warning(Warnings.W017.format(alias=alias))
|
||||
return
|
||||
|
||||
cdef hash_t entity_hash
|
||||
|
||||
cdef vector[int64_t] entry_indices
|
||||
cdef vector[float] probs
|
||||
|
||||
for entity, prob in zip(entities, probabilities):
|
||||
entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
|
||||
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
entry_indices.push_back(int(entry_index))
|
||||
probs.push_back(float(prob))
|
||||
|
||||
self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs)
|
||||
|
||||
return alias_hash
|
||||
|
||||
|
||||
def get_candidates(self, unicode alias):
|
||||
""" TODO: where to put this functionality ?"""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
|
||||
return [Candidate(kb=self,
|
||||
entity_hash=self._entries[entry_index].entity_hash,
|
||||
alias_hash=alias_hash,
|
||||
prior_prob=prob)
|
||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
if entry_index != 0]
|
|
@ -14,7 +14,7 @@ import srsly
|
|||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker
|
||||
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
|
||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||
from .pipeline import EntityRuler
|
||||
|
@ -117,6 +117,7 @@ class Language(object):
|
|||
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
|
||||
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
|
||||
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
||||
"entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
|
||||
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
||||
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
||||
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
|
||||
|
@ -212,6 +213,10 @@ class Language(object):
|
|||
def entity(self):
|
||||
return self.get_pipe("ner")
|
||||
|
||||
@property
|
||||
def linker(self):
|
||||
return self.get_pipe("entity_linker")
|
||||
|
||||
@property
|
||||
def matcher(self):
|
||||
return self.get_pipe("matcher")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from .pipes import Tagger, DependencyParser, EntityRecognizer
|
||||
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
|
||||
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
||||
from .entityruler import EntityRuler
|
||||
from .hooks import SentenceSegmenter, SimilarityHook
|
||||
|
@ -11,6 +11,7 @@ __all__ = [
|
|||
"Tagger",
|
||||
"DependencyParser",
|
||||
"EntityRecognizer",
|
||||
"EntityLinker",
|
||||
"TextCategorizer",
|
||||
"Tensorizer",
|
||||
"Pipe",
|
||||
|
|
|
@ -1061,6 +1061,55 @@ cdef class EntityRecognizer(Parser):
|
|||
if move[0] in ("B", "I", "L", "U")))
|
||||
|
||||
|
||||
class EntityLinker(Pipe):
|
||||
name = 'entity_linker'
|
||||
|
||||
@classmethod
|
||||
def Model(cls, nr_class=1, **cfg):
|
||||
# TODO: non-dummy EL implementation
|
||||
return None
|
||||
|
||||
def __init__(self, model=True, **cfg):
|
||||
self.model = False
|
||||
self.cfg = dict(cfg)
|
||||
self.kb = self.cfg["kb"]
|
||||
|
||||
def __call__(self, doc):
|
||||
self.set_annotations([doc], scores=None, tensors=None)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
"""Apply the pipe to a stream of documents.
|
||||
Both __call__ and pipe should delegate to the `predict()`
|
||||
and `set_annotations()` methods.
|
||||
"""
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
docs = list(docs)
|
||||
self.set_annotations(docs, scores=None, tensors=None)
|
||||
yield from docs
|
||||
|
||||
def set_annotations(self, docs, scores, tensors=None):
|
||||
"""
|
||||
Currently implemented as taking the KB entry with highest prior probability for each named entity
|
||||
TODO: actually use context etc
|
||||
"""
|
||||
for i, doc in enumerate(docs):
|
||||
for ent in doc.ents:
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
best_candidate = max(candidates, key=lambda c: c.prior_prob)
|
||||
for token in ent:
|
||||
token.ent_kb_id_ = best_candidate.entity_
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def add_label(self, label):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class Sentencizer(object):
|
||||
"""Segment the Doc into sentences using a rule-based strategy.
|
||||
|
||||
|
@ -1146,5 +1195,5 @@ class Sentencizer(object):
|
|||
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
|
||||
return self
|
||||
|
||||
|
||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "Sentencizer"]
|
||||
|
||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer"]
|
||||
|
|
|
@ -70,4 +70,5 @@ cdef struct TokenC:
|
|||
int sent_start
|
||||
int ent_iob
|
||||
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
||||
attr_t ent_kb_id
|
||||
hash_t ent_id
|
||||
|
|
|
@ -172,10 +172,12 @@ def test_span_as_doc(doc):
|
|||
assert span_doc[0].idx == 0
|
||||
|
||||
|
||||
def test_span_string_label(doc):
|
||||
span = Span(doc, 0, 1, label="hello")
|
||||
def test_span_string_label_kb_id(doc):
|
||||
span = Span(doc, 0, 1, label="hello", kb_id="Q342")
|
||||
assert span.label_ == "hello"
|
||||
assert span.label == doc.vocab.strings["hello"]
|
||||
assert span.kb_id_ == "Q342"
|
||||
assert span.kb_id == doc.vocab.strings["Q342"]
|
||||
|
||||
|
||||
def test_span_label_readonly(doc):
|
||||
|
@ -184,6 +186,12 @@ def test_span_label_readonly(doc):
|
|||
span.label_ = "hello"
|
||||
|
||||
|
||||
def test_span_kb_id_readonly(doc):
|
||||
span = Span(doc, 0, 1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
span.kb_id_ = "Q342"
|
||||
|
||||
|
||||
def test_span_ents_property(doc):
|
||||
"""Test span.ents for the """
|
||||
doc.ents = [
|
||||
|
|
91
spacy/tests/pipeline/test_el.py
Normal file
91
spacy/tests/pipeline/test_el.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from spacy.lang.en import English
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
||||
|
||||
def test_kb_valid_entities(nlp):
|
||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
||||
mykb.add_entity(entity=u'Q2')
|
||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9])
|
||||
|
||||
# test the size of the corresponding KB
|
||||
assert(mykb.get_size_entities() == 3)
|
||||
assert(mykb.get_size_aliases() == 2)
|
||||
|
||||
|
||||
def test_kb_invalid_entities(nlp):
|
||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
||||
|
||||
# adding aliases - should fail because one of the given IDs is not valid
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q342'], probabilities=[0.8, 0.2])
|
||||
|
||||
|
||||
def test_kb_invalid_probabilities(nlp):
|
||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
||||
|
||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.4])
|
||||
|
||||
|
||||
def test_kb_invalid_combination(nlp):
|
||||
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
||||
|
||||
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.3, 0.4, 0.1])
|
||||
|
||||
|
||||
def test_candidate_generation(nlp):
|
||||
"""Test correct candidate generation"""
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert(len(mykb.get_candidates(u'douglas')) == 2)
|
||||
assert(len(mykb.get_candidates(u'adam')) == 1)
|
||||
assert(len(mykb.get_candidates(u'shrubbery')) == 0)
|
|
@ -326,7 +326,7 @@ cdef class Doc:
|
|||
def doc(self):
|
||||
return self
|
||||
|
||||
def char_span(self, int start_idx, int end_idx, label=0, vector=None):
|
||||
def char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None):
|
||||
"""Create a `Span` object from the slice `doc.text[start : end]`.
|
||||
|
||||
doc (Doc): The parent document.
|
||||
|
@ -334,6 +334,7 @@ cdef class Doc:
|
|||
end (int): The index of the first character after the span.
|
||||
label (uint64 or string): A label to attach to the Span, e.g. for
|
||||
named entities.
|
||||
kb_id (uint64 or string): An ID from a KB to capture the meaning of a named entity.
|
||||
vector (ndarray[ndim=1, dtype='float32']): A meaning representation of
|
||||
the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
|
@ -342,6 +343,8 @@ cdef class Doc:
|
|||
"""
|
||||
if not isinstance(label, int):
|
||||
label = self.vocab.strings.add(label)
|
||||
if not isinstance(kb_id, int):
|
||||
kb_id = self.vocab.strings.add(kb_id)
|
||||
cdef int start = token_by_start(self.c, self.length, start_idx)
|
||||
if start == -1:
|
||||
return None
|
||||
|
@ -350,7 +353,7 @@ cdef class Doc:
|
|||
return None
|
||||
# Currently we have the token index, we want the range-end index
|
||||
end += 1
|
||||
cdef Span span = Span(self, start, end, label=label, vector=vector)
|
||||
cdef Span span = Span(self, start, end, label=label, kb_id=kb_id, vector=vector)
|
||||
return span
|
||||
|
||||
def similarity(self, other):
|
||||
|
@ -484,6 +487,7 @@ cdef class Doc:
|
|||
cdef const TokenC* token
|
||||
cdef int start = -1
|
||||
cdef attr_t label = 0
|
||||
cdef attr_t kb_id = 0
|
||||
output = []
|
||||
for i in range(self.length):
|
||||
token = &self.c[i]
|
||||
|
@ -493,16 +497,18 @@ cdef class Doc:
|
|||
raise ValueError(Errors.E093.format(seq=" ".join(seq)))
|
||||
elif token.ent_iob == 2 or token.ent_iob == 0:
|
||||
if start != -1:
|
||||
output.append(Span(self, start, i, label=label))
|
||||
output.append(Span(self, start, i, label=label, kb_id=kb_id))
|
||||
start = -1
|
||||
label = 0
|
||||
kb_id = 0
|
||||
elif token.ent_iob == 3:
|
||||
if start != -1:
|
||||
output.append(Span(self, start, i, label=label))
|
||||
output.append(Span(self, start, i, label=label, kb_id=kb_id))
|
||||
start = i
|
||||
label = token.ent_type
|
||||
kb_id = token.ent_kb_id
|
||||
if start != -1:
|
||||
output.append(Span(self, start, self.length, label=label))
|
||||
output.append(Span(self, start, self.length, label=label, kb_id=kb_id))
|
||||
return tuple(output)
|
||||
|
||||
def __set__(self, ents):
|
||||
|
|
|
@ -11,6 +11,7 @@ cdef class Span:
|
|||
cdef readonly int start_char
|
||||
cdef readonly int end_char
|
||||
cdef readonly attr_t label
|
||||
cdef readonly attr_t kb_id
|
||||
|
||||
cdef public _vector
|
||||
cdef public _vector_norm
|
||||
|
|
|
@ -85,13 +85,14 @@ cdef class Span:
|
|||
return Underscore.span_extensions.pop(name)
|
||||
|
||||
def __cinit__(self, Doc doc, int start, int end, label=0, vector=None,
|
||||
vector_norm=None):
|
||||
vector_norm=None, kb_id=0):
|
||||
"""Create a `Span` object from the slice `doc[start : end]`.
|
||||
|
||||
doc (Doc): The parent document.
|
||||
start (int): The index of the first token of the span.
|
||||
end (int): The index of the first token after the span.
|
||||
label (uint64): A label to attach to the Span, e.g. for named entities.
|
||||
kb_id (uint64): An identifier from a Knowledge Base to capture the meaning of a named entity.
|
||||
vector (ndarray[ndim=1, dtype='float32']): A meaning representation
|
||||
of the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
|
@ -110,11 +111,14 @@ cdef class Span:
|
|||
self.end_char = 0
|
||||
if isinstance(label, basestring_):
|
||||
label = doc.vocab.strings.add(label)
|
||||
if isinstance(kb_id, basestring_):
|
||||
kb_id = doc.vocab.strings.add(kb_id)
|
||||
if label not in doc.vocab.strings:
|
||||
raise ValueError(Errors.E084.format(label=label))
|
||||
self.label = label
|
||||
self._vector = vector
|
||||
self._vector_norm = vector_norm
|
||||
self.kb_id = kb_id
|
||||
|
||||
def __richcmp__(self, Span other, int op):
|
||||
if other is None:
|
||||
|
@ -655,6 +659,20 @@ cdef class Span:
|
|||
label_ = ''
|
||||
raise NotImplementedError(Errors.E129.format(start=self.start, end=self.end, label=label_))
|
||||
|
||||
property kb_id_:
|
||||
"""RETURNS (unicode): The named entity's KB ID."""
|
||||
def __get__(self):
|
||||
return self.doc.vocab.strings[self.kb_id]
|
||||
|
||||
def __set__(self, unicode kb_id_):
|
||||
if not kb_id_:
|
||||
kb_id_ = ''
|
||||
current_label = self.label_
|
||||
if not current_label:
|
||||
current_label = ''
|
||||
raise NotImplementedError(Errors.E131.format(start=self.start, end=self.end,
|
||||
label=current_label, kb_id=kb_id_))
|
||||
|
||||
|
||||
cdef int _count_words_to_root(const TokenC* token, int sent_length) except -1:
|
||||
# Don't allow spaces to be the root, if there are
|
||||
|
|
|
@ -770,6 +770,22 @@ cdef class Token:
|
|||
def __set__(self, name):
|
||||
self.c.ent_id = self.vocab.strings.add(name)
|
||||
|
||||
property ent_kb_id:
|
||||
"""RETURNS (uint64): Named entity KB ID."""
|
||||
def __get__(self):
|
||||
return self.c.ent_kb_id
|
||||
|
||||
def __set__(self, attr_t ent_kb_id):
|
||||
self.c.ent_kb_id = ent_kb_id
|
||||
|
||||
property ent_kb_id_:
|
||||
"""RETURNS (unicode): Named entity KB ID."""
|
||||
def __get__(self):
|
||||
return self.vocab.strings[self.c.ent_kb_id]
|
||||
|
||||
def __set__(self, ent_kb_id):
|
||||
self.c.ent_kb_id = self.vocab.strings.add(ent_kb_id)
|
||||
|
||||
@property
|
||||
def whitespace_(self):
|
||||
"""RETURNS (unicode): The trailing whitespace character, if present."""
|
||||
|
|
Loading…
Reference in New Issue
Block a user