Merge pull request #6222 from explosion/fix/initialize-clear

Clear rule-based components on initialize
This commit is contained in:
Ines Montani 2020-10-08 10:37:10 +02:00 committed by GitHub
commit eb28e8ce35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 2 deletions

View File

@ -53,10 +53,18 @@ class AttributeRuler(Pipe):
self.name = name
self.vocab = vocab
self.matcher = Matcher(self.vocab, validate=validate)
self.validate = validate
self.attrs = []
self._attrs_unnormed = [] # store for reference
self.indices = []
def clear(self) -> None:
"""Reset all patterns."""
self.matcher = Matcher(self.vocab, validate=self.validate)
self.attrs = []
self._attrs_unnormed = []
self.indices = []
def initialize(
self,
get_examples: Optional[Callable[[], Iterable[Example]]],
@ -65,13 +73,14 @@ class AttributeRuler(Pipe):
patterns: Optional[Iterable[AttributeRulerPatternType]] = None,
tag_map: Optional[TagMapType] = None,
morph_rules: Optional[MorphRulesType] = None,
):
) -> None:
"""Initialize the attribute ruler by adding zero or more patterns.
Rules can be specified as a sequence of dicts using the `patterns`
keyword argument. You can also provide rules using the "tag map" or
"morph rules" formats supported by spaCy prior to v3.
"""
self.clear()
if patterns:
self.add_patterns(patterns)
if tag_map:

View File

@ -201,10 +201,10 @@ class EntityRuler(Pipe):
DOCS: https://nightly.spacy.io/api/entityruler#initialize
"""
self.clear()
if patterns:
self.add_patterns(patterns)
@property
def ent_ids(self) -> Tuple[str, ...]:
"""All entity ids present in the match patterns `id` properties

View File

@ -136,6 +136,16 @@ def test_attributeruler_init_patterns(nlp, pattern_dicts):
assert doc.has_annotation("MORPH")
def test_attributeruler_init_clear(nlp, pattern_dicts):
"""Test that initialization clears patterns."""
ruler = nlp.add_pipe("attribute_ruler")
assert not len(ruler.matcher)
ruler.add_patterns(pattern_dicts)
assert len(ruler.matcher)
ruler.initialize(lambda: [])
assert not len(ruler.matcher)
def test_attributeruler_score(nlp, pattern_dicts):
# initialize with patterns
ruler = nlp.add_pipe("attribute_ruler")

View File

@ -68,6 +68,15 @@ def test_entity_ruler_init_patterns(nlp, patterns):
assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_init_clear(nlp, patterns):
"""Test that initialization clears patterns."""
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
assert len(ruler.labels) == 4
ruler.initialize(lambda: [])
assert len(ruler.labels) == 0
def test_entity_ruler_existing(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)