mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	check and unit test in case prior probs exceed 1
This commit is contained in:
		
							parent
							
								
									b55baaa1dc
								
							
						
					
					
						commit
						33f8a0fe2e
					
				|  | @ -35,6 +35,13 @@ cdef class KnowledgeBase: | ||||||
| 
 | 
 | ||||||
|     def add_alias(self, unicode alias, entities, probabilities): |     def add_alias(self, unicode alias, entities, probabilities): | ||||||
|         """For a given alias, add its potential entities and prior probabilies to the KB.""" |         """For a given alias, add its potential entities and prior probabilies to the KB.""" | ||||||
|  | 
 | ||||||
|  |         # Throw an error if the probabilities sum up to more than 1 | ||||||
|  |         prob_sum = sum(probabilities) | ||||||
|  |         if prob_sum > 1: | ||||||
|  |             raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, " | ||||||
|  |                                                                                     "but found " + str(prob_sum)) | ||||||
|  | 
 | ||||||
|         cdef hash_t alias_hash = self.strings.add(alias) |         cdef hash_t alias_hash = self.strings.add(alias) | ||||||
| 
 | 
 | ||||||
|         # Return if this alias was added before |         # Return if this alias was added before | ||||||
|  |  | ||||||
|  | @ -42,6 +42,12 @@ def create_kb(): | ||||||
| 
 | 
 | ||||||
|     print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) |     print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) | ||||||
| 
 | 
 | ||||||
|  |     alias2 = "johny" | ||||||
|  |     print(" adding alias2", alias2) | ||||||
|  |     mykb.add_alias(alias=alias2, entities=["Q0", "Q42"], probabilities=[0.3, 1.1]) | ||||||
|  | 
 | ||||||
|  |     print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) | ||||||
|  | 
 | ||||||
|     print("candidates for", alias) |     print("candidates for", alias) | ||||||
|     candidates = mykb.get_candidates(alias) |     candidates = mykb.get_candidates(alias) | ||||||
|     print(" ", candidates) |     print(" ", candidates) | ||||||
|  |  | ||||||
|  | @ -1,14 +1,16 @@ | ||||||
|  | # coding: utf-8 | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from spacy.kb import KnowledgeBase | from spacy.kb import KnowledgeBase | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_kb_valid_entities(): | def test_kb_valid_entities(): | ||||||
|  |     """Test the valid construction of a KB with 3 entities and one alias""" | ||||||
|     mykb = KnowledgeBase() |     mykb = KnowledgeBase() | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity_id="Q1", prob=0.5) |     mykb.add_entity(entity_id="Q1", prob=0.9) | ||||||
|     mykb.add_entity(entity_id="Q2", prob=0.5) |     mykb.add_entity(entity_id="Q2", prob=0.2) | ||||||
|     mykb.add_entity(entity_id="Q3", prob=0.5) |     mykb.add_entity(entity_id="Q3", prob=0.5) | ||||||
| 
 | 
 | ||||||
|     # adding aliases |     # adding aliases | ||||||
|  | @ -16,14 +18,29 @@ def test_kb_valid_entities(): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_kb_invalid_entities(): | def test_kb_invalid_entities(): | ||||||
|  |     """Test the invalid construction of a KB with an alias linked to a non-existing entity""" | ||||||
|     mykb = KnowledgeBase() |     mykb = KnowledgeBase() | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity_id="Q1", prob=0.5) |     mykb.add_entity(entity_id="Q1", prob=0.9) | ||||||
|     mykb.add_entity(entity_id="Q2", prob=0.5) |     mykb.add_entity(entity_id="Q2", prob=0.2) | ||||||
|     mykb.add_entity(entity_id="Q3", prob=0.5) |     mykb.add_entity(entity_id="Q3", prob=0.5) | ||||||
| 
 | 
 | ||||||
|     # adding aliases - should fail because one of the given IDs is not valid |     # adding aliases - should fail because one of the given IDs is not valid | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|         mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]) |         mykb.add_alias(alias="douglassss", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | def test_kb_invalid_probabilities(): | ||||||
|  |     """Test the invalid construction of a KB with wrong prior probabilities""" | ||||||
|  |     mykb = KnowledgeBase() | ||||||
|  | 
 | ||||||
|  |     # adding entities | ||||||
|  |     mykb.add_entity(entity_id="Q1", prob=0.9) | ||||||
|  |     mykb.add_entity(entity_id="Q2", prob=0.2) | ||||||
|  |     mykb.add_entity(entity_id="Q3", prob=0.5) | ||||||
|  | 
 | ||||||
|  |     # adding aliases - should fail because the sum of the probabilities exceeds 1 | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4]) | ||||||
|  | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user