Add tag_map argument to CLI debug-data and train (#4750) (#5038)

Add an argument for a path to a JSON-formatted tag map, which is used to
update and extend the default language tag map.
This commit is contained in:
adrianeboyd 2020-02-26 12:10:38 +01:00 committed by GitHub
parent c7e3c034d2
commit ff184b7a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View File

@ -26,6 +26,7 @@ BLANK_MODEL_THRESHOLD = 2000
lang=("model language", "positional", None, str), lang=("model language", "positional", None, str),
train_path=("location of JSON-formatted training data", "positional", None, Path), train_path=("location of JSON-formatted training data", "positional", None, Path),
dev_path=("location of JSON-formatted development data", "positional", None, Path), dev_path=("location of JSON-formatted development data", "positional", None, Path),
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
base_model=("name of model to update (optional)", "option", "b", str), base_model=("name of model to update (optional)", "option", "b", str),
pipeline=( pipeline=(
"Comma-separated names of pipeline components to train", "Comma-separated names of pipeline components to train",
@ -41,6 +42,7 @@ def debug_data(
lang, lang,
train_path, train_path,
dev_path, dev_path,
tag_map_path=None,
base_model=None, base_model=None,
pipeline="tagger,parser,ner", pipeline="tagger,parser,ner",
ignore_warnings=False, ignore_warnings=False,
@ -60,6 +62,10 @@ def debug_data(
if not dev_path.exists(): if not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1) msg.fail("Development data not found", dev_path, exits=1)
tag_map = {}
if tag_map_path is not None:
tag_map = srsly.read_json(tag_map_path)
# Initialize the model and pipeline # Initialize the model and pipeline
pipeline = [p.strip() for p in pipeline.split(",")] pipeline = [p.strip() for p in pipeline.split(",")]
if base_model: if base_model:
@ -67,6 +73,8 @@ def debug_data(
else: else:
lang_cls = get_lang_class(lang) lang_cls = get_lang_class(lang)
nlp = lang_cls() nlp = lang_cls()
# Update tag map with provided mapping
nlp.vocab.morphology.tag_map.update(tag_map)
msg.divider("Data format validation") msg.divider("Data format validation")
@ -344,7 +352,7 @@ def debug_data(
if "tagger" in pipeline: if "tagger" in pipeline:
msg.divider("Part-of-speech Tagging") msg.divider("Part-of-speech Tagging")
labels = [label for label in gold_train_data["tags"]] labels = [label for label in gold_train_data["tags"]]
tag_map = nlp.Defaults.tag_map tag_map = nlp.vocab.morphology.tag_map
msg.info( msg.info(
"{} {} in data ({} {} in tag map)".format( "{} {} in data ({} {} in tag map)".format(
len(labels), len(labels),

View File

@ -57,6 +57,7 @@ from .. import about
textcat_multilabel=("Textcat classes aren't mutually exclusive (multilabel)", "flag", "TML", bool), textcat_multilabel=("Textcat classes aren't mutually exclusive (multilabel)", "flag", "TML", bool),
textcat_arch=("Textcat model architecture", "option", "ta", str), textcat_arch=("Textcat model architecture", "option", "ta", str),
textcat_positive_label=("Textcat positive label for binary classes with two labels", "option", "tpl", str), textcat_positive_label=("Textcat positive label for binary classes with two labels", "option", "tpl", str),
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
verbose=("Display more information for debug", "flag", "VV", bool), verbose=("Display more information for debug", "flag", "VV", bool),
debug=("Run data diagnostics before training", "flag", "D", bool), debug=("Run data diagnostics before training", "flag", "D", bool),
# fmt: on # fmt: on
@ -95,6 +96,7 @@ def train(
textcat_multilabel=False, textcat_multilabel=False,
textcat_arch="bow", textcat_arch="bow",
textcat_positive_label=None, textcat_positive_label=None,
tag_map_path=None,
verbose=False, verbose=False,
debug=False, debug=False,
): ):
@ -132,6 +134,9 @@ def train(
output_path.mkdir() output_path.mkdir()
msg.good("Created output directory: {}".format(output_path)) msg.good("Created output directory: {}".format(output_path))
tag_map = {}
if tag_map_path is not None:
tag_map = srsly.read_json(tag_map_path)
# Take dropout and batch size as generators of values -- dropout # Take dropout and batch size as generators of values -- dropout
# starts high and decays sharply, to force the optimizer to explore. # starts high and decays sharply, to force the optimizer to explore.
# Batch size starts at 1 and grows, so that we make updates quickly # Batch size starts at 1 and grows, so that we make updates quickly
@ -238,6 +243,9 @@ def train(
pipe_cfg = {} pipe_cfg = {}
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
# Update tag map with provided mapping
nlp.vocab.morphology.tag_map.update(tag_map)
if vectors: if vectors:
msg.text("Loading vector from model '{}'".format(vectors)) msg.text("Loading vector from model '{}'".format(vectors))
_load_vectors(nlp, vectors) _load_vectors(nlp, vectors)