Merge pull request #3459 from svlandeg/feature/el-framework

Basic framework and APIs for entity linker
This commit is contained in:
Ines Montani 2019-03-29 14:02:22 +01:00 committed by GitHub
commit 68900066e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 571 additions and 12 deletions

View File

@ -0,0 +1,71 @@
# coding: utf-8
from __future__ import unicode_literals
"""Demonstrate how to build a simple knowledge base and run an Entity Linking algorithm.
Currently still a bit of a dummy algorithm: taking simply the entity with highest probability for a given alias
"""
import spacy
from spacy.kb import KnowledgeBase
def create_kb(vocab):
kb = KnowledgeBase(vocab=vocab)
# adding entities
entity_0 = "Q1004791_Douglas"
print("adding entity", entity_0)
kb.add_entity(entity=entity_0, prob=0.5)
entity_1 = "Q42_Douglas_Adams"
print("adding entity", entity_1)
kb.add_entity(entity=entity_1, prob=0.5)
entity_2 = "Q5301561_Douglas_Haig"
print("adding entity", entity_2)
kb.add_entity(entity=entity_2, prob=0.5)
# adding aliases
print()
alias_0 = "Douglas"
print("adding alias", alias_0)
kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.1, 0.6, 0.2])
alias_1 = "Douglas Adams"
print("adding alias", alias_1)
kb.add_alias(alias=alias_1, entities=[entity_1], probabilities=[0.9])
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
return kb
def add_el(kb, nlp):
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": kb})
nlp.add_pipe(el_pipe, last=True)
for alias in ["Douglas Adams", "Douglas"]:
candidates = nlp.linker.kb.get_candidates(alias)
print()
print(len(candidates), "candidate(s) for", alias, ":")
for c in candidates:
print(" ", c.entity_, c.prior_prob)
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel. " \
"The main character in Doug's novel is called Arthur Dent."
doc = nlp(text)
print()
for token in doc:
print("token", token.text, token.ent_type_, token.ent_kb_id_)
print()
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
if __name__ == "__main__":
nlp = spacy.load('en_core_web_sm')
my_kb = create_kb(nlp.vocab)
add_el(my_kb, nlp)

View File

@ -40,6 +40,7 @@ MOD_NAMES = [
"spacy.lexeme",
"spacy.vocab",
"spacy.attrs",
"spacy.kb",
"spacy.morphology",
"spacy.pipeline.pipes",
"spacy.syntax.stateclass",

View File

@ -80,6 +80,8 @@ class Warnings(object):
"the v2.x models cannot release the global interpreter lock. "
"Future versions may introduce a `n_process` argument for "
"parallel inference via multiprocessing.")
W017 = ("Alias '{alias}' already exists in the Knowledge base.")
W018 = ("Entity '{entity}' already exists in the Knowledge base.")
@add_codes
@ -371,6 +373,16 @@ class Errors(object):
"with spacy >= 2.1.0. To fix this, reinstall Python and use a wide "
"unicode build instead. You can also rebuild Python and set the "
"--enable-unicode=ucs4 flag.")
E131 = ("Cannot write the kb_id of an existing Span object because a Span "
"is a read-only view of the underlying Token objects stored in the Doc. "
"Instead, create a new Span object and specify the `kb_id` keyword argument, "
"for example:\nfrom spacy.tokens import Span\n"
"span = Span(doc, start={start}, end={end}, label='{label}', kb_id='{kb_id}')")
E132 = ("The vectors for entities and probabilities for alias '{alias}' should have equal length, "
"but found {entities_length} and {probabilities_length} respectively.")
E133 = ("The sum of prior probabilities for alias '{alias}' should not exceed 1, "
"but found {sum}.")
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
@add_codes

148
spacy/kb.pxd Normal file
View File

@ -0,0 +1,148 @@
"""Knowledge-base for entity or concept linking."""
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from libcpp.vector cimport vector
from libc.stdint cimport int32_t, int64_t
from spacy.vocab cimport Vocab
from .typedefs cimport hash_t
# Internal struct, for storage and disambiguation. This isn't what we return
# to the user as the answer to "here's your entity". It's the minimum number
# of bits we need to keep track of the answers.
cdef struct _EntryC:
# The hash of this entry's unique ID and name in the kB
hash_t entity_hash
# Allows retrieval of one or more vectors.
# Each element of vector_rows should be an index into a vectors table.
# Every entry should have the same number of vectors, so we can avoid storing
# the number of vectors in each knowledge-base struct
int32_t* vector_rows
# Allows retrieval of a struct of non-vector features. We could make this a
# pointer, but we have 32 bits left over in the struct after prob, so we'd
# like this to only be 32 bits. We can also set this to -1, for the common
# case where there are no features.
int32_t feats_row
# log probability of entity, based on corpus frequency
float prob
# Each alias struct stores a list of Entry pointers with their prior probabilities
# for this specific mention/alias.
cdef struct _AliasC:
# All entry candidates for this alias
vector[int64_t] entry_indices
# Prior probability P(entity|alias) - should sum up to (at most) 1.
vector[float] probs
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
cdef class Candidate:
cdef readonly KnowledgeBase kb
cdef hash_t entity_hash
cdef hash_t alias_hash
cdef float prior_prob
cdef class KnowledgeBase:
cdef Pool mem
cpdef readonly Vocab vocab
# This maps 64bit keys (hash of unique entity string)
# to 64bit values (position of the _EntryC struct in the _entries vector).
# The PreshMap is pretty space efficient, as it uses open addressing. So
# the only overhead is the vacancy rate, which is approximately 30%.
cdef PreshMap _entry_index
# Each entry takes 128 bits, and again we'll have a 30% or so overhead for
# over allocation.
# In total we end up with (N*128*1.3)+(N*128*1.3) bits for N entries.
# Storing 1m entries would take 41.6mb under this scheme.
cdef vector[_EntryC] _entries
# This maps 64bit keys (hash of unique alias string)
# to 64bit values (position of the _AliasC struct in the _aliases_table vector).
cdef PreshMap _alias_index
# This should map mention hashes to (entry_id, prob) tuples. The probability
# should be P(entity | mention), which is pretty important to know.
# We can pack both pieces of information into a 64-bit value, to keep things
# efficient.
cdef vector[_AliasC] _aliases_table
# This is the part which might take more space: storing various
# categorical features for the entries, and storing vectors for disambiguation
# and possibly usage.
# If each entry gets a 300-dimensional vector, for 1m entries we would need
# 1.2gb. That gets expensive fast. What might be better is to avoid learning
# a unique vector for every entity. We could instead have a compositional
# model, that embeds different features of the entities into vectors. We'll
# still want some per-entity features, like the Wikipedia text or entity
# co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions.
cdef object _vectors_table
# It's very useful to track categorical features, at least for output, even
# if they're not useful in the model itself. For instance, we should be
# able to track stuff like a person's date of birth or whatever. This can
# easily make the KB bigger, but if this isn't needed by the model, and it's
# optional data, we can let users configure a DB as the backend for this.
cdef object _features_table
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
int32_t* vector_rows, int feats_row):
"""Add an entry to the knowledge base."""
# This is what we'll map the hash key to. It's where the entry will sit
# in the vector of entries, so we can get it later.
cdef int64_t new_index = self._entries.size()
self._entries.push_back(
_EntryC(
entity_hash=entity_hash,
vector_rows=vector_rows,
feats_row=feats_row,
prob=prob
))
self._entry_index[entity_hash] = new_index
return new_index
cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs):
"""Connect a mention to a list of potential entities with their prior probabilities ."""
cdef int64_t new_index = self._aliases_table.size()
self._aliases_table.push_back(
_AliasC(
entry_indices=entry_indices,
probs=probs
))
self._alias_index[alias_hash] = new_index
return new_index
cdef inline _create_empty_vectors(self):
"""
Making sure the first element of each vector is a dummy,
because the PreshMap maps pointing to indices in these vectors can not contain 0 as value
cf. https://github.com/explosion/preshed/issues/17
"""
cdef int32_t dummy_value = 0
self.vocab.strings.add("")
self._entries.push_back(
_EntryC(
entity_hash=self.vocab.strings[""],
vector_rows=&dummy_value,
feats_row=dummy_value,
prob=dummy_value
))
self._aliases_table.push_back(
_AliasC(
entry_indices=[dummy_value],
probs=[dummy_value]
))

131
spacy/kb.pyx Normal file
View File

@ -0,0 +1,131 @@
# cython: profile=True
# coding: utf8
from spacy.errors import Errors, Warnings, user_warning
cdef class Candidate:
def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob):
self.kb = kb
self.entity_hash = entity_hash
self.alias_hash = alias_hash
self.prior_prob = prior_prob
@property
def entity(self):
"""RETURNS (uint64): hash of the entity's KB ID/name"""
return self.entity_hash
@property
def entity_(self):
"""RETURNS (unicode): ID/name of this entity in the KB"""
return self.kb.vocab.strings[self.entity]
@property
def alias(self):
"""RETURNS (uint64): hash of the alias"""
return self.alias_hash
@property
def alias_(self):
"""RETURNS (unicode): ID of the original alias"""
return self.kb.vocab.strings[self.alias]
@property
def prior_prob(self):
return self.prior_prob
cdef class KnowledgeBase:
def __init__(self, Vocab vocab):
self.vocab = vocab
self._entry_index = PreshMap()
self._alias_index = PreshMap()
self.mem = Pool()
self._create_empty_vectors()
def __len__(self):
return self.get_size_entities()
def get_size_entities(self):
return self._entries.size() - 1 # not counting dummy element on index 0
def get_size_aliases(self):
return self._aliases_table.size() - 1 # not counting dummy element on index 0
def add_entity(self, unicode entity, float prob=0.5, vectors=None, features=None):
"""
Add an entity to the KB.
Return the hash of the entity ID at the end
"""
cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before
if entity_hash in self._entry_index:
user_warning(Warnings.W018.format(entity=entity))
return
cdef int32_t dummy_value = 342
self.c_add_entity(entity_hash=entity_hash, prob=prob,
vector_rows=&dummy_value, feats_row=dummy_value)
# TODO self._vectors_table.get_pointer(vectors),
# self._features_table.get(features))
return entity_hash
def add_alias(self, unicode alias, entities, probabilities):
"""
For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end
"""
# Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias,
entities_length=len(entities),
probabilities_length=len(probabilities)))
# Throw an error if the probabilities sum up to more than 1
prob_sum = sum(probabilities)
if prob_sum > 1:
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
cdef hash_t alias_hash = self.vocab.strings.add(alias)
# Return if this alias was added before
if alias_hash in self._alias_index:
user_warning(Warnings.W017.format(alias=alias))
return
cdef hash_t entity_hash
cdef vector[int64_t] entry_indices
cdef vector[float] probs
for entity, prob in zip(entities, probabilities):
entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index:
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash)
entry_indices.push_back(int(entry_index))
probs.push_back(float(prob))
self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs)
return alias_hash
def get_candidates(self, unicode alias):
""" TODO: where to put this functionality ?"""
cdef hash_t alias_hash = self.vocab.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]
return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash,
alias_hash=alias_hash,
prior_prob=prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0]

View File

@ -14,7 +14,7 @@ import srsly
from .tokenizer import Tokenizer
from .vocab import Vocab
from .lemmatizer import Lemmatizer
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer, EntityLinker
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
from .pipeline import EntityRuler
@ -117,6 +117,7 @@ class Language(object):
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
"entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
@ -212,6 +213,10 @@ class Language(object):
def entity(self):
return self.get_pipe("ner")
@property
def linker(self):
return self.get_pipe("entity_linker")
@property
def matcher(self):
return self.get_pipe("matcher")

View File

@ -1,7 +1,7 @@
# coding: utf8
from __future__ import unicode_literals
from .pipes import Tagger, DependencyParser, EntityRecognizer
from .pipes import Tagger, DependencyParser, EntityRecognizer, EntityLinker
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
from .entityruler import EntityRuler
from .hooks import SentenceSegmenter, SimilarityHook
@ -11,6 +11,7 @@ __all__ = [
"Tagger",
"DependencyParser",
"EntityRecognizer",
"EntityLinker",
"TextCategorizer",
"Tensorizer",
"Pipe",

View File

@ -1061,6 +1061,55 @@ cdef class EntityRecognizer(Parser):
if move[0] in ("B", "I", "L", "U")))
class EntityLinker(Pipe):
name = 'entity_linker'
@classmethod
def Model(cls, nr_class=1, **cfg):
# TODO: non-dummy EL implementation
return None
def __init__(self, model=True, **cfg):
self.model = False
self.cfg = dict(cfg)
self.kb = self.cfg["kb"]
def __call__(self, doc):
self.set_annotations([doc], scores=None, tensors=None)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
"""Apply the pipe to a stream of documents.
Both __call__ and pipe should delegate to the `predict()`
and `set_annotations()` methods.
"""
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
self.set_annotations(docs, scores=None, tensors=None)
yield from docs
def set_annotations(self, docs, scores, tensors=None):
"""
Currently implemented as taking the KB entry with highest prior probability for each named entity
TODO: actually use context etc
"""
for i, doc in enumerate(docs):
for ent in doc.ents:
candidates = self.kb.get_candidates(ent.text)
if candidates:
best_candidate = max(candidates, key=lambda c: c.prior_prob)
for token in ent:
token.ent_kb_id_ = best_candidate.entity_
def get_loss(self, docs, golds, scores):
# TODO
pass
def add_label(self, label):
# TODO
pass
class Sentencizer(object):
"""Segment the Doc into sentences using a rule-based strategy.
@ -1146,5 +1195,5 @@ class Sentencizer(object):
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
return self
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "Sentencizer"]
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer"]

View File

@ -70,4 +70,5 @@ cdef struct TokenC:
int sent_start
int ent_iob
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
attr_t ent_kb_id
hash_t ent_id

View File

@ -172,10 +172,12 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0
def test_span_string_label(doc):
span = Span(doc, 0, 1, label="hello")
def test_span_string_label_kb_id(doc):
span = Span(doc, 0, 1, label="hello", kb_id="Q342")
assert span.label_ == "hello"
assert span.label == doc.vocab.strings["hello"]
assert span.kb_id_ == "Q342"
assert span.kb_id == doc.vocab.strings["Q342"]
def test_span_label_readonly(doc):
@ -184,6 +186,12 @@ def test_span_label_readonly(doc):
span.label_ = "hello"
def test_span_kb_id_readonly(doc):
span = Span(doc, 0, 1)
with pytest.raises(NotImplementedError):
span.kb_id_ = "Q342"
def test_span_ents_property(doc):
"""Test span.ents for the """
doc.ents = [

View File

@ -0,0 +1,91 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
from spacy.kb import KnowledgeBase
from spacy.lang.en import English
@pytest.fixture
def nlp():
return English()
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2')
mykb.add_entity(entity=u'Q3', prob=0.5)
# adding aliases
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9])
# test the size of the corresponding KB
assert(mykb.get_size_entities() == 3)
assert(mykb.get_size_aliases() == 2)
def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
# adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError):
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q342'], probabilities=[0.8, 0.2])
def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
# adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError):
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.3, 0.4, 0.1])
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
# adding aliases
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9])
# test the size of the relevant candidates
assert(len(mykb.get_candidates(u'douglas')) == 2)
assert(len(mykb.get_candidates(u'adam')) == 1)
assert(len(mykb.get_candidates(u'shrubbery')) == 0)

View File

@ -326,7 +326,7 @@ cdef class Doc:
def doc(self):
return self
def char_span(self, int start_idx, int end_idx, label=0, vector=None):
def char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None):
"""Create a `Span` object from the slice `doc.text[start : end]`.
doc (Doc): The parent document.
@ -334,6 +334,7 @@ cdef class Doc:
end (int): The index of the first character after the span.
label (uint64 or string): A label to attach to the Span, e.g. for
named entities.
kb_id (uint64 or string): An ID from a KB to capture the meaning of a named entity.
vector (ndarray[ndim=1, dtype='float32']): A meaning representation of
the span.
RETURNS (Span): The newly constructed object.
@ -342,6 +343,8 @@ cdef class Doc:
"""
if not isinstance(label, int):
label = self.vocab.strings.add(label)
if not isinstance(kb_id, int):
kb_id = self.vocab.strings.add(kb_id)
cdef int start = token_by_start(self.c, self.length, start_idx)
if start == -1:
return None
@ -350,7 +353,7 @@ cdef class Doc:
return None
# Currently we have the token index, we want the range-end index
end += 1
cdef Span span = Span(self, start, end, label=label, vector=vector)
cdef Span span = Span(self, start, end, label=label, kb_id=kb_id, vector=vector)
return span
def similarity(self, other):
@ -484,6 +487,7 @@ cdef class Doc:
cdef const TokenC* token
cdef int start = -1
cdef attr_t label = 0
cdef attr_t kb_id = 0
output = []
for i in range(self.length):
token = &self.c[i]
@ -493,16 +497,18 @@ cdef class Doc:
raise ValueError(Errors.E093.format(seq=" ".join(seq)))
elif token.ent_iob == 2 or token.ent_iob == 0:
if start != -1:
output.append(Span(self, start, i, label=label))
output.append(Span(self, start, i, label=label, kb_id=kb_id))
start = -1
label = 0
kb_id = 0
elif token.ent_iob == 3:
if start != -1:
output.append(Span(self, start, i, label=label))
output.append(Span(self, start, i, label=label, kb_id=kb_id))
start = i
label = token.ent_type
kb_id = token.ent_kb_id
if start != -1:
output.append(Span(self, start, self.length, label=label))
output.append(Span(self, start, self.length, label=label, kb_id=kb_id))
return tuple(output)
def __set__(self, ents):

View File

@ -11,6 +11,7 @@ cdef class Span:
cdef readonly int start_char
cdef readonly int end_char
cdef readonly attr_t label
cdef readonly attr_t kb_id
cdef public _vector
cdef public _vector_norm

View File

@ -85,13 +85,14 @@ cdef class Span:
return Underscore.span_extensions.pop(name)
def __cinit__(self, Doc doc, int start, int end, label=0, vector=None,
vector_norm=None):
vector_norm=None, kb_id=0):
"""Create a `Span` object from the slice `doc[start : end]`.
doc (Doc): The parent document.
start (int): The index of the first token of the span.
end (int): The index of the first token after the span.
label (uint64): A label to attach to the Span, e.g. for named entities.
kb_id (uint64): An identifier from a Knowledge Base to capture the meaning of a named entity.
vector (ndarray[ndim=1, dtype='float32']): A meaning representation
of the span.
RETURNS (Span): The newly constructed object.
@ -110,11 +111,14 @@ cdef class Span:
self.end_char = 0
if isinstance(label, basestring_):
label = doc.vocab.strings.add(label)
if isinstance(kb_id, basestring_):
kb_id = doc.vocab.strings.add(kb_id)
if label not in doc.vocab.strings:
raise ValueError(Errors.E084.format(label=label))
self.label = label
self._vector = vector
self._vector_norm = vector_norm
self.kb_id = kb_id
def __richcmp__(self, Span other, int op):
if other is None:
@ -655,6 +659,20 @@ cdef class Span:
label_ = ''
raise NotImplementedError(Errors.E129.format(start=self.start, end=self.end, label=label_))
property kb_id_:
"""RETURNS (unicode): The named entity's KB ID."""
def __get__(self):
return self.doc.vocab.strings[self.kb_id]
def __set__(self, unicode kb_id_):
if not kb_id_:
kb_id_ = ''
current_label = self.label_
if not current_label:
current_label = ''
raise NotImplementedError(Errors.E131.format(start=self.start, end=self.end,
label=current_label, kb_id=kb_id_))
cdef int _count_words_to_root(const TokenC* token, int sent_length) except -1:
# Don't allow spaces to be the root, if there are

View File

@ -770,6 +770,22 @@ cdef class Token:
def __set__(self, name):
self.c.ent_id = self.vocab.strings.add(name)
property ent_kb_id:
"""RETURNS (uint64): Named entity KB ID."""
def __get__(self):
return self.c.ent_kb_id
def __set__(self, attr_t ent_kb_id):
self.c.ent_kb_id = ent_kb_id
property ent_kb_id_:
"""RETURNS (unicode): Named entity KB ID."""
def __get__(self):
return self.vocab.strings[self.c.ent_kb_id]
def __set__(self, ent_kb_id):
self.c.ent_kb_id = self.vocab.strings.add(ent_kb_id)
@property
def whitespace_(self):
"""RETURNS (unicode): The trailing whitespace character, if present."""