mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Switch converters to generator functions (#6547)
* Switch converters to generator functions To reduce the memory usage when converting large corpora, refactor the convert methods to be generator functions. * Update tests
This commit is contained in:
parent
8656a08777
commit
1ddf2f39c7
|
@ -5,6 +5,7 @@ from wasabi import Printer
|
|||
import srsly
|
||||
import re
|
||||
import sys
|
||||
import itertools
|
||||
|
||||
from ._util import app, Arg, Opt
|
||||
from ..training import docs_to_json
|
||||
|
@ -130,15 +131,16 @@ def convert(
|
|||
)
|
||||
doc_files.append((input_loc, docs))
|
||||
if concatenate:
|
||||
all_docs = []
|
||||
for _, docs in doc_files:
|
||||
all_docs.extend(docs)
|
||||
all_docs = itertools.chain.from_iterable([docs for _, docs in doc_files])
|
||||
doc_files = [(input_path, all_docs)]
|
||||
for input_loc, docs in doc_files:
|
||||
if file_type == "json":
|
||||
data = [docs_to_json(docs)]
|
||||
len_docs = len(data)
|
||||
else:
|
||||
data = DocBin(docs=docs, store_user_data=True).to_bytes()
|
||||
db = DocBin(docs=docs, store_user_data=True)
|
||||
len_docs = len(db)
|
||||
data = db.to_bytes()
|
||||
if output_dir == "-":
|
||||
_print_docs_to_stdout(data, file_type)
|
||||
else:
|
||||
|
@ -149,7 +151,7 @@ def convert(
|
|||
output_file = Path(output_dir) / input_loc.parts[-1]
|
||||
output_file = output_file.with_suffix(f".{file_type}")
|
||||
_write_docs_to_file(data, output_file, file_type)
|
||||
msg.good(f"Generated output file ({len(docs)} documents): {output_file}")
|
||||
msg.good(f"Generated output file ({len_docs} documents): {output_file}")
|
||||
|
||||
|
||||
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_cli_converters_conllu_to_docs():
|
|||
"4\tavstår\tavstå\tVERB\t_\tMood=Ind|Tense=Pres|VerbForm=Fin\t0\troot\t_\tO",
|
||||
]
|
||||
input_data = "\n".join(lines)
|
||||
converted_docs = conllu_to_docs(input_data, n_sents=1)
|
||||
converted_docs = list(conllu_to_docs(input_data, n_sents=1))
|
||||
assert len(converted_docs) == 1
|
||||
converted = [docs_to_json(converted_docs)]
|
||||
assert converted[0]["id"] == 0
|
||||
|
@ -65,8 +65,8 @@ def test_cli_converters_conllu_to_docs():
|
|||
)
|
||||
def test_cli_converters_conllu_to_docs_name_ner_map(lines):
|
||||
input_data = "\n".join(lines)
|
||||
converted_docs = conllu_to_docs(
|
||||
input_data, n_sents=1, ner_map={"PER": "PERSON", "BAD": ""}
|
||||
converted_docs = list(
|
||||
conllu_to_docs(input_data, n_sents=1, ner_map={"PER": "PERSON", "BAD": ""})
|
||||
)
|
||||
assert len(converted_docs) == 1
|
||||
converted = [docs_to_json(converted_docs)]
|
||||
|
@ -99,8 +99,10 @@ def test_cli_converters_conllu_to_docs_subtokens():
|
|||
"5\t.\t$.\tPUNCT\t_\t_\t4\tpunct\t_\tname=O",
|
||||
]
|
||||
input_data = "\n".join(lines)
|
||||
converted_docs = conllu_to_docs(
|
||||
input_data, n_sents=1, merge_subtokens=True, append_morphology=True
|
||||
converted_docs = list(
|
||||
conllu_to_docs(
|
||||
input_data, n_sents=1, merge_subtokens=True, append_morphology=True
|
||||
)
|
||||
)
|
||||
assert len(converted_docs) == 1
|
||||
converted = [docs_to_json(converted_docs)]
|
||||
|
@ -145,7 +147,7 @@ def test_cli_converters_iob_to_docs():
|
|||
"I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O",
|
||||
]
|
||||
input_data = "\n".join(lines)
|
||||
converted_docs = iob_to_docs(input_data, n_sents=10)
|
||||
converted_docs = list(iob_to_docs(input_data, n_sents=10))
|
||||
assert len(converted_docs) == 1
|
||||
converted = docs_to_json(converted_docs)
|
||||
assert converted["id"] == 0
|
||||
|
@ -212,7 +214,7 @@ def test_cli_converters_conll_ner_to_docs():
|
|||
".\t.\t_\tO",
|
||||
]
|
||||
input_data = "\n".join(lines)
|
||||
converted_docs = conll_ner_to_docs(input_data, n_sents=10)
|
||||
converted_docs = list(conll_ner_to_docs(input_data, n_sents=10))
|
||||
assert len(converted_docs) == 1
|
||||
converted = docs_to_json(converted_docs)
|
||||
assert converted["id"] == 0
|
||||
|
|
|
@ -195,7 +195,7 @@ def test_json_to_docs_no_ner(en_vocab):
|
|||
],
|
||||
}
|
||||
]
|
||||
docs = json_to_docs(data)
|
||||
docs = list(json_to_docs(data))
|
||||
assert len(docs) == 1
|
||||
for doc in docs:
|
||||
assert not doc.has_annotation("ENT_IOB")
|
||||
|
|
|
@ -87,7 +87,6 @@ def conll_ner_to_docs(
|
|||
nlp = load_model(model)
|
||||
else:
|
||||
nlp = get_lang_class("xx")()
|
||||
output_docs = []
|
||||
for conll_doc in input_data.strip().split(doc_delimiter):
|
||||
conll_doc = conll_doc.strip()
|
||||
if not conll_doc:
|
||||
|
@ -116,8 +115,7 @@ def conll_ner_to_docs(
|
|||
token.is_sent_start = sent_starts[i]
|
||||
entities = tags_to_entities(biluo_tags)
|
||||
doc.ents = [Span(doc, start=s, end=e + 1, label=L) for L, s, e in entities]
|
||||
output_docs.append(doc)
|
||||
return output_docs
|
||||
yield doc
|
||||
|
||||
|
||||
def segment_sents_and_docs(doc, n_sents, doc_delimiter, model=None, msg=None):
|
||||
|
|
|
@ -34,16 +34,14 @@ def conllu_to_docs(
|
|||
ner_map=ner_map,
|
||||
merge_subtokens=merge_subtokens,
|
||||
)
|
||||
docs = []
|
||||
sent_docs_to_merge = []
|
||||
for sent_doc in sent_docs:
|
||||
sent_docs_to_merge.append(sent_doc)
|
||||
if len(sent_docs_to_merge) % n_sents == 0:
|
||||
docs.append(Doc.from_docs(sent_docs_to_merge))
|
||||
yield Doc.from_docs(sent_docs_to_merge)
|
||||
sent_docs_to_merge = []
|
||||
if sent_docs_to_merge:
|
||||
docs.append(Doc.from_docs(sent_docs_to_merge))
|
||||
return docs
|
||||
yield Doc.from_docs(sent_docs_to_merge)
|
||||
|
||||
|
||||
def has_ner(input_data, ner_tag_pattern):
|
||||
|
|
|
@ -24,12 +24,10 @@ def iob_to_docs(input_data, n_sents=10, no_print=False, *args, **kwargs):
|
|||
msg = Printer(no_print=no_print)
|
||||
if n_sents > 0:
|
||||
n_sents_info(msg, n_sents)
|
||||
docs = read_iob(input_data.split("\n"), vocab, n_sents)
|
||||
return docs
|
||||
yield from read_iob(input_data.split("\n"), vocab, n_sents)
|
||||
|
||||
|
||||
def read_iob(raw_sents, vocab, n_sents):
|
||||
docs = []
|
||||
for group in minibatch(raw_sents, size=n_sents):
|
||||
tokens = []
|
||||
words = []
|
||||
|
@ -61,5 +59,4 @@ def read_iob(raw_sents, vocab, n_sents):
|
|||
biluo = iob_to_biluo(iob)
|
||||
entities = tags_to_entities(biluo)
|
||||
doc.ents = [Span(doc, start=s, end=e + 1, label=L) for (L, s, e) in entities]
|
||||
docs.append(doc)
|
||||
return docs
|
||||
yield doc
|
||||
|
|
|
@ -12,11 +12,9 @@ def json_to_docs(input_data, model=None, **kwargs):
|
|||
if not isinstance(input_data, str):
|
||||
input_data = srsly.json_dumps(input_data)
|
||||
input_data = input_data.encode("utf8")
|
||||
docs = []
|
||||
for json_doc in json_iterate(input_data):
|
||||
for json_para in json_to_annotations(json_doc):
|
||||
example_dict = _fix_legacy_dict_data(json_para)
|
||||
tok_dict, doc_dict = _parse_example_dict_data(example_dict)
|
||||
doc = annotations_to_doc(nlp.vocab, tok_dict, doc_dict)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
yield doc
|
||||
|
|
Loading…
Reference in New Issue
Block a user