Switch converters to generator functions (#6547)

* Switch converters to generator functions

To reduce the memory usage when converting large corpora, refactor the
convert methods to be generator functions.

* Update tests
This commit is contained in:
Adriane Boyd 2020-12-15 09:47:16 +01:00 committed by GitHub
parent 8656a08777
commit 1ddf2f39c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 23 additions and 28 deletions

View File

@ -5,6 +5,7 @@ from wasabi import Printer
import srsly import srsly
import re import re
import sys import sys
import itertools
from ._util import app, Arg, Opt from ._util import app, Arg, Opt
from ..training import docs_to_json from ..training import docs_to_json
@ -130,15 +131,16 @@ def convert(
) )
doc_files.append((input_loc, docs)) doc_files.append((input_loc, docs))
if concatenate: if concatenate:
all_docs = [] all_docs = itertools.chain.from_iterable([docs for _, docs in doc_files])
for _, docs in doc_files:
all_docs.extend(docs)
doc_files = [(input_path, all_docs)] doc_files = [(input_path, all_docs)]
for input_loc, docs in doc_files: for input_loc, docs in doc_files:
if file_type == "json": if file_type == "json":
data = [docs_to_json(docs)] data = [docs_to_json(docs)]
len_docs = len(data)
else: else:
data = DocBin(docs=docs, store_user_data=True).to_bytes() db = DocBin(docs=docs, store_user_data=True)
len_docs = len(db)
data = db.to_bytes()
if output_dir == "-": if output_dir == "-":
_print_docs_to_stdout(data, file_type) _print_docs_to_stdout(data, file_type)
else: else:
@ -149,7 +151,7 @@ def convert(
output_file = Path(output_dir) / input_loc.parts[-1] output_file = Path(output_dir) / input_loc.parts[-1]
output_file = output_file.with_suffix(f".{file_type}") output_file = output_file.with_suffix(f".{file_type}")
_write_docs_to_file(data, output_file, file_type) _write_docs_to_file(data, output_file, file_type)
msg.good(f"Generated output file ({len(docs)} documents): {output_file}") msg.good(f"Generated output file ({len_docs} documents): {output_file}")
def _print_docs_to_stdout(data: Any, output_type: str) -> None: def _print_docs_to_stdout(data: Any, output_type: str) -> None:

View File

@ -24,7 +24,7 @@ def test_cli_converters_conllu_to_docs():
"4\tavstår\tavstå\tVERB\t_\tMood=Ind|Tense=Pres|VerbForm=Fin\t0\troot\t_\tO", "4\tavstår\tavstå\tVERB\t_\tMood=Ind|Tense=Pres|VerbForm=Fin\t0\troot\t_\tO",
] ]
input_data = "\n".join(lines) input_data = "\n".join(lines)
converted_docs = conllu_to_docs(input_data, n_sents=1) converted_docs = list(conllu_to_docs(input_data, n_sents=1))
assert len(converted_docs) == 1 assert len(converted_docs) == 1
converted = [docs_to_json(converted_docs)] converted = [docs_to_json(converted_docs)]
assert converted[0]["id"] == 0 assert converted[0]["id"] == 0
@ -65,8 +65,8 @@ def test_cli_converters_conllu_to_docs():
) )
def test_cli_converters_conllu_to_docs_name_ner_map(lines): def test_cli_converters_conllu_to_docs_name_ner_map(lines):
input_data = "\n".join(lines) input_data = "\n".join(lines)
converted_docs = conllu_to_docs( converted_docs = list(
input_data, n_sents=1, ner_map={"PER": "PERSON", "BAD": ""} conllu_to_docs(input_data, n_sents=1, ner_map={"PER": "PERSON", "BAD": ""})
) )
assert len(converted_docs) == 1 assert len(converted_docs) == 1
converted = [docs_to_json(converted_docs)] converted = [docs_to_json(converted_docs)]
@ -99,8 +99,10 @@ def test_cli_converters_conllu_to_docs_subtokens():
"5\t.\t$.\tPUNCT\t_\t_\t4\tpunct\t_\tname=O", "5\t.\t$.\tPUNCT\t_\t_\t4\tpunct\t_\tname=O",
] ]
input_data = "\n".join(lines) input_data = "\n".join(lines)
converted_docs = conllu_to_docs( converted_docs = list(
input_data, n_sents=1, merge_subtokens=True, append_morphology=True conllu_to_docs(
input_data, n_sents=1, merge_subtokens=True, append_morphology=True
)
) )
assert len(converted_docs) == 1 assert len(converted_docs) == 1
converted = [docs_to_json(converted_docs)] converted = [docs_to_json(converted_docs)]
@ -145,7 +147,7 @@ def test_cli_converters_iob_to_docs():
"I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O", "I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O",
] ]
input_data = "\n".join(lines) input_data = "\n".join(lines)
converted_docs = iob_to_docs(input_data, n_sents=10) converted_docs = list(iob_to_docs(input_data, n_sents=10))
assert len(converted_docs) == 1 assert len(converted_docs) == 1
converted = docs_to_json(converted_docs) converted = docs_to_json(converted_docs)
assert converted["id"] == 0 assert converted["id"] == 0
@ -212,7 +214,7 @@ def test_cli_converters_conll_ner_to_docs():
".\t.\t_\tO", ".\t.\t_\tO",
] ]
input_data = "\n".join(lines) input_data = "\n".join(lines)
converted_docs = conll_ner_to_docs(input_data, n_sents=10) converted_docs = list(conll_ner_to_docs(input_data, n_sents=10))
assert len(converted_docs) == 1 assert len(converted_docs) == 1
converted = docs_to_json(converted_docs) converted = docs_to_json(converted_docs)
assert converted["id"] == 0 assert converted["id"] == 0

View File

@ -195,7 +195,7 @@ def test_json_to_docs_no_ner(en_vocab):
], ],
} }
] ]
docs = json_to_docs(data) docs = list(json_to_docs(data))
assert len(docs) == 1 assert len(docs) == 1
for doc in docs: for doc in docs:
assert not doc.has_annotation("ENT_IOB") assert not doc.has_annotation("ENT_IOB")

View File

@ -87,7 +87,6 @@ def conll_ner_to_docs(
nlp = load_model(model) nlp = load_model(model)
else: else:
nlp = get_lang_class("xx")() nlp = get_lang_class("xx")()
output_docs = []
for conll_doc in input_data.strip().split(doc_delimiter): for conll_doc in input_data.strip().split(doc_delimiter):
conll_doc = conll_doc.strip() conll_doc = conll_doc.strip()
if not conll_doc: if not conll_doc:
@ -116,8 +115,7 @@ def conll_ner_to_docs(
token.is_sent_start = sent_starts[i] token.is_sent_start = sent_starts[i]
entities = tags_to_entities(biluo_tags) entities = tags_to_entities(biluo_tags)
doc.ents = [Span(doc, start=s, end=e + 1, label=L) for L, s, e in entities] doc.ents = [Span(doc, start=s, end=e + 1, label=L) for L, s, e in entities]
output_docs.append(doc) yield doc
return output_docs
def segment_sents_and_docs(doc, n_sents, doc_delimiter, model=None, msg=None): def segment_sents_and_docs(doc, n_sents, doc_delimiter, model=None, msg=None):

View File

@ -34,16 +34,14 @@ def conllu_to_docs(
ner_map=ner_map, ner_map=ner_map,
merge_subtokens=merge_subtokens, merge_subtokens=merge_subtokens,
) )
docs = []
sent_docs_to_merge = [] sent_docs_to_merge = []
for sent_doc in sent_docs: for sent_doc in sent_docs:
sent_docs_to_merge.append(sent_doc) sent_docs_to_merge.append(sent_doc)
if len(sent_docs_to_merge) % n_sents == 0: if len(sent_docs_to_merge) % n_sents == 0:
docs.append(Doc.from_docs(sent_docs_to_merge)) yield Doc.from_docs(sent_docs_to_merge)
sent_docs_to_merge = [] sent_docs_to_merge = []
if sent_docs_to_merge: if sent_docs_to_merge:
docs.append(Doc.from_docs(sent_docs_to_merge)) yield Doc.from_docs(sent_docs_to_merge)
return docs
def has_ner(input_data, ner_tag_pattern): def has_ner(input_data, ner_tag_pattern):

View File

@ -24,12 +24,10 @@ def iob_to_docs(input_data, n_sents=10, no_print=False, *args, **kwargs):
msg = Printer(no_print=no_print) msg = Printer(no_print=no_print)
if n_sents > 0: if n_sents > 0:
n_sents_info(msg, n_sents) n_sents_info(msg, n_sents)
docs = read_iob(input_data.split("\n"), vocab, n_sents) yield from read_iob(input_data.split("\n"), vocab, n_sents)
return docs
def read_iob(raw_sents, vocab, n_sents): def read_iob(raw_sents, vocab, n_sents):
docs = []
for group in minibatch(raw_sents, size=n_sents): for group in minibatch(raw_sents, size=n_sents):
tokens = [] tokens = []
words = [] words = []
@ -61,5 +59,4 @@ def read_iob(raw_sents, vocab, n_sents):
biluo = iob_to_biluo(iob) biluo = iob_to_biluo(iob)
entities = tags_to_entities(biluo) entities = tags_to_entities(biluo)
doc.ents = [Span(doc, start=s, end=e + 1, label=L) for (L, s, e) in entities] doc.ents = [Span(doc, start=s, end=e + 1, label=L) for (L, s, e) in entities]
docs.append(doc) yield doc
return docs

View File

@ -12,11 +12,9 @@ def json_to_docs(input_data, model=None, **kwargs):
if not isinstance(input_data, str): if not isinstance(input_data, str):
input_data = srsly.json_dumps(input_data) input_data = srsly.json_dumps(input_data)
input_data = input_data.encode("utf8") input_data = input_data.encode("utf8")
docs = []
for json_doc in json_iterate(input_data): for json_doc in json_iterate(input_data):
for json_para in json_to_annotations(json_doc): for json_para in json_to_annotations(json_doc):
example_dict = _fix_legacy_dict_data(json_para) example_dict = _fix_legacy_dict_data(json_para)
tok_dict, doc_dict = _parse_example_dict_data(example_dict) tok_dict, doc_dict = _parse_example_dict_data(example_dict)
doc = annotations_to_doc(nlp.vocab, tok_dict, doc_dict) doc = annotations_to_doc(nlp.vocab, tok_dict, doc_dict)
docs.append(doc) yield doc
return docs