mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
* Add check for empty input file to CLI pretrain * Raise error if JSONL is not a dict or contains neither `tokens` nor `text` key * Skip empty values for correct pretrain keys and log a counter as warning * Add tests for CLI pretrain core function make_docs. * Add a short hint for the `tokens` key to the CLI pretrain docs * Add success message to CLI pretrain * Update model loading to fix the tests * Skip empty values and do not create docs out of it
This commit is contained in:
parent
5accfbb938
commit
d8573ee715
|
@ -13,6 +13,7 @@ from thinc.neural.util import prefer_gpu, get_array_module
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
from ..errors import Errors
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import ID, HEAD
|
from ..attrs import ID, HEAD
|
||||||
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
|
from .._ml import Tok2Vec, flatten, chain, create_default_optimizer
|
||||||
|
@ -101,6 +102,8 @@ def pretrain(
|
||||||
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
|
msg.fail("Input text file doesn't exist", texts_loc, exits=1)
|
||||||
with msg.loading("Loading input texts..."):
|
with msg.loading("Loading input texts..."):
|
||||||
texts = list(srsly.read_jsonl(texts_loc))
|
texts = list(srsly.read_jsonl(texts_loc))
|
||||||
|
if not texts:
|
||||||
|
msg.fail("Input file is empty", texts_loc, exits=1)
|
||||||
msg.good("Loaded input texts")
|
msg.good("Loaded input texts")
|
||||||
random.shuffle(texts)
|
random.shuffle(texts)
|
||||||
else: # reading from stdin
|
else: # reading from stdin
|
||||||
|
@ -149,16 +152,18 @@ def pretrain(
|
||||||
with (output_dir / "log.jsonl").open("a") as file_:
|
with (output_dir / "log.jsonl").open("a") as file_:
|
||||||
file_.write(srsly.json_dumps(log) + "\n")
|
file_.write(srsly.json_dumps(log) + "\n")
|
||||||
|
|
||||||
|
skip_counter = 0
|
||||||
for epoch in range(n_iter):
|
for epoch in range(n_iter):
|
||||||
for batch_id, batch in enumerate(
|
for batch_id, batch in enumerate(
|
||||||
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
|
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
|
||||||
):
|
):
|
||||||
docs = make_docs(
|
docs, count = make_docs(
|
||||||
nlp,
|
nlp,
|
||||||
[text for (text, _) in batch],
|
[text for (text, _) in batch],
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
)
|
)
|
||||||
|
skip_counter += count
|
||||||
loss = make_update(
|
loss = make_update(
|
||||||
model, docs, optimizer, objective=loss_func, drop=dropout
|
model, docs, optimizer, objective=loss_func, drop=dropout
|
||||||
)
|
)
|
||||||
|
@ -174,6 +179,9 @@ def pretrain(
|
||||||
if texts_loc != "-":
|
if texts_loc != "-":
|
||||||
# Reshuffle the texts if texts were loaded from a file
|
# Reshuffle the texts if texts were loaded from a file
|
||||||
random.shuffle(texts)
|
random.shuffle(texts)
|
||||||
|
if skip_counter > 0:
|
||||||
|
msg.warn("Skipped {count} empty values".format(count=str(skip_counter)))
|
||||||
|
msg.good("Successfully finished pretrain")
|
||||||
|
|
||||||
|
|
||||||
def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
||||||
|
@ -195,12 +203,24 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
||||||
|
|
||||||
def make_docs(nlp, batch, min_length, max_length):
|
def make_docs(nlp, batch, min_length, max_length):
|
||||||
docs = []
|
docs = []
|
||||||
|
skip_count = 0
|
||||||
for record in batch:
|
for record in batch:
|
||||||
|
if not isinstance(record, dict):
|
||||||
|
raise TypeError(Errors.E137.format(type=type(record), line=record))
|
||||||
if "tokens" in record:
|
if "tokens" in record:
|
||||||
doc = Doc(nlp.vocab, words=record["tokens"])
|
words = record["tokens"]
|
||||||
else:
|
if not words:
|
||||||
|
skip_count += 1
|
||||||
|
continue
|
||||||
|
doc = Doc(nlp.vocab, words=words)
|
||||||
|
elif "text" in record:
|
||||||
text = record["text"]
|
text = record["text"]
|
||||||
|
if not text:
|
||||||
|
skip_count += 1
|
||||||
|
continue
|
||||||
doc = nlp.make_doc(text)
|
doc = nlp.make_doc(text)
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E138.format(text=record))
|
||||||
if "heads" in record:
|
if "heads" in record:
|
||||||
heads = record["heads"]
|
heads = record["heads"]
|
||||||
heads = numpy.asarray(heads, dtype="uint64")
|
heads = numpy.asarray(heads, dtype="uint64")
|
||||||
|
@ -208,7 +228,7 @@ def make_docs(nlp, batch, min_length, max_length):
|
||||||
doc = doc.from_array([HEAD], heads)
|
doc = doc.from_array([HEAD], heads)
|
||||||
if len(doc) >= min_length and len(doc) < max_length:
|
if len(doc) >= min_length and len(doc) < max_length:
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs
|
return docs, skip_count
|
||||||
|
|
||||||
|
|
||||||
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||||
|
|
|
@ -393,6 +393,12 @@ class Errors(object):
|
||||||
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
||||||
E136 = ("This additional feature requires the jsonschema library to be "
|
E136 = ("This additional feature requires the jsonschema library to be "
|
||||||
"installed:\npip install jsonschema")
|
"installed:\npip install jsonschema")
|
||||||
|
E137 = ("Expected 'dict' type, but got '{type}' from '{line}'. Make sure to provide a valid JSON "
|
||||||
|
"object as input with either the `text` or `tokens` key. For more info, see the docs:\n"
|
||||||
|
"https://spacy.io/api/cli#pretrain-jsonl")
|
||||||
|
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input includes either the "
|
||||||
|
"`text` or `tokens` key. For more info, see the docs:\n"
|
||||||
|
"https://spacy.io/api/cli#pretrain-jsonl")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from spacy.lang.en import English
|
||||||
from spacy.cli.converters import conllu2json
|
from spacy.cli.converters import conllu2json
|
||||||
|
from spacy.cli.pretrain import make_docs
|
||||||
|
|
||||||
|
|
||||||
def test_cli_converters_conllu2json():
|
def test_cli_converters_conllu2json():
|
||||||
|
@ -26,3 +30,45 @@ def test_cli_converters_conllu2json():
|
||||||
assert [t["head"] for t in tokens] == [1, 2, -1, 0]
|
assert [t["head"] for t in tokens] == [1, 2, -1, 0]
|
||||||
assert [t["dep"] for t in tokens] == ["appos", "nsubj", "name", "ROOT"]
|
assert [t["dep"] for t in tokens] == ["appos", "nsubj", "name", "ROOT"]
|
||||||
assert [t["ner"] for t in tokens] == ["O", "B-PER", "L-PER", "O"]
|
assert [t["ner"] for t in tokens] == ["O", "B-PER", "L-PER", "O"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pretrain_make_docs():
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
valid_jsonl_text = {"text": "Some text"}
|
||||||
|
docs, skip_count = make_docs(nlp, [valid_jsonl_text], 1, 10)
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert skip_count == 0
|
||||||
|
|
||||||
|
valid_jsonl_tokens = {"tokens": ["Some", "tokens"]}
|
||||||
|
docs, skip_count = make_docs(nlp, [valid_jsonl_tokens], 1, 10)
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert skip_count == 0
|
||||||
|
|
||||||
|
invalid_jsonl_type = 0
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
make_docs(nlp, [invalid_jsonl_type], 1, 100)
|
||||||
|
|
||||||
|
invalid_jsonl_key = {"invalid": "Does not matter"}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
make_docs(nlp, [invalid_jsonl_key], 1, 100)
|
||||||
|
|
||||||
|
empty_jsonl_text = {"text": ""}
|
||||||
|
docs, skip_count = make_docs(nlp, [empty_jsonl_text], 1, 10)
|
||||||
|
assert len(docs) == 0
|
||||||
|
assert skip_count == 1
|
||||||
|
|
||||||
|
empty_jsonl_tokens = {"tokens": []}
|
||||||
|
docs, skip_count = make_docs(nlp, [empty_jsonl_tokens], 1, 10)
|
||||||
|
assert len(docs) == 0
|
||||||
|
assert skip_count == 1
|
||||||
|
|
||||||
|
too_short_jsonl = {"text": "This text is not long enough"}
|
||||||
|
docs, skip_count = make_docs(nlp, [too_short_jsonl], 10, 15)
|
||||||
|
assert len(docs) == 0
|
||||||
|
assert skip_count == 0
|
||||||
|
|
||||||
|
too_long_jsonl = {"text": "This text contains way too much tokens for this test"}
|
||||||
|
docs, skip_count = make_docs(nlp, [too_long_jsonl], 1, 5)
|
||||||
|
assert len(docs) == 0
|
||||||
|
assert skip_count == 0
|
||||||
|
|
|
@ -291,7 +291,7 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
|
||||||
|
|
||||||
| Argument | Type | Description |
|
| Argument | Type | Description |
|
||||||
| ----------------------- | ---------- | --------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------------------- | ---------- | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `texts_loc` | positional | Path to JSONL file with raw texts to learn from, with text provided as the key `"text"`. [See here](#pretrain-jsonl) for details. |
|
| `texts_loc` | positional | Path to JSONL file with raw texts to learn from, with text provided as the key `"text"` or tokens as the key `tokens`. [See here](#pretrain-jsonl) for details. |
|
||||||
| `vectors_model` | positional | Name or path to spaCy model with vectors to learn from. |
|
| `vectors_model` | positional | Name or path to spaCy model with vectors to learn from. |
|
||||||
| `output_dir` | positional | Directory to write models to on each epoch. |
|
| `output_dir` | positional | Directory to write models to on each epoch. |
|
||||||
| `--width`, `-cw` | option | Width of CNN layers. |
|
| `--width`, `-cw` | option | Width of CNN layers. |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user