Auto-format

This commit is contained in:
Ines Montani 2019-07-10 12:03:05 +02:00
parent 6ba5ddbd5f
commit ea2050079b

View File

@ -10,7 +10,7 @@ from ..util import ensure_path, to_disk, from_disk
from ..tokens import Span from ..tokens import Span
from ..matcher import Matcher, PhraseMatcher from ..matcher import Matcher, PhraseMatcher
DEFAULT_ENT_ID_SEP = '||' DEFAULT_ENT_ID_SEP = "||"
class EntityRuler(object): class EntityRuler(object):
@ -53,7 +53,9 @@ class EntityRuler(object):
self.matcher = Matcher(nlp.vocab) self.matcher = Matcher(nlp.vocab)
if phrase_matcher_attr is not None: if phrase_matcher_attr is not None:
self.phrase_matcher_attr = phrase_matcher_attr self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher(nlp.vocab, attr=self.phrase_matcher_attr) self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr
)
else: else:
self.phrase_matcher_attr = None self.phrase_matcher_attr = None
self.phrase_matcher = PhraseMatcher(nlp.vocab) self.phrase_matcher = PhraseMatcher(nlp.vocab)
@ -223,13 +225,14 @@ class EntityRuler(object):
""" """
cfg = srsly.msgpack_loads(patterns_bytes) cfg = srsly.msgpack_loads(patterns_bytes)
if isinstance(cfg, dict): if isinstance(cfg, dict):
self.add_patterns(cfg.get('patterns', cfg)) self.add_patterns(cfg.get("patterns", cfg))
self.overwrite = cfg.get('overwrite', False) self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr', None) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
if self.phrase_matcher_attr is not None: if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(self.nlp.vocab, self.phrase_matcher = PhraseMatcher(
attr=self.phrase_matcher_attr) self.nlp.vocab, attr=self.phrase_matcher_attr
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP) )
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
else: else:
self.add_patterns(cfg) self.add_patterns(cfg)
return self return self
@ -242,11 +245,14 @@ class EntityRuler(object):
DOCS: https://spacy.io/api/entityruler#to_bytes DOCS: https://spacy.io/api/entityruler#to_bytes
""" """
serial = OrderedDict(( serial = OrderedDict(
('overwrite', self.overwrite), (
('ent_id_sep', self.ent_id_sep), ("overwrite", self.overwrite),
('phrase_matcher_attr', self.phrase_matcher_attr), ("ent_id_sep", self.ent_id_sep),
('patterns', self.patterns))) ("phrase_matcher_attr", self.phrase_matcher_attr),
("patterns", self.patterns),
)
)
return srsly.msgpack_dumps(serial) return srsly.msgpack_dumps(serial)
def from_disk(self, path, **kwargs): def from_disk(self, path, **kwargs):
@ -266,17 +272,20 @@ class EntityRuler(object):
else: else:
cfg = {} cfg = {}
deserializers = { deserializers = {
'patterns': lambda p: self.add_patterns(srsly.read_jsonl(p.with_suffix('.jsonl'))), "patterns": lambda p: self.add_patterns(
'cfg': lambda p: cfg.update(srsly.read_json(p)) srsly.read_jsonl(p.with_suffix(".jsonl"))
),
"cfg": lambda p: cfg.update(srsly.read_json(p)),
} }
from_disk(path, deserializers, {}) from_disk(path, deserializers, {})
self.overwrite = cfg.get('overwrite', False) self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get('phrase_matcher_attr') self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
self.ent_id_sep = cfg.get('ent_id_sep', DEFAULT_ENT_ID_SEP) self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
if self.phrase_matcher_attr is not None: if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(self.nlp.vocab, self.phrase_matcher = PhraseMatcher(
attr=self.phrase_matcher_attr) self.nlp.vocab, attr=self.phrase_matcher_attr
)
return self return self
def to_disk(self, path, **kwargs): def to_disk(self, path, **kwargs):
@ -289,13 +298,16 @@ class EntityRuler(object):
DOCS: https://spacy.io/api/entityruler#to_disk DOCS: https://spacy.io/api/entityruler#to_disk
""" """
cfg = {'overwrite': self.overwrite, cfg = {
'phrase_matcher_attr': self.phrase_matcher_attr, "overwrite": self.overwrite,
'ent_id_sep': self.ent_id_sep} "phrase_matcher_attr": self.phrase_matcher_attr,
"ent_id_sep": self.ent_id_sep,
}
serializers = { serializers = {
'patterns': lambda p: srsly.write_jsonl(p.with_suffix('.jsonl'), "patterns": lambda p: srsly.write_jsonl(
self.patterns), p.with_suffix(".jsonl"), self.patterns
'cfg': lambda p: srsly.write_json(p, cfg) ),
"cfg": lambda p: srsly.write_json(p, cfg),
} }
path = ensure_path(path) path = ensure_path(path)
to_disk(path, serializers, {}) to_disk(path, serializers, {})