mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	fixes in kb and gold
This commit is contained in:
		
							parent
							
								
									4086c6ff60
								
							
						
					
					
						commit
						d833d4c358
					
				| 
						 | 
					@ -401,15 +401,13 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
 | 
				
			||||||
                                gold_start = int(start) - found_ent.sent.start_char
 | 
					                                gold_start = int(start) - found_ent.sent.start_char
 | 
				
			||||||
                                gold_end = int(end) - found_ent.sent.start_char
 | 
					                                gold_end = int(end) - found_ent.sent.start_char
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                # add both positive and negative examples (in random order just to be sure)
 | 
					                                # add both pos and neg examples (in random order)
 | 
				
			||||||
                                if kb:
 | 
					                                if kb:
 | 
				
			||||||
                                    gold_entities = {}
 | 
					                                    gold_entities = {}
 | 
				
			||||||
                                    candidate_ids = [
 | 
					                                    candidates = kb.get_candidates(alias)
 | 
				
			||||||
                                        c.entity_ for c in kb.get_candidates(alias)
 | 
					                                    candidate_ids = [c.entity_ for c in candidates]
 | 
				
			||||||
                                    ]
 | 
					                                    # add positive example in case the KB doesn't have it
 | 
				
			||||||
                                    candidate_ids.append(
 | 
					                                    candidate_ids.append(wd_id)
 | 
				
			||||||
                                        wd_id
 | 
					 | 
				
			||||||
                                    )  # in case the KB doesn't have it
 | 
					 | 
				
			||||||
                                    random.shuffle(candidate_ids)
 | 
					                                    random.shuffle(candidate_ids)
 | 
				
			||||||
                                    for kb_id in candidate_ids:
 | 
					                                    for kb_id in candidate_ids:
 | 
				
			||||||
                                        entry = (gold_start, gold_end, kb_id)
 | 
					                                        entry = (gold_start, gold_end, kb_id)
 | 
				
			||||||
| 
						 | 
					@ -418,7 +416,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
 | 
				
			||||||
                                        else:
 | 
					                                        else:
 | 
				
			||||||
                                            gold_entities[entry] = 1.0
 | 
					                                            gold_entities[entry] = 1.0
 | 
				
			||||||
                                else:
 | 
					                                else:
 | 
				
			||||||
                                    gold_entities = {}
 | 
					                                    entry = (gold_start, gold_end, wd_id)
 | 
				
			||||||
 | 
					                                    gold_entities = {entry: 1.0}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                gold = GoldParse(doc=sent, links=gold_entities)
 | 
					                                gold = GoldParse(doc=sent, links=gold_entities)
 | 
				
			||||||
                                data.append((sent, gold))
 | 
					                                data.append((sent, gold))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,7 +31,7 @@ cdef class GoldParse:
 | 
				
			||||||
    cdef public list ents
 | 
					    cdef public list ents
 | 
				
			||||||
    cdef public dict brackets
 | 
					    cdef public dict brackets
 | 
				
			||||||
    cdef public object cats
 | 
					    cdef public object cats
 | 
				
			||||||
    cdef public list links
 | 
					    cdef public dict links
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef readonly list cand_to_gold
 | 
					    cdef readonly list cand_to_gold
 | 
				
			||||||
    cdef readonly list gold_to_cand
 | 
					    cdef readonly list gold_to_cand
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -450,8 +450,10 @@ cdef class GoldParse:
 | 
				
			||||||
            examples of a label to have the value 0.0. Labels not in the
 | 
					            examples of a label to have the value 0.0. Labels not in the
 | 
				
			||||||
            dictionary are treated as missing - the gradient for those labels
 | 
					            dictionary are treated as missing - the gradient for those labels
 | 
				
			||||||
            will be zero.
 | 
					            will be zero.
 | 
				
			||||||
        links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
 | 
					        links (dict): A dict with `(start_char, end_char, kb_id)` keys,
 | 
				
			||||||
            representing the external ID of an entity in a knowledge base.
 | 
					            representing the external ID of an entity in a knowledge base,
 | 
				
			||||||
 | 
					            and the values being either 1.0 or 0.0, indicating positive and
 | 
				
			||||||
 | 
					            negative examples, respectively.
 | 
				
			||||||
        RETURNS (GoldParse): The newly constructed object.
 | 
					        RETURNS (GoldParse): The newly constructed object.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if words is None:
 | 
					        if words is None:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										29
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							| 
						 | 
					@ -191,7 +191,7 @@ cdef class KnowledgeBase:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_candidates(self, unicode alias):
 | 
					    def get_candidates(self, unicode alias):
 | 
				
			||||||
        cdef hash_t alias_hash = self.vocab.strings[alias]
 | 
					        cdef hash_t alias_hash = self.vocab.strings[alias]
 | 
				
			||||||
        alias_index = <int64_t>self._alias_index.get(alias_hash)
 | 
					        alias_index = <int64_t>self._alias_index.get(alias_hash)  # TODO: check for error? unit test !
 | 
				
			||||||
        alias_entry = self._aliases_table[alias_index]
 | 
					        alias_entry = self._aliases_table[alias_index]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return [Candidate(kb=self,
 | 
					        return [Candidate(kb=self,
 | 
				
			||||||
| 
						 | 
					@ -199,12 +199,12 @@ cdef class KnowledgeBase:
 | 
				
			||||||
                          entity_freq=self._entries[entry_index].prob,
 | 
					                          entity_freq=self._entries[entry_index].prob,
 | 
				
			||||||
                          entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
 | 
					                          entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
 | 
				
			||||||
                          alias_hash=alias_hash,
 | 
					                          alias_hash=alias_hash,
 | 
				
			||||||
                          prior_prob=prob)
 | 
					                          prior_prob=prior_prob)
 | 
				
			||||||
                for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
 | 
					                for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
 | 
				
			||||||
                if entry_index != 0]
 | 
					                if entry_index != 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_vector(self, unicode entity):
 | 
					    def get_vector(self, unicode entity):
 | 
				
			||||||
        cdef hash_t entity_hash = self.vocab.strings.add(entity)
 | 
					        cdef hash_t entity_hash = self.vocab.strings[entity]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Return an empty list if this entity is unknown in this KB
 | 
					        # Return an empty list if this entity is unknown in this KB
 | 
				
			||||||
        if entity_hash not in self._entry_index:
 | 
					        if entity_hash not in self._entry_index:
 | 
				
			||||||
| 
						 | 
					@ -213,6 +213,27 @@ cdef class KnowledgeBase:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return self._vectors_table[self._entries[entry_index].vector_index]
 | 
					        return self._vectors_table[self._entries[entry_index].vector_index]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_prior_prob(self, unicode entity, unicode alias):
 | 
				
			||||||
 | 
					        """ Return the prior probability of a given alias being linked to a given entity,
 | 
				
			||||||
 | 
					        or return 0.0 when this combination is not known in the knowledge base"""
 | 
				
			||||||
 | 
					        cdef hash_t alias_hash = self.vocab.strings[alias]
 | 
				
			||||||
 | 
					        cdef hash_t entity_hash = self.vocab.strings[entity]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO: error  ?
 | 
				
			||||||
 | 
					        if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
 | 
				
			||||||
 | 
					            return 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        alias_index = <int64_t>self._alias_index.get(alias_hash)
 | 
				
			||||||
 | 
					        entry_index = self._entry_index[entity_hash]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        alias_entry = self._aliases_table[alias_index]
 | 
				
			||||||
 | 
					        for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
 | 
				
			||||||
 | 
					            if self._entries[entry_index].entity_hash == entity_hash:
 | 
				
			||||||
 | 
					                return prior_prob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def dump(self, loc):
 | 
					    def dump(self, loc):
 | 
				
			||||||
        cdef Writer writer = Writer(loc)
 | 
					        cdef Writer writer = Writer(loc)
 | 
				
			||||||
        writer.write_header(self.get_size_entities(), self.entity_vector_length)
 | 
					        writer.write_header(self.get_size_entities(), self.entity_vector_length)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,11 @@ def nlp():
 | 
				
			||||||
    return English()
 | 
					    return English()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def assert_almost_equal(a, b):
 | 
				
			||||||
 | 
					    delta = 0.0001
 | 
				
			||||||
 | 
					    assert a - delta <= b <= a + delta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_kb_valid_entities(nlp):
 | 
					def test_kb_valid_entities(nlp):
 | 
				
			||||||
    """Test the valid construction of a KB with 3 entities and two aliases"""
 | 
					    """Test the valid construction of a KB with 3 entities and two aliases"""
 | 
				
			||||||
    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
 | 
					    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
 | 
				
			||||||
| 
						 | 
					@ -35,6 +40,10 @@ def test_kb_valid_entities(nlp):
 | 
				
			||||||
    assert mykb.get_vector("Q2") == [2, 1, 0]
 | 
					    assert mykb.get_vector("Q2") == [2, 1, 0]
 | 
				
			||||||
    assert mykb.get_vector("Q3") == [-1, -6, 5]
 | 
					    assert mykb.get_vector("Q3") == [-1, -6, 5]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test retrieval of prior probabilities
 | 
				
			||||||
 | 
					    assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
 | 
				
			||||||
 | 
					    assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_kb_invalid_entities(nlp):
 | 
					def test_kb_invalid_entities(nlp):
 | 
				
			||||||
    """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
 | 
					    """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
 | 
				
			||||||
| 
						 | 
					@ -99,12 +108,12 @@ def test_candidate_generation(nlp):
 | 
				
			||||||
    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
 | 
					    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # adding entities
 | 
					    # adding entities
 | 
				
			||||||
    mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
 | 
					    mykb.add_entity(entity="Q1", prob=0.7, entity_vector=[1])
 | 
				
			||||||
    mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
 | 
					    mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
 | 
				
			||||||
    mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
 | 
					    mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # adding aliases
 | 
					    # adding aliases
 | 
				
			||||||
    mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
 | 
					    mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
 | 
				
			||||||
    mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
 | 
					    mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # test the size of the relevant candidates
 | 
					    # test the size of the relevant candidates
 | 
				
			||||||
| 
						 | 
					@ -112,6 +121,12 @@ def test_candidate_generation(nlp):
 | 
				
			||||||
    assert len(mykb.get_candidates("adam")) == 1
 | 
					    assert len(mykb.get_candidates("adam")) == 1
 | 
				
			||||||
    assert len(mykb.get_candidates("shrubbery")) == 0
 | 
					    assert len(mykb.get_candidates("shrubbery")) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test the content of the candidates
 | 
				
			||||||
 | 
					    assert mykb.get_candidates("adam")[0].entity_ == "Q2"
 | 
				
			||||||
 | 
					    assert mykb.get_candidates("adam")[0].alias_ == "adam"
 | 
				
			||||||
 | 
					    assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
 | 
				
			||||||
 | 
					    assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_preserving_links_asdoc(nlp):
 | 
					def test_preserving_links_asdoc(nlp):
 | 
				
			||||||
    """Test that Span.as_doc preserves the existing entity links"""
 | 
					    """Test that Span.as_doc preserves the existing entity links"""
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user