EntityRuler improve disk load error message (#9658)

* added error string

* added serialization test

* added more to if statements

* wrote file to tempdir

* added tempdir

* changed parameter a bit

* Update spacy/tests/pipeline/test_entity_ruler.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Duygu Altinok 2021-11-23 16:26:05 +01:00 committed by GitHub
parent 9ac6d4991e
commit a7d7e80adb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 2 deletions

View File

@ -888,6 +888,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1021 = ("`pos` value \"{pp}\" is not a valid Universal Dependencies tag. "
"Non-UD tags should use the `tag` property.")
E1022 = ("Words must be of type str or int, but input is of type '{wtype}'")
E1023 = ("Couldn't read EntityRuler from the {path}. This file doesn't exist.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -431,10 +431,16 @@ class EntityRuler(Pipe):
path = ensure_path(path)
self.clear()
depr_patterns_path = path.with_suffix(".jsonl")
if depr_patterns_path.is_file():
patterns = srsly.read_jsonl(depr_patterns_path)
if path.suffix == ".jsonl": # user provides a jsonl
if path.is_file:
patterns = srsly.read_jsonl(path)
self.add_patterns(patterns)
else:
raise ValueError(Errors.E1023.format(path=path))
elif depr_patterns_path.is_file():
patterns = srsly.read_jsonl(depr_patterns_path)
self.add_patterns(patterns)
elif path.is_dir(): # path is a valid directory
cfg = {}
deserializers_patterns = {
"patterns": lambda p: self.add_patterns(
@ -451,6 +457,8 @@ class EntityRuler(Pipe):
self.nlp.vocab, attr=self.phrase_matcher_attr
)
from_disk(path, deserializers_patterns, {})
else: # path is not a valid directory or file
raise ValueError(Errors.E146.format(path=path))
return self
def to_disk(

View File

@ -5,6 +5,8 @@ from spacy.tokens import Span
from spacy.language import Language
from spacy.pipeline import EntityRuler
from spacy.errors import MatchPatternError
from spacy.tests.util import make_tempdir
from thinc.api import NumpyOps, get_current_ops
@ -238,3 +240,23 @@ def test_entity_ruler_multiprocessing(nlp, n_process):
for doc in nlp.pipe(texts, n_process=2):
for ent in doc.ents:
assert ent.ent_id_ == "1234"
def test_entity_ruler_serialize_jsonl(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
with make_tempdir() as d:
ruler.to_disk(d / "test_ruler.jsonl")
ruler.from_disk(d / "test_ruler.jsonl") # read from an existing jsonl file
with pytest.raises(ValueError):
ruler.from_disk(d / "non_existing.jsonl") # read from a bad jsonl file
def test_entity_ruler_serialize_dir(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
with make_tempdir() as d:
ruler.to_disk(d / "test_ruler")
ruler.from_disk(d / "test_ruler") # read from an existing directory
with pytest.raises(ValueError):
ruler.from_disk(d / "non_existing_dir") # read from a bad directory