diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py index 61649d676..706873290 100644 --- a/spacy/pipeline/attributeruler.py +++ b/spacy/pipeline/attributeruler.py @@ -1,5 +1,5 @@ import srsly -from typing import List, Dict, Union, Iterable +from typing import List, Dict, Union, Iterable, Any from pathlib import Path from .pipe import Pipe @@ -7,11 +7,14 @@ from ..errors import Errors from ..language import Language from ..matcher import Matcher from ..symbols import IDS -from ..tokens import Doc +from ..tokens import Doc, Span from ..vocab import Vocab from .. import util +MatcherPatternType = List[Dict[Union[int, str], Any]] + + @Language.factory( "attribute_ruler", assigns=[], @@ -20,9 +23,9 @@ from .. import util default_score_weights={}, ) def make_attribute_ruler( - nlp: Language, name: str, + nlp: Language, name: str, pattern_dicts: Iterable[Dict] = tuple() ): - return AttributeRuler(nlp.vocab, name) + return AttributeRuler(nlp.vocab, name, pattern_dicts=pattern_dicts) class AttributeRuler(Pipe): @@ -32,10 +35,22 @@ class AttributeRuler(Pipe): DOCS: https://spacy.io/api/attributeruler """ - def __init__(self, vocab: Vocab, name: str = "attribute_ruler") -> None: - """Initialize the attributeruler. + def __init__( + self, + vocab: Vocab, + name: str = "attribute_ruler", + *, + pattern_dicts: List[Dict[str, Union[List[MatcherPatternType], Dict, int]]] = {}, + ) -> None: + """Initialize the AttributeRuler. - RETURNS (AttributeRuler): The attributeruler component. + vocab (Vocab): The vocab. + name (str): The pipe name. Defaults to "attribute_ruler". + pattern_dicts (List[Dict]): A list of pattern dicts with the keys as + the arguments to AttributeRuler.add (`patterns`/`attrs`/`index`) to add + as patterns. + + RETURNS (AttributeRuler): The AttributeRuler component. DOCS: https://spacy.io/api/attributeruler#init """ @@ -45,6 +60,9 @@ class AttributeRuler(Pipe): self.attrs = [] self.indices = [] + for p in pattern_dicts: + self.add(**p) + def __call__(self, doc: Doc) -> Doc: """Apply the attributeruler to a Doc and set all attribute exceptions. @@ -54,21 +72,36 @@ class AttributeRuler(Pipe): DOCS: https://spacy.io/api/attributeruler#call """ matches = self.matcher(doc) - with doc.retokenize() as retokenizer: - for match_id, start, end in matches: - attrs = self.attrs[match_id] - index = self.indices[match_id] - token = doc[start:end][index] - if start <= token.i < end: - retokenizer.merge(doc[token.i : token.i + 1], attrs) - else: - raise ValueError( - Errors.E1001.format( - patterns=self.matcher.get(match_id), - span=[t.text for t in doc[start:end]], - index=index, + + # Multiple patterns may apply to the same token but the retokenizer can + # only handle one merge per token, so split the matches into sets of + # disjoint spans. + original_spans = set( + [Span(doc, start, end, label=match_id) for match_id, start, end in matches] + ) + disjoint_span_sets = [] + while original_spans: + filtered_spans = set(util.filter_spans(original_spans)) + disjoint_span_sets.append(filtered_spans) + original_spans -= filtered_spans + + # Retokenize with each set of disjoint spans separately + for span_set in disjoint_span_sets: + with doc.retokenize() as retokenizer: + for span in span_set: + attrs = self.attrs[span.label] + index = self.indices[span.label] + token = span[index] + if span.start <= token.i < span.end: + retokenizer.merge(doc[token.i : token.i + 1], attrs) + else: + raise ValueError( + Errors.E1001.format( + patterns=self.matcher.get(span.label), + span=[t.text for t in span], + index=index, + ) ) - ) return doc def load_from_tag_map( @@ -93,12 +126,14 @@ class AttributeRuler(Pipe): attrs["MORPH"] = self.vocab.strings[morph] self.add([pattern], attrs) - def add(self, patterns: List[List[Dict]], attrs: Dict, index: int = 0) -> None: + def add( + self, patterns: Iterable[MatcherPatternType], attrs: Dict, index: int = 0 + ) -> None: """Add Matcher patterns for tokens that should be modified with the provided attributes. The token at the specified index within the matched span will be assigned the attributes. - pattern (List[List[Dict]]): A list of Matcher patterns. + pattern (Iterable[List[Dict]]): A list of Matcher patterns. attrs (Dict): The attributes to assign to the target token in the matched span. index (int): The index of the token in the matched span to modify. May diff --git a/spacy/tests/pipeline/test_attributeruler.py b/spacy/tests/pipeline/test_attributeruler.py index bee286b83..d168a948d 100644 --- a/spacy/tests/pipeline/test_attributeruler.py +++ b/spacy/tests/pipeline/test_attributeruler.py @@ -12,6 +12,24 @@ def nlp(): return English() +@pytest.fixture +def pattern_dicts(): + return [ + { + "patterns": [[{"ORTH": "a"}]], + "attrs": {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, + }, + # one pattern sets the lemma + {"patterns": [[{"ORTH": "test"}]], "attrs": {"LEMMA": "cat"}}, + # another pattern sets the morphology + { + "patterns": [[{"ORTH": "test"}]], + "attrs": {"MORPH": "Case=Nom|Number=Sing"}, + "index": 0, + }, + ] + + @pytest.fixture def tag_map(): return { @@ -25,13 +43,21 @@ def morph_rules(): return {"DT": {"the": {"POS": "DET", "LEMMA": "a", "Case": "Nom"}}} -def test_attributeruler_init(nlp): - a = AttributeRuler(nlp.vocab) - +def test_attributeruler_init(nlp, pattern_dicts): a = nlp.add_pipe("attribute_ruler") - a.add([[{"ORTH": "a"}]], {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}) - a.add([[{"ORTH": "test"}]], {"LEMMA": "cat", "MORPH": "Number=Sing|Case=Nom"}) - a.add([[{"ORTH": "test"}]], {"LEMMA": "dog"}) + for p in pattern_dicts: + a.add(**p) + + doc = nlp("This is a test.") + assert doc[2].lemma_ == "the" + assert doc[2].morph_ == "Case=Nom|Number=Plur" + assert doc[3].lemma_ == "cat" + assert doc[3].morph_ == "Case=Nom|Number=Sing" + + +def test_attributeruler_init_patterns(nlp, pattern_dicts): + # initialize with patterns + a = nlp.add_pipe("attribute_ruler", config={"pattern_dicts": pattern_dicts}) doc = nlp("This is a test.") assert doc[2].lemma_ == "the" @@ -43,7 +69,11 @@ def test_attributeruler_init(nlp): def test_attributeruler_tag_map(nlp, tag_map): a = AttributeRuler(nlp.vocab) a.load_from_tag_map(tag_map) - doc = get_doc(nlp.vocab, words=["This", "is", "a", "test", "."], tags=["DT", "VBZ", "DT", "NN", "."]) + doc = get_doc( + nlp.vocab, + words=["This", "is", "a", "test", "."], + tags=["DT", "VBZ", "DT", "NN", "."], + ) doc = a(doc) for i in range(len(doc)): @@ -58,7 +88,11 @@ def test_attributeruler_tag_map(nlp, tag_map): def test_attributeruler_morph_rules(nlp, morph_rules): a = AttributeRuler(nlp.vocab) a.load_from_morph_rules(morph_rules) - doc = get_doc(nlp.vocab, words=["This", "is", "the", "test", "."], tags=["DT", "VBZ", "DT", "NN", "."]) + doc = get_doc( + nlp.vocab, + words=["This", "is", "the", "test", "."], + tags=["DT", "VBZ", "DT", "NN", "."], + ) doc = a(doc) for i in range(len(doc)): @@ -73,8 +107,16 @@ def test_attributeruler_morph_rules(nlp, morph_rules): def test_attributeruler_indices(nlp): a = nlp.add_pipe("attribute_ruler") - a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, index=0) - a.add([[{"ORTH": "This"}, {"ORTH": "is"}]], {"LEMMA": "was", "MORPH": "Case=Nom|Number=Sing"}, index=1) + a.add( + [[{"ORTH": "a"}, {"ORTH": "test"}]], + {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, + index=0, + ) + a.add( + [[{"ORTH": "This"}, {"ORTH": "is"}]], + {"LEMMA": "was", "MORPH": "Case=Nom|Number=Sing"}, + index=1, + ) a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=-1) text = "This is a test." @@ -97,10 +139,19 @@ def test_attributeruler_indices(nlp): with pytest.raises(ValueError): doc = nlp(text) -def test_attributeruler_serialize(nlp): + +def test_attributeruler_serialize(nlp, pattern_dicts): a = nlp.add_pipe("attribute_ruler") - a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, index=0) - a.add([[{"ORTH": "This"}, {"ORTH": "is"}]], {"LEMMA": "was", "MORPH": "Case=Nom|Number=Sing"}, index=1) + a.add( + [[{"ORTH": "a"}, {"ORTH": "test"}]], + {"LEMMA": "the", "MORPH": "Case=Nom|Number=Plur"}, + index=0, + ) + a.add( + [[{"ORTH": "This"}, {"ORTH": "is"}]], + {"LEMMA": "was", "MORPH": "Case=Nom|Number=Sing"}, + index=1, + ) a.add([[{"ORTH": "a"}, {"ORTH": "test"}]], {"LEMMA": "cat"}, index=-1) text = "This is a test."