mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
fixes in kb and gold
This commit is contained in:
parent
4086c6ff60
commit
d833d4c358
|
@ -401,15 +401,13 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
gold_start = int(start) - found_ent.sent.start_char
|
gold_start = int(start) - found_ent.sent.start_char
|
||||||
gold_end = int(end) - found_ent.sent.start_char
|
gold_end = int(end) - found_ent.sent.start_char
|
||||||
|
|
||||||
# add both positive and negative examples (in random order just to be sure)
|
# add both pos and neg examples (in random order)
|
||||||
if kb:
|
if kb:
|
||||||
gold_entities = {}
|
gold_entities = {}
|
||||||
candidate_ids = [
|
candidates = kb.get_candidates(alias)
|
||||||
c.entity_ for c in kb.get_candidates(alias)
|
candidate_ids = [c.entity_ for c in candidates]
|
||||||
]
|
# add positive example in case the KB doesn't have it
|
||||||
candidate_ids.append(
|
candidate_ids.append(wd_id)
|
||||||
wd_id
|
|
||||||
) # in case the KB doesn't have it
|
|
||||||
random.shuffle(candidate_ids)
|
random.shuffle(candidate_ids)
|
||||||
for kb_id in candidate_ids:
|
for kb_id in candidate_ids:
|
||||||
entry = (gold_start, gold_end, kb_id)
|
entry = (gold_start, gold_end, kb_id)
|
||||||
|
@ -418,7 +416,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
else:
|
else:
|
||||||
gold_entities[entry] = 1.0
|
gold_entities[entry] = 1.0
|
||||||
else:
|
else:
|
||||||
gold_entities = {}
|
entry = (gold_start, gold_end, wd_id)
|
||||||
|
gold_entities = {entry: 1.0}
|
||||||
|
|
||||||
gold = GoldParse(doc=sent, links=gold_entities)
|
gold = GoldParse(doc=sent, links=gold_entities)
|
||||||
data.append((sent, gold))
|
data.append((sent, gold))
|
||||||
|
|
|
@ -31,7 +31,7 @@ cdef class GoldParse:
|
||||||
cdef public list ents
|
cdef public list ents
|
||||||
cdef public dict brackets
|
cdef public dict brackets
|
||||||
cdef public object cats
|
cdef public object cats
|
||||||
cdef public list links
|
cdef public dict links
|
||||||
|
|
||||||
cdef readonly list cand_to_gold
|
cdef readonly list cand_to_gold
|
||||||
cdef readonly list gold_to_cand
|
cdef readonly list gold_to_cand
|
||||||
|
|
|
@ -450,8 +450,10 @@ cdef class GoldParse:
|
||||||
examples of a label to have the value 0.0. Labels not in the
|
examples of a label to have the value 0.0. Labels not in the
|
||||||
dictionary are treated as missing - the gradient for those labels
|
dictionary are treated as missing - the gradient for those labels
|
||||||
will be zero.
|
will be zero.
|
||||||
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
links (dict): A dict with `(start_char, end_char, kb_id)` keys,
|
||||||
representing the external ID of an entity in a knowledge base.
|
representing the external ID of an entity in a knowledge base,
|
||||||
|
and the values being either 1.0 or 0.0, indicating positive and
|
||||||
|
negative examples, respectively.
|
||||||
RETURNS (GoldParse): The newly constructed object.
|
RETURNS (GoldParse): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
if words is None:
|
if words is None:
|
||||||
|
|
29
spacy/kb.pyx
29
spacy/kb.pyx
|
@ -191,7 +191,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
def get_candidates(self, unicode alias):
|
def get_candidates(self, unicode alias):
|
||||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
alias_index = <int64_t>self._alias_index.get(alias_hash) # TODO: check for error? unit test !
|
||||||
alias_entry = self._aliases_table[alias_index]
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
|
||||||
return [Candidate(kb=self,
|
return [Candidate(kb=self,
|
||||||
|
@ -199,12 +199,12 @@ cdef class KnowledgeBase:
|
||||||
entity_freq=self._entries[entry_index].prob,
|
entity_freq=self._entries[entry_index].prob,
|
||||||
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||||
alias_hash=alias_hash,
|
alias_hash=alias_hash,
|
||||||
prior_prob=prob)
|
prior_prob=prior_prob)
|
||||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||||
if entry_index != 0]
|
if entry_index != 0]
|
||||||
|
|
||||||
def get_vector(self, unicode entity):
|
def get_vector(self, unicode entity):
|
||||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
|
||||||
# Return an empty list if this entity is unknown in this KB
|
# Return an empty list if this entity is unknown in this KB
|
||||||
if entity_hash not in self._entry_index:
|
if entity_hash not in self._entry_index:
|
||||||
|
@ -213,6 +213,27 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
return self._vectors_table[self._entries[entry_index].vector_index]
|
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||||
|
|
||||||
|
def get_prior_prob(self, unicode entity, unicode alias):
|
||||||
|
""" Return the prior probability of a given alias being linked to a given entity,
|
||||||
|
or return 0.0 when this combination is not known in the knowledge base"""
|
||||||
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
|
||||||
|
# TODO: error ?
|
||||||
|
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||||
|
entry_index = self._entry_index[entity_hash]
|
||||||
|
|
||||||
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
|
||||||
|
if self._entries[entry_index].entity_hash == entity_hash:
|
||||||
|
return prior_prob
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def dump(self, loc):
|
def dump(self, loc):
|
||||||
cdef Writer writer = Writer(loc)
|
cdef Writer writer = Writer(loc)
|
||||||
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||||
|
|
|
@ -13,6 +13,11 @@ def nlp():
|
||||||
return English()
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_almost_equal(a, b):
|
||||||
|
delta = 0.0001
|
||||||
|
assert a - delta <= b <= a + delta
|
||||||
|
|
||||||
|
|
||||||
def test_kb_valid_entities(nlp):
|
def test_kb_valid_entities(nlp):
|
||||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
@ -35,6 +40,10 @@ def test_kb_valid_entities(nlp):
|
||||||
assert mykb.get_vector("Q2") == [2, 1, 0]
|
assert mykb.get_vector("Q2") == [2, 1, 0]
|
||||||
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
||||||
|
|
||||||
|
# test retrieval of prior probabilities
|
||||||
|
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
|
||||||
|
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entities(nlp):
|
def test_kb_invalid_entities(nlp):
|
||||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||||
|
@ -99,12 +108,12 @@ def test_candidate_generation(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.7, entity_vector=[1])
|
||||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
# test the size of the relevant candidates
|
# test the size of the relevant candidates
|
||||||
|
@ -112,6 +121,12 @@ def test_candidate_generation(nlp):
|
||||||
assert len(mykb.get_candidates("adam")) == 1
|
assert len(mykb.get_candidates("adam")) == 1
|
||||||
assert len(mykb.get_candidates("shrubbery")) == 0
|
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||||
|
|
||||||
|
# test the content of the candidates
|
||||||
|
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
|
||||||
|
assert mykb.get_candidates("adam")[0].alias_ == "adam"
|
||||||
|
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
|
||||||
|
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||||
|
|
||||||
|
|
||||||
def test_preserving_links_asdoc(nlp):
|
def test_preserving_links_asdoc(nlp):
|
||||||
"""Test that Span.as_doc preserves the existing entity links"""
|
"""Test that Span.as_doc preserves the existing entity links"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user