diff --git a/spacy/errors.py b/spacy/errors.py index 89fb9c289..a0da31a1e 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -60,6 +60,13 @@ class Warnings(object): "make displaCy start another one. Instead, you should be able to " "replace displacy.serve with displacy.render to show the " "visualization.") + W012 = ("A Doc object you're adding to the PhraseMatcher for pattern " + "'{key}' is parsed and/or tagged, but to match on '{attr}', you " + "don't actually need this information. This means that creating " + "the patterns is potentially much slower, because all pipeline " + "components are applied. To only create tokenized Doc objects, " + "try using `nlp.make_doc(text)` or process all texts as a stream " + "using `list(nlp.tokenizer.pipe(all_texts))`.") @add_codes diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 3cc890dca..4abf275be 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -7,12 +7,12 @@ from murmurhash.mrmr cimport hash64 from preshed.maps cimport PreshMap from .matcher cimport Matcher -from ..attrs cimport ORTH, attr_id_t +from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t from ..vocab cimport Vocab from ..tokens.doc cimport Doc, get_token_attr from ..typedefs cimport attr_t, hash_t -from ..errors import Warnings, deprecation_warning +from ..errors import Warnings, deprecation_warning, user_warning from ..attrs import FLAG61 as U_ENT from ..attrs import FLAG60 as B2_ENT from ..attrs import FLAG59 as B3_ENT @@ -33,8 +33,9 @@ cdef class PhraseMatcher: cdef attr_id_t attr cdef public object _callbacks cdef public object _patterns + cdef public object _validate - def __init__(self, Vocab vocab, max_length=0, attr='ORTH'): + def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False): if max_length != 0: deprecation_warning(Warnings.W010) self.mem = Pool() @@ -54,6 +55,7 @@ cdef class PhraseMatcher: ] self.matcher.add('Candidate', None, *abstract_patterns) self._callbacks = {} + self._validate = validate def __len__(self): """Get the number of rules added to the matcher. Note that this only @@ -95,6 +97,10 @@ cdef class PhraseMatcher: length = doc.length if length == 0: continue + if self._validate and (doc.is_tagged or doc.is_parsed) \ + and self.attr not in (DEP, POS, TAG, LEMMA): + string_attr = self.vocab.strings[self.attr] + user_warning(Warnings.W012.format(key=key, attr=string_attr)) tags = get_bilou(length) phrase_key = mem.alloc(length, sizeof(attr_t)) for i, tag in enumerate(tags): diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index 9ecd61465..a5fdb345e 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -1,6 +1,7 @@ # coding: utf-8 from __future__ import unicode_literals +import pytest from spacy.matcher import PhraseMatcher from spacy.tokens import Doc from ..util import get_doc @@ -78,3 +79,23 @@ def test_phrase_matcher_bool_attrs(en_vocab): assert end1 == 3 assert start2 == 3 assert end2 == 6 + + +def test_phrase_matcher_validation(en_vocab): + doc1 = Doc(en_vocab, words=["Test"]) + doc1.is_parsed = True + doc2 = Doc(en_vocab, words=["Test"]) + doc2.is_tagged = True + doc3 = Doc(en_vocab, words=["Test"]) + matcher = PhraseMatcher(en_vocab, validate=True) + with pytest.warns(UserWarning): + matcher.add("TEST1", None, doc1) + with pytest.warns(UserWarning): + matcher.add("TEST2", None, doc2) + with pytest.warns(None) as record: + matcher.add("TEST3", None, doc3) + assert not record.list + matcher = PhraseMatcher(en_vocab, attr="POS", validate=True) + with pytest.warns(None) as record: + matcher.add("TEST4", None, doc2) + assert not record.list