Added an argument to EntityRuler constructor to pass attrs to… (#3919)

* Perserve flags in EntityRuler

The EntityRuler (explosion/spaCy#3526) does not preserve
overwrite flags (or `ent_id_sep`) when serialized.  This
commit adds support for serialization/deserialization preserving
overwrite and ent_id_sep flags.

* add signed contributor agreement

* flake8 cleanup

mostly blank line issues.

* mark test from the issue as needing a model

The test from the issue needs some language model for serialization
but the test wasn't originally marked correctly.

* Adds `phrase_matcher_attr` to allow args to PhraseMatcher

This is an added arg to pass to the `PhraseMatcher`. For example,
this allows creation of a case insensitive phrase matcher when the
`EntityRuler` is created.  References explosion/spaCy#3822

* remove unneeded model loading

The model didn't need to be loaded, and I replaced it with
a change that doesn't require it (using existings fixtures)

* updated docstring for new argument

* updated docs to reflect new argument to the EntityRuler constructor

* change tempdir handling to be compatible with python 2.7

* return conflicted code to entityruler

Some stuff got cut out because of merge conflicts, this
returns that code for the phrase_matcher_attr.

* fixed typo in the code added back after conflicts

* flake8 compliance

When I deconflicted the branch there were some flake8 issues
introduced. This resolves the spacing problems.

* test changes:  attempts to fix flaky test in python3.5

These tests seem to be alittle flaky in 3.5 so I changed the check to avoid
the comparisons that seem to be fail sometimes.
This commit is contained in:
Joshua Smith 2019-07-09 14:09:17 -04:00 committed by Ines Montani
parent a795fbd3b2
commit 2eb925bd05
4 changed files with 46 additions and 4 deletions

View File

@ -26,7 +26,7 @@ class EntityRuler(object):
name = "entity_ruler" name = "entity_ruler"
def __init__(self, nlp, **cfg): def __init__(self, nlp, phrase_matcher_attr=None, **cfg):
"""Initialize the entitiy ruler. If patterns are supplied here, they """Initialize the entitiy ruler. If patterns are supplied here, they
need to be a list of dictionaries with a `"label"` and `"pattern"` need to be a list of dictionaries with a `"label"` and `"pattern"`
key. A pattern can either be a token pattern (list) or a phrase pattern key. A pattern can either be a token pattern (list) or a phrase pattern
@ -34,6 +34,8 @@ class EntityRuler(object):
nlp (Language): The shared nlp object to pass the vocab to the matchers nlp (Language): The shared nlp object to pass the vocab to the matchers
and process phrase patterns. and process phrase patterns.
phrase_matcher_attr (int / unicode): Token attribute to match on, passed
to the internal PhraseMatcher as `attr`
patterns (iterable): Optional patterns to load in. patterns (iterable): Optional patterns to load in.
overwrite_ents (bool): If existing entities are present, e.g. entities overwrite_ents (bool): If existing entities are present, e.g. entities
added by the model, overwrite them by matches if necessary. added by the model, overwrite them by matches if necessary.
@ -49,7 +51,12 @@ class EntityRuler(object):
self.token_patterns = defaultdict(list) self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list)
self.matcher = Matcher(nlp.vocab) self.matcher = Matcher(nlp.vocab)
self.phrase_matcher = PhraseMatcher(nlp.vocab) if phrase_matcher_attr is not None:
self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher(nlp.vocab, attr=self.phrase_matcher_attr)
else:
self.phrase_matcher_attr = None
self.phrase_matcher = PhraseMatcher(nlp.vocab)
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
patterns = cfg.get("patterns") patterns = cfg.get("patterns")
if patterns is not None: if patterns is not None:
@ -218,6 +225,10 @@ class EntityRuler(object):
if isinstance(cfg, dict): if isinstance(cfg, dict):
self.add_patterns(cfg.get('patterns', cfg)) self.add_patterns(cfg.get('patterns', cfg))
self.overwrite = cfg.get('overwrite', False) self.overwrite = cfg.get('overwrite', False)
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr', None)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(self.nlp.vocab,
attr=self.phrase_matcher_attr)
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP) self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
else: else:
self.add_patterns(cfg) self.add_patterns(cfg)
@ -234,6 +245,7 @@ class EntityRuler(object):
serial = OrderedDict(( serial = OrderedDict((
('overwrite', self.overwrite), ('overwrite', self.overwrite),
('ent_id_sep', self.ent_id_sep), ('ent_id_sep', self.ent_id_sep),
('phrase_matcher_attr', self.phrase_matcher_attr),
('patterns', self.patterns))) ('patterns', self.patterns)))
return srsly.msgpack_dumps(serial) return srsly.msgpack_dumps(serial)
@ -259,7 +271,12 @@ class EntityRuler(object):
} }
from_disk(path, deserializers, {}) from_disk(path, deserializers, {})
self.overwrite = cfg.get('overwrite', False) self.overwrite = cfg.get('overwrite', False)
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr')
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP) self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(self.nlp.vocab,
attr=self.phrase_matcher_attr)
return self return self
def to_disk(self, path, **kwargs): def to_disk(self, path, **kwargs):
@ -273,6 +290,7 @@ class EntityRuler(object):
DOCS: https://spacy.io/api/entityruler#to_disk DOCS: https://spacy.io/api/entityruler#to_disk
""" """
cfg = {'overwrite': self.overwrite, cfg = {'overwrite': self.overwrite,
'phrase_matcher_attr': self.phrase_matcher_attr,
'ent_id_sep': self.ent_id_sep} 'ent_id_sep': self.ent_id_sep}
serializers = { serializers = {
'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'), 'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'),

View File

@ -106,5 +106,24 @@ def test_entity_ruler_serialize_bytes(nlp, patterns):
assert len(new_ruler) == 0 assert len(new_ruler) == 0
assert len(new_ruler.labels) == 0 assert len(new_ruler.labels) == 0
new_ruler = new_ruler.from_bytes(ruler_bytes) new_ruler = new_ruler.from_bytes(ruler_bytes)
assert len(new_ruler) == len(patterns)
assert len(new_ruler.labels) == 4
assert len(new_ruler.patterns) == len(ruler.patterns)
for pattern in ruler.patterns:
assert pattern in new_ruler.patterns
assert new_ruler.labels == ruler.labels
def test_entity_ruler_serialize_phrase_matcher_attr_bytes(nlp, patterns):
ruler = EntityRuler(nlp, phrase_matcher_attr="LOWER", patterns=patterns)
assert len(ruler) == len(patterns) assert len(ruler) == len(patterns)
assert len(ruler.labels) == 4 assert len(ruler.labels) == 4
ruler_bytes = ruler.to_bytes()
new_ruler = EntityRuler(nlp)
assert len(new_ruler) == 0
assert len(new_ruler.labels) == 0
assert new_ruler.phrase_matcher_attr is None
new_ruler = new_ruler.from_bytes(ruler_bytes)
assert len(new_ruler) == len(patterns)
assert len(new_ruler.labels) == 4
assert new_ruler.phrase_matcher_attr == "LOWER"

View File

@ -9,6 +9,7 @@ from spacy import load
import srsly import srsly
from ..util import make_tempdir from ..util import make_tempdir
@pytest.fixture @pytest.fixture
def patterns(): def patterns():
return [ return [
@ -28,6 +29,7 @@ def add_ent():
return add_ent_component return add_ent_component
def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab): def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab):
nlp = Language(vocab=en_vocab) nlp = Language(vocab=en_vocab)
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True) ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
@ -50,7 +52,8 @@ def test_entity_ruler_existing_bytes_old_format_safe(patterns, en_vocab):
new_ruler = EntityRuler(nlp) new_ruler = EntityRuler(nlp)
new_ruler = new_ruler.from_bytes(bytes_old_style) new_ruler = new_ruler.from_bytes(bytes_old_style)
assert len(new_ruler) == len(ruler) assert len(new_ruler) == len(ruler)
assert new_ruler.patterns == ruler.patterns for pattern in ruler.patterns:
assert pattern in new_ruler.patterns
assert new_ruler.overwrite is not ruler.overwrite assert new_ruler.overwrite is not ruler.overwrite
@ -62,7 +65,8 @@ def test_entity_ruler_from_disk_old_format_safe(patterns, en_vocab):
srsly.write_jsonl(out_file, ruler.patterns) srsly.write_jsonl(out_file, ruler.patterns)
new_ruler = EntityRuler(nlp) new_ruler = EntityRuler(nlp)
new_ruler = new_ruler.from_disk(out_file) new_ruler = new_ruler.from_disk(out_file)
assert new_ruler.patterns == ruler.patterns for pattern in ruler.patterns:
assert pattern in new_ruler.patterns
assert len(new_ruler) == len(ruler) assert len(new_ruler) == len(ruler)
assert new_ruler.overwrite is not ruler.overwrite assert new_ruler.overwrite is not ruler.overwrite

View File

@ -34,6 +34,7 @@ be a token pattern (list) or a phrase pattern (string). For example:
| ---------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- | | ---------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. | | `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. |
| `patterns` | iterable | Optional patterns to load in. | | `patterns` | iterable | Optional patterns to load in. |
| `phrase_matcher_attr` | int / unicode | Optional attr to pass to the internal [`PhraseMatcher`](/api/phtasematcher). defaults to `None`
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. | | `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. |
| `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. | | `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. |
| **RETURNS** | `EntityRuler` | The newly constructed object. | | **RETURNS** | `EntityRuler` | The newly constructed object. |