mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Don't add duplicate patterns (fix #8216) * Refactor EntityRuler init This simplifies the EntityRuler init code. This is helpful as prep for allowing the EntityRuler to reset itself. * Make EntityRuler.clear reset matchers Includes a new test for this. * Tidy PhraseMatcher instantiation Since the attr can be None safely now, the guard if is no longer required here. Also renamed the `_validate` attr. Maybe it's not needed? * Fix NER test * Add test to make sure patterns aren't increasing * Move test to regression tests
This commit is contained in:
		
							parent
							
								
									1db18732e0
								
							
						
					
					
						commit
						ad026dc5fd
					
				|  | @ -50,6 +50,8 @@ cdef class PhraseMatcher: | ||||||
|         if isinstance(attr, (int, long)): |         if isinstance(attr, (int, long)): | ||||||
|             self.attr = attr |             self.attr = attr | ||||||
|         else: |         else: | ||||||
|  |             if attr is None: | ||||||
|  |                 attr = "ORTH" | ||||||
|             attr = attr.upper() |             attr = attr.upper() | ||||||
|             if attr == "TEXT": |             if attr == "TEXT": | ||||||
|                 attr = "ORTH" |                 attr = "ORTH" | ||||||
|  |  | ||||||
|  | @ -102,17 +102,12 @@ class EntityRuler(Pipe): | ||||||
|         self.overwrite = overwrite_ents |         self.overwrite = overwrite_ents | ||||||
|         self.token_patterns = defaultdict(list) |         self.token_patterns = defaultdict(list) | ||||||
|         self.phrase_patterns = defaultdict(list) |         self.phrase_patterns = defaultdict(list) | ||||||
|  |         self._validate = validate | ||||||
|         self.matcher = Matcher(nlp.vocab, validate=validate) |         self.matcher = Matcher(nlp.vocab, validate=validate) | ||||||
|         if phrase_matcher_attr is not None: |  | ||||||
|             if phrase_matcher_attr.upper() == "TEXT": |  | ||||||
|                 phrase_matcher_attr = "ORTH" |  | ||||||
|         self.phrase_matcher_attr = phrase_matcher_attr |         self.phrase_matcher_attr = phrase_matcher_attr | ||||||
|         self.phrase_matcher = PhraseMatcher( |         self.phrase_matcher = PhraseMatcher( | ||||||
|             nlp.vocab, attr=self.phrase_matcher_attr, validate=validate |             nlp.vocab, attr=self.phrase_matcher_attr, validate=validate | ||||||
|         ) |         ) | ||||||
|         else: |  | ||||||
|             self.phrase_matcher_attr = None |  | ||||||
|             self.phrase_matcher = PhraseMatcher(nlp.vocab, validate=validate) |  | ||||||
|         self.ent_id_sep = ent_id_sep |         self.ent_id_sep = ent_id_sep | ||||||
|         self._ent_ids = defaultdict(dict) |         self._ent_ids = defaultdict(dict) | ||||||
|         if patterns is not None: |         if patterns is not None: | ||||||
|  | @ -317,20 +312,22 @@ class EntityRuler(Pipe): | ||||||
|                 pattern = entry["pattern"] |                 pattern = entry["pattern"] | ||||||
|                 if isinstance(pattern, Doc): |                 if isinstance(pattern, Doc): | ||||||
|                     self.phrase_patterns[label].append(pattern) |                     self.phrase_patterns[label].append(pattern) | ||||||
|  |                     self.phrase_matcher.add(label, [pattern]) | ||||||
|                 elif isinstance(pattern, list): |                 elif isinstance(pattern, list): | ||||||
|                     self.token_patterns[label].append(pattern) |                     self.token_patterns[label].append(pattern) | ||||||
|  |                     self.matcher.add(label, [pattern]) | ||||||
|                 else: |                 else: | ||||||
|                     raise ValueError(Errors.E097.format(pattern=pattern)) |                     raise ValueError(Errors.E097.format(pattern=pattern)) | ||||||
|             for label, patterns in self.token_patterns.items(): |  | ||||||
|                 self.matcher.add(label, patterns) |  | ||||||
|             for label, patterns in self.phrase_patterns.items(): |  | ||||||
|                 self.phrase_matcher.add(label, patterns) |  | ||||||
| 
 | 
 | ||||||
|     def clear(self) -> None: |     def clear(self) -> None: | ||||||
|         """Reset all patterns.""" |         """Reset all patterns.""" | ||||||
|         self.token_patterns = defaultdict(list) |         self.token_patterns = defaultdict(list) | ||||||
|         self.phrase_patterns = defaultdict(list) |         self.phrase_patterns = defaultdict(list) | ||||||
|         self._ent_ids = defaultdict(dict) |         self._ent_ids = defaultdict(dict) | ||||||
|  |         self.matcher = Matcher(self.nlp.vocab, validate=self._validate) | ||||||
|  |         self.phrase_matcher = PhraseMatcher( | ||||||
|  |             self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def _require_patterns(self) -> None: |     def _require_patterns(self) -> None: | ||||||
|         """Raise a warning if this component has no patterns defined.""" |         """Raise a warning if this component has no patterns defined.""" | ||||||
|  | @ -381,7 +378,6 @@ class EntityRuler(Pipe): | ||||||
|             self.add_patterns(cfg.get("patterns", cfg)) |             self.add_patterns(cfg.get("patterns", cfg)) | ||||||
|             self.overwrite = cfg.get("overwrite", False) |             self.overwrite = cfg.get("overwrite", False) | ||||||
|             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) |             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) | ||||||
|             if self.phrase_matcher_attr is not None: |  | ||||||
|             self.phrase_matcher = PhraseMatcher( |             self.phrase_matcher = PhraseMatcher( | ||||||
|                 self.nlp.vocab, attr=self.phrase_matcher_attr |                 self.nlp.vocab, attr=self.phrase_matcher_attr | ||||||
|             ) |             ) | ||||||
|  | @ -435,7 +431,6 @@ class EntityRuler(Pipe): | ||||||
|             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") |             self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") | ||||||
|             self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) |             self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) | ||||||
| 
 | 
 | ||||||
|             if self.phrase_matcher_attr is not None: |  | ||||||
|             self.phrase_matcher = PhraseMatcher( |             self.phrase_matcher = PhraseMatcher( | ||||||
|                 self.nlp.vocab, attr=self.phrase_matcher_attr |                 self.nlp.vocab, attr=self.phrase_matcher_attr | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  | @ -252,12 +252,12 @@ def test_ruler_before_ner(): | ||||||
|     # 1 : Entity Ruler - should set "this" to B and everything else to empty |     # 1 : Entity Ruler - should set "this" to B and everything else to empty | ||||||
|     patterns = [{"label": "THING", "pattern": "This"}] |     patterns = [{"label": "THING", "pattern": "This"}] | ||||||
|     ruler = nlp.add_pipe("entity_ruler") |     ruler = nlp.add_pipe("entity_ruler") | ||||||
|     ruler.add_patterns(patterns) |  | ||||||
| 
 | 
 | ||||||
|     # 2: untrained NER - should set everything else to O |     # 2: untrained NER - should set everything else to O | ||||||
|     untrained_ner = nlp.add_pipe("ner") |     untrained_ner = nlp.add_pipe("ner") | ||||||
|     untrained_ner.add_label("MY_LABEL") |     untrained_ner.add_label("MY_LABEL") | ||||||
|     nlp.initialize() |     nlp.initialize() | ||||||
|  |     ruler.add_patterns(patterns) | ||||||
|     doc = nlp("This is Antti Korhonen speaking in Finland") |     doc = nlp("This is Antti Korhonen speaking in Finland") | ||||||
|     expected_iobs = ["B", "O", "O", "O", "O", "O", "O"] |     expected_iobs = ["B", "O", "O", "O", "O", "O", "O"] | ||||||
|     expected_types = ["THING", "", "", "", "", "", ""] |     expected_types = ["THING", "", "", "", "", "", ""] | ||||||
|  |  | ||||||
|  | @ -89,6 +89,19 @@ def test_entity_ruler_init_clear(nlp, patterns): | ||||||
|     assert len(ruler.labels) == 0 |     assert len(ruler.labels) == 0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_entity_ruler_clear(nlp, patterns): | ||||||
|  |     """Test that initialization clears patterns.""" | ||||||
|  |     ruler = nlp.add_pipe("entity_ruler") | ||||||
|  |     ruler.add_patterns(patterns) | ||||||
|  |     assert len(ruler.labels) == 4 | ||||||
|  |     doc = nlp("hello world") | ||||||
|  |     assert len(doc.ents) == 1 | ||||||
|  |     ruler.clear() | ||||||
|  |     assert len(ruler.labels) == 0 | ||||||
|  |     doc = nlp("hello world") | ||||||
|  |     assert len(doc.ents) == 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_entity_ruler_existing(nlp, patterns): | def test_entity_ruler_existing(nlp, patterns): | ||||||
|     ruler = nlp.add_pipe("entity_ruler") |     ruler = nlp.add_pipe("entity_ruler") | ||||||
|     ruler.add_patterns(patterns) |     ruler.add_patterns(patterns) | ||||||
|  |  | ||||||
							
								
								
									
										34
									
								
								spacy/tests/regression/test_issue8216.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								spacy/tests/regression/test_issue8216.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,34 @@ | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | from spacy import registry | ||||||
|  | from spacy.language import Language | ||||||
|  | from spacy.pipeline import EntityRuler | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def nlp(): | ||||||
|  |     return Language() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | @registry.misc("entity_ruler_patterns") | ||||||
|  | def patterns(): | ||||||
|  |     return [ | ||||||
|  |         {"label": "HELLO", "pattern": "hello world"}, | ||||||
|  |         {"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]}, | ||||||
|  |         {"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]}, | ||||||
|  |         {"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]}, | ||||||
|  |         {"label": "TECH_ORG", "pattern": "Apple", "id": "a1"}, | ||||||
|  |         {"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"}, | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_entity_ruler_fix8216(nlp, patterns): | ||||||
|  |     """Test that patterns don't get added excessively.""" | ||||||
|  |     ruler = nlp.add_pipe("entity_ruler", config={"validate": True}) | ||||||
|  |     ruler.add_patterns(patterns) | ||||||
|  |     pattern_count = sum(len(mm) for mm in ruler.matcher._patterns.values()) | ||||||
|  |     assert pattern_count > 0 | ||||||
|  |     ruler.add_patterns([]) | ||||||
|  |     after_count = sum(len(mm) for mm in ruler.matcher._patterns.values()) | ||||||
|  |     assert after_count == pattern_count | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user