mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Added an argument to EntityRuler
constructor to pass attrs to… (#3919)
* Perserve flags in EntityRuler The EntityRuler (explosion/spaCy#3526) does not preserve overwrite flags (or `ent_id_sep`) when serialized. This commit adds support for serialization/deserialization preserving overwrite and ent_id_sep flags. * add signed contributor agreement * flake8 cleanup mostly blank line issues. * mark test from the issue as needing a model The test from the issue needs some language model for serialization but the test wasn't originally marked correctly. * Adds `phrase_matcher_attr` to allow args to PhraseMatcher This is an added arg to pass to the `PhraseMatcher`. For example, this allows creation of a case insensitive phrase matcher when the `EntityRuler` is created. References explosion/spaCy#3822 * remove unneeded model loading The model didn't need to be loaded, and I replaced it with a change that doesn't require it (using existings fixtures) * updated docstring for new argument * updated docs to reflect new argument to the EntityRuler constructor * change tempdir handling to be compatible with python 2.7 * return conflicted code to entityruler Some stuff got cut out because of merge conflicts, this returns that code for the phrase_matcher_attr. * fixed typo in the code added back after conflicts * flake8 compliance When I deconflicted the branch there were some flake8 issues introduced. This resolves the spacing problems. * test changes: attempts to fix flaky test in python3.5 These tests seem to be alittle flaky in 3.5 so I changed the check to avoid the comparisons that seem to be fail sometimes.
This commit is contained in:
parent
a795fbd3b2
commit
2eb925bd05
|
@ -26,7 +26,7 @@ class EntityRuler(object):
|
||||||
|
|
||||||
name = "entity_ruler"
|
name = "entity_ruler"
|
||||||
|
|
||||||
def __init__(self, nlp, **cfg):
|
def __init__(self, nlp, phrase_matcher_attr=None, **cfg):
|
||||||
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
"""Initialize the entitiy ruler. If patterns are supplied here, they
|
||||||
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
need to be a list of dictionaries with a `"label"` and `"pattern"`
|
||||||
key. A pattern can either be a token pattern (list) or a phrase pattern
|
key. A pattern can either be a token pattern (list) or a phrase pattern
|
||||||
|
@ -34,6 +34,8 @@ class EntityRuler(object):
|
||||||
|
|
||||||
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
||||||
and process phrase patterns.
|
and process phrase patterns.
|
||||||
|
phrase_matcher_attr (int / unicode): Token attribute to match on, passed
|
||||||
|
to the internal PhraseMatcher as `attr`
|
||||||
patterns (iterable): Optional patterns to load in.
|
patterns (iterable): Optional patterns to load in.
|
||||||
overwrite_ents (bool): If existing entities are present, e.g. entities
|
overwrite_ents (bool): If existing entities are present, e.g. entities
|
||||||
added by the model, overwrite them by matches if necessary.
|
added by the model, overwrite them by matches if necessary.
|
||||||
|
@ -49,7 +51,12 @@ class EntityRuler(object):
|
||||||
self.token_patterns = defaultdict(list)
|
self.token_patterns = defaultdict(list)
|
||||||
self.phrase_patterns = defaultdict(list)
|
self.phrase_patterns = defaultdict(list)
|
||||||
self.matcher = Matcher(nlp.vocab)
|
self.matcher = Matcher(nlp.vocab)
|
||||||
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
if phrase_matcher_attr is not None:
|
||||||
|
self.phrase_matcher_attr = phrase_matcher_attr
|
||||||
|
self.phrase_matcher = PhraseMatcher(nlp.vocab, attr=self.phrase_matcher_attr)
|
||||||
|
else:
|
||||||
|
self.phrase_matcher_attr = None
|
||||||
|
self.phrase_matcher = PhraseMatcher(nlp.vocab)
|
||||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||||
patterns = cfg.get("patterns")
|
patterns = cfg.get("patterns")
|
||||||
if patterns is not None:
|
if patterns is not None:
|
||||||
|
@ -218,6 +225,10 @@ class EntityRuler(object):
|
||||||
if isinstance(cfg, dict):
|
if isinstance(cfg, dict):
|
||||||
self.add_patterns(cfg.get('patterns', cfg))
|
self.add_patterns(cfg.get('patterns', cfg))
|
||||||
self.overwrite = cfg.get('overwrite', False)
|
self.overwrite = cfg.get('overwrite', False)
|
||||||
|
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr', None)
|
||||||
|
if self.phrase_matcher_attr is not None:
|
||||||
|
self.phrase_matcher = PhraseMatcher(self.nlp.vocab,
|
||||||
|
attr=self.phrase_matcher_attr)
|
||||||
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
|
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
|
||||||
else:
|
else:
|
||||||
self.add_patterns(cfg)
|
self.add_patterns(cfg)
|
||||||
|
@ -234,6 +245,7 @@ class EntityRuler(object):
|
||||||
serial = OrderedDict((
|
serial = OrderedDict((
|
||||||
('overwrite', self.overwrite),
|
('overwrite', self.overwrite),
|
||||||
('ent_id_sep', self.ent_id_sep),
|
('ent_id_sep', self.ent_id_sep),
|
||||||
|
('phrase_matcher_attr', self.phrase_matcher_attr),
|
||||||
('patterns', self.patterns)))
|
('patterns', self.patterns)))
|
||||||
return srsly.msgpack_dumps(serial)
|
return srsly.msgpack_dumps(serial)
|
||||||
|
|
||||||
|
@ -259,7 +271,12 @@ class EntityRuler(object):
|
||||||
}
|
}
|
||||||
from_disk(path, deserializers, {})
|
from_disk(path, deserializers, {})
|
||||||
self.overwrite = cfg.get('overwrite', False)
|
self.overwrite = cfg.get('overwrite', False)
|
||||||
|
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr')
|
||||||
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
|
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP)
|
||||||
|
|
||||||
|
if self.phrase_matcher_attr is not None:
|
||||||
|
self.phrase_matcher = PhraseMatcher(self.nlp.vocab,
|
||||||
|
attr=self.phrase_matcher_attr)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, **kwargs):
|
def to_disk(self, path, **kwargs):
|
||||||
|
@ -273,6 +290,7 @@ class EntityRuler(object):
|
||||||
DOCS: https://spacy.io/api/entityruler#to_disk
|
DOCS: https://spacy.io/api/entityruler#to_disk
|
||||||
"""
|
"""
|
||||||
cfg = {'overwrite': self.overwrite,
|
cfg = {'overwrite': self.overwrite,
|
||||||
|
'phrase_matcher_attr': self.phrase_matcher_attr,
|
||||||
'ent_id_sep': self.ent_id_sep}
|
'ent_id_sep': self.ent_id_sep}
|
||||||
serializers = {
|
serializers = {
|
||||||
'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'),
|
'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'),
|
||||||
|
|
|
@ -106,5 +106,24 @@ def test_entity_ruler_serialize_bytes(nlp, patterns):
|
||||||
assert len(new_ruler) == 0
|
assert len(new_ruler) == 0
|
||||||
assert len(new_ruler.labels) == 0
|
assert len(new_ruler.labels) == 0
|
||||||
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
||||||
|
assert len(new_ruler) == len(patterns)
|
||||||
|
assert len(new_ruler.labels) == 4
|
||||||
|
assert len(new_ruler.patterns) == len(ruler.patterns)
|
||||||
|
for pattern in ruler.patterns:
|
||||||
|
assert pattern in new_ruler.patterns
|
||||||
|
assert new_ruler.labels == ruler.labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_serialize_phrase_matcher_attr_bytes(nlp, patterns):
|
||||||
|
ruler = EntityRuler(nlp, phrase_matcher_attr="LOWER", patterns=patterns)
|
||||||
assert len(ruler) == len(patterns)
|
assert len(ruler) == len(patterns)
|
||||||
assert len(ruler.labels) == 4
|
assert len(ruler.labels) == 4
|
||||||
|
ruler_bytes = ruler.to_bytes()
|
||||||
|
new_ruler = EntityRuler(nlp)
|
||||||
|
assert len(new_ruler) == 0
|
||||||
|
assert len(new_ruler.labels) == 0
|
||||||
|
assert new_ruler.phrase_matcher_attr is None
|
||||||
|
new_ruler = new_ruler.from_bytes(ruler_bytes)
|
||||||
|
assert len(new_ruler) == len(patterns)
|
||||||
|
assert len(new_ruler.labels) == 4
|
||||||
|
assert new_ruler.phrase_matcher_attr == "LOWER"
|
||||||
|
|
|
@ -9,6 +9,7 @@ from spacy import load
|
||||||
import srsly
|
import srsly
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patterns():
|
def patterns():
|
||||||
return [
|
return [
|
||||||
|
@ -28,6 +29,7 @@ def add_ent():
|
||||||
|
|
||||||
return add_ent_component
|
return add_ent_component
|
||||||
|
|
||||||
|
|
||||||
def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab):
|
def test_entity_ruler_existing_overwrite_serialize_bytes(patterns, en_vocab):
|
||||||
nlp = Language(vocab=en_vocab)
|
nlp = Language(vocab=en_vocab)
|
||||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||||
|
@ -50,7 +52,8 @@ def test_entity_ruler_existing_bytes_old_format_safe(patterns, en_vocab):
|
||||||
new_ruler = EntityRuler(nlp)
|
new_ruler = EntityRuler(nlp)
|
||||||
new_ruler = new_ruler.from_bytes(bytes_old_style)
|
new_ruler = new_ruler.from_bytes(bytes_old_style)
|
||||||
assert len(new_ruler) == len(ruler)
|
assert len(new_ruler) == len(ruler)
|
||||||
assert new_ruler.patterns == ruler.patterns
|
for pattern in ruler.patterns:
|
||||||
|
assert pattern in new_ruler.patterns
|
||||||
assert new_ruler.overwrite is not ruler.overwrite
|
assert new_ruler.overwrite is not ruler.overwrite
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,7 +65,8 @@ def test_entity_ruler_from_disk_old_format_safe(patterns, en_vocab):
|
||||||
srsly.write_jsonl(out_file, ruler.patterns)
|
srsly.write_jsonl(out_file, ruler.patterns)
|
||||||
new_ruler = EntityRuler(nlp)
|
new_ruler = EntityRuler(nlp)
|
||||||
new_ruler = new_ruler.from_disk(out_file)
|
new_ruler = new_ruler.from_disk(out_file)
|
||||||
assert new_ruler.patterns == ruler.patterns
|
for pattern in ruler.patterns:
|
||||||
|
assert pattern in new_ruler.patterns
|
||||||
assert len(new_ruler) == len(ruler)
|
assert len(new_ruler) == len(ruler)
|
||||||
assert new_ruler.overwrite is not ruler.overwrite
|
assert new_ruler.overwrite is not ruler.overwrite
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ be a token pattern (list) or a phrase pattern (string). For example:
|
||||||
| ---------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ---------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. |
|
| `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. |
|
||||||
| `patterns` | iterable | Optional patterns to load in. |
|
| `patterns` | iterable | Optional patterns to load in. |
|
||||||
|
| `phrase_matcher_attr` | int / unicode | Optional attr to pass to the internal [`PhraseMatcher`](/api/phtasematcher). defaults to `None`
|
||||||
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. |
|
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. |
|
||||||
| `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. |
|
| `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. |
|
||||||
| **RETURNS** | `EntityRuler` | The newly constructed object. |
|
| **RETURNS** | `EntityRuler` | The newly constructed object. |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user