mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-06 06:30:35 +03:00
avoid writing temp dir in json2docs, fixing 4402 test
This commit is contained in:
parent
ffddff03b8
commit
5e71919322
|
@ -2,7 +2,7 @@ import tempfile
|
||||||
import contextlib
|
import contextlib
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ..gold_io import read_json_file
|
from ..gold_io import json_to_annotations
|
||||||
from ..example import annotations2doc
|
from ..example import annotations2doc
|
||||||
from ..example import _fix_legacy_dict_data, _parse_example_dict_data
|
from ..example import _fix_legacy_dict_data, _parse_example_dict_data
|
||||||
from ...util import load_model
|
from ...util import load_model
|
||||||
|
@ -19,11 +19,7 @@ def make_tempdir():
|
||||||
def json2docs(input_data, model=None, **kwargs):
|
def json2docs(input_data, model=None, **kwargs):
|
||||||
nlp = load_model(model) if model is not None else MultiLanguage()
|
nlp = load_model(model) if model is not None else MultiLanguage()
|
||||||
docs = []
|
docs = []
|
||||||
with make_tempdir() as tmp_dir:
|
for json_annot in json_to_annotations(input_data):
|
||||||
json_path = Path(tmp_dir) / "data.json"
|
|
||||||
with (json_path).open("w") as file_:
|
|
||||||
file_.write(input_data)
|
|
||||||
for json_annot in read_json_file(json_path):
|
|
||||||
example_dict = _fix_legacy_dict_data(json_annot)
|
example_dict = _fix_legacy_dict_data(json_annot)
|
||||||
tok_dict, doc_dict = _parse_example_dict_data(example_dict)
|
tok_dict, doc_dict = _parse_example_dict_data(example_dict)
|
||||||
doc = annotations2doc(nlp.vocab, tok_dict, doc_dict)
|
doc = annotations2doc(nlp.vocab, tok_dict, doc_dict)
|
||||||
|
|
|
@ -43,7 +43,7 @@ class Corpus:
|
||||||
locs.append(path)
|
locs.append(path)
|
||||||
return locs
|
return locs
|
||||||
|
|
||||||
def make_examples(self, nlp, reference_docs, **kwargs):
|
def make_examples(self, nlp, reference_docs):
|
||||||
for reference in reference_docs:
|
for reference in reference_docs:
|
||||||
predicted = nlp.make_doc(reference.text)
|
predicted = nlp.make_doc(reference.text)
|
||||||
yield Example(predicted, reference)
|
yield Example(predicted, reference)
|
||||||
|
@ -72,15 +72,15 @@ class Corpus:
|
||||||
i += 1
|
i += 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def train_dataset(self, nlp, shuffle=True, **kwargs):
|
def train_dataset(self, nlp, shuffle=True):
|
||||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
||||||
examples = self.make_examples(nlp, ref_docs, **kwargs)
|
examples = self.make_examples(nlp, ref_docs)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
examples = list(examples)
|
examples = list(examples)
|
||||||
random.shuffle(examples)
|
random.shuffle(examples)
|
||||||
yield from examples
|
yield from examples
|
||||||
|
|
||||||
def dev_dataset(self, nlp, **kwargs):
|
def dev_dataset(self, nlp):
|
||||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
|
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
|
||||||
examples = self.make_examples(nlp, ref_docs, **kwargs)
|
examples = self.make_examples(nlp, ref_docs)
|
||||||
yield from examples
|
yield from examples
|
||||||
|
|
|
@ -9,7 +9,6 @@ from .align cimport Alignment
|
||||||
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
|
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
|
||||||
from .align import Alignment
|
from .align import Alignment
|
||||||
from ..errors import Errors, AlignmentError
|
from ..errors import Errors, AlignmentError
|
||||||
from ..structs cimport TokenC
|
|
||||||
from ..syntax import nonproj
|
from ..syntax import nonproj
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,6 +18,7 @@ cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
|
||||||
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
|
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
|
||||||
if array.size:
|
if array.size:
|
||||||
output = output.from_array(attrs, array)
|
output = output.from_array(attrs, array)
|
||||||
|
# TODO: links ?!
|
||||||
output.cats.update(doc_annot.get("cats", {}))
|
output.cats.update(doc_annot.get("cats", {}))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import warnings
|
||||||
import srsly
|
import srsly
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Warnings
|
from ..errors import Warnings
|
||||||
from ..tokens import Token, Doc
|
from ..tokens import Doc
|
||||||
from .iob_utils import biluo_tags_from_offsets
|
from .iob_utils import biluo_tags_from_offsets
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,31 @@
|
||||||
import srsly
|
|
||||||
from spacy.gold import Corpus
|
from spacy.gold import Corpus
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
from ...gold.converters import json2docs
|
||||||
|
from ...tokens import DocBin
|
||||||
|
|
||||||
|
|
||||||
def test_issue4402():
|
def test_issue4402():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
with make_tempdir() as tmpdir:
|
with make_tempdir() as tmpdir:
|
||||||
json_path = tmpdir / "test4402.json"
|
output_file = tmpdir / "test4402.spacy"
|
||||||
srsly.write_json(json_path, json_data)
|
docs = json2docs(json_data)
|
||||||
|
data = DocBin(docs=docs, attrs =["ORTH", "SENT_START", "ENT_IOB", "ENT_TYPE"]).to_bytes()
|
||||||
|
with output_file.open("wb") as file_:
|
||||||
|
file_.write(data)
|
||||||
|
corpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file))
|
||||||
|
|
||||||
corpus = Corpus(str(json_path), str(json_path))
|
train_data = list(corpus.train_dataset(nlp))
|
||||||
|
assert len(train_data) == 2
|
||||||
|
|
||||||
train_data = list(corpus.train_dataset(nlp, gold_preproc=True, max_length=0))
|
split_train_data = []
|
||||||
# assert that the data got split into 4 sentences
|
for eg in train_data:
|
||||||
assert len(train_data) == 4
|
split_train_data.extend(eg.split_sents())
|
||||||
|
assert len(split_train_data) == 4
|
||||||
|
|
||||||
|
|
||||||
json_data = [
|
json_data =\
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
|
@ -89,4 +96,3 @@ json_data = [
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user