spaCy/spacy/gold/corpus.py

227 lines
7.8 KiB
Python

import random
import shutil
import tempfile
import srsly
from pathlib import Path
import itertools
from ..tokens import Doc
from .. import util
from ..errors import Errors, AlignmentError
from .gold_io import read_json_file, json_to_annotations
from .augment import make_orth_variants, add_noise
from .example import Example
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER.
DOCS: https://spacy.io/api/goldcorpus
"""
def __init__(self, train, dev, gold_preproc=False, limit=None):
"""Create a GoldCorpus.
train (str / Path): File or directory of training data.
dev (str / Path): File or directory of development data.
RETURNS (GoldCorpus): The newly created object.
"""
self.limit = limit
if isinstance(train, str) or isinstance(train, Path):
train = self.read_annotations(self.walk_corpus(train))
dev = self.read_annotations(self.walk_corpus(dev))
# Write temp directory with one doc per file, so we can shuffle and stream
self.tmp_dir = Path(tempfile.mkdtemp())
self.write_msgpack(self.tmp_dir / "train", train, limit=self.limit)
self.write_msgpack(self.tmp_dir / "dev", dev, limit=self.limit)
def __del__(self):
shutil.rmtree(self.tmp_dir)
@staticmethod
def write_msgpack(directory, examples, limit=0):
if not directory.exists():
directory.mkdir()
n = 0
for i, ex_dict in enumerate(examples):
text = ex_dict["text"]
srsly.write_msgpack(directory / f"{i}.msg", (text, ex_dict))
n += 1
if limit and n >= limit:
break
@staticmethod
def walk_corpus(path):
path = util.ensure_path(path)
if not path.is_dir():
return [path]
paths = [path]
locs = []
seen = set()
for path in paths:
if str(path) in seen:
continue
seen.add(str(path))
if path.parts[-1].startswith("."):
continue
elif path.is_dir():
paths.extend(path.iterdir())
elif path.parts[-1].endswith((".json", ".jsonl")):
locs.append(path)
return locs
@staticmethod
def read_annotations(locs, limit=0):
""" Yield training examples as example dicts """
i = 0
for loc in locs:
loc = util.ensure_path(loc)
file_name = loc.parts[-1]
if file_name.endswith("json"):
examples = read_json_file(loc)
elif file_name.endswith("jsonl"):
gold_tuples = srsly.read_jsonl(loc)
first_gold_tuple = next(gold_tuples)
gold_tuples = itertools.chain([first_gold_tuple], gold_tuples)
# TODO: proper format checks with schemas
if isinstance(first_gold_tuple, dict):
if first_gold_tuple.get("paragraphs", None):
examples = []
for json_doc in gold_tuples:
examples.extend(json_to_annotations(json_doc))
elif first_gold_tuple.get("doc_annotation", None):
examples = []
for ex_dict in gold_tuples:
doc = ex_dict.get("doc", None)
if doc is None:
doc = ex_dict.get("text", None)
if not (
doc is None
or isinstance(doc, Doc)
or isinstance(doc, str)
):
raise ValueError(Errors.E987.format(type=type(doc)))
examples.append(ex_dict)
elif file_name.endswith("msg"):
text, ex_dict = srsly.read_msgpack(loc)
examples = [ex_dict]
else:
supported = ("json", "jsonl", "msg")
raise ValueError(Errors.E124.format(path=loc, formats=supported))
try:
for example in examples:
yield example
i += 1
if limit and i >= limit:
return
except KeyError as e:
msg = "Missing key {}".format(e)
raise KeyError(Errors.E996.format(file=file_name, msg=msg))
except UnboundLocalError as e:
msg = "Unexpected document structure"
raise ValueError(Errors.E996.format(file=file_name, msg=msg))
@property
def dev_annotations(self):
locs = (self.tmp_dir / "dev").iterdir()
yield from self.read_annotations(locs, limit=self.limit)
@property
def train_annotations(self):
locs = (self.tmp_dir / "train").iterdir()
yield from self.read_annotations(locs, limit=self.limit)
def count_train(self):
"""Returns count of words in train examples"""
n = 0
i = 0
for eg_dict in self.train_annotations:
n += len(eg_dict["token_annotation"]["words"])
if self.limit and i >= self.limit:
break
i += 1
return n
def train_dataset(
self,
nlp,
gold_preproc=False,
max_length=None,
noise_level=0.0,
orth_variant_level=0.0,
ignore_misaligned=False,
):
locs = list((self.tmp_dir / "train").iterdir())
random.shuffle(locs)
train_annotations = self.read_annotations(locs, limit=self.limit)
examples = self.iter_examples(
nlp,
train_annotations,
gold_preproc,
max_length=max_length,
noise_level=noise_level,
orth_variant_level=orth_variant_level,
make_projective=True,
ignore_misaligned=ignore_misaligned,
)
yield from examples
def train_dataset_without_preprocessing(
self, nlp, gold_preproc=False, ignore_misaligned=False
):
examples = self.iter_examples(
nlp,
self.train_annotations,
gold_preproc=gold_preproc,
ignore_misaligned=ignore_misaligned,
)
yield from examples
def dev_dataset(self, nlp, gold_preproc=False, ignore_misaligned=False):
examples = self.iter_examples(
nlp,
self.dev_annotations,
gold_preproc=gold_preproc,
ignore_misaligned=ignore_misaligned,
)
yield from examples
@classmethod
def iter_examples(
cls,
nlp,
annotations,
gold_preproc,
max_length=None,
noise_level=0.0,
orth_variant_level=0.0,
make_projective=False,
ignore_misaligned=False,
):
""" Setting gold_preproc will result in creating a doc per sentence """
for eg_dict in annotations:
if eg_dict["text"]:
example = Example.from_dict(
nlp.make_doc(eg_dict["text"]),
eg_dict
)
else:
example = Example.from_dict(
Doc(nlp.vocab, words=eg_dict["words"]),
eg_dict
)
if gold_preproc:
# TODO: Data augmentation
examples = example.split_sents()
else:
examples = [example]
for eg in examples:
if (not max_length) or len(eg.predicted) < max_length:
if ignore_misaligned:
try:
_ = eg._deprecated_get_gold()
except AlignmentError:
continue
yield eg