mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Add an argument for a path to a JSON-formatted tag map, which is used to update and extend the default language tag map.
This commit is contained in:
parent
c7e3c034d2
commit
ff184b7a9c
|
@ -26,6 +26,7 @@ BLANK_MODEL_THRESHOLD = 2000
|
||||||
lang=("model language", "positional", None, str),
|
lang=("model language", "positional", None, str),
|
||||||
train_path=("location of JSON-formatted training data", "positional", None, Path),
|
train_path=("location of JSON-formatted training data", "positional", None, Path),
|
||||||
dev_path=("location of JSON-formatted development data", "positional", None, Path),
|
dev_path=("location of JSON-formatted development data", "positional", None, Path),
|
||||||
|
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
||||||
base_model=("name of model to update (optional)", "option", "b", str),
|
base_model=("name of model to update (optional)", "option", "b", str),
|
||||||
pipeline=(
|
pipeline=(
|
||||||
"Comma-separated names of pipeline components to train",
|
"Comma-separated names of pipeline components to train",
|
||||||
|
@ -41,6 +42,7 @@ def debug_data(
|
||||||
lang,
|
lang,
|
||||||
train_path,
|
train_path,
|
||||||
dev_path,
|
dev_path,
|
||||||
|
tag_map_path=None,
|
||||||
base_model=None,
|
base_model=None,
|
||||||
pipeline="tagger,parser,ner",
|
pipeline="tagger,parser,ner",
|
||||||
ignore_warnings=False,
|
ignore_warnings=False,
|
||||||
|
@ -60,6 +62,10 @@ def debug_data(
|
||||||
if not dev_path.exists():
|
if not dev_path.exists():
|
||||||
msg.fail("Development data not found", dev_path, exits=1)
|
msg.fail("Development data not found", dev_path, exits=1)
|
||||||
|
|
||||||
|
tag_map = {}
|
||||||
|
if tag_map_path is not None:
|
||||||
|
tag_map = srsly.read_json(tag_map_path)
|
||||||
|
|
||||||
# Initialize the model and pipeline
|
# Initialize the model and pipeline
|
||||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
pipeline = [p.strip() for p in pipeline.split(",")]
|
||||||
if base_model:
|
if base_model:
|
||||||
|
@ -67,6 +73,8 @@ def debug_data(
|
||||||
else:
|
else:
|
||||||
lang_cls = get_lang_class(lang)
|
lang_cls = get_lang_class(lang)
|
||||||
nlp = lang_cls()
|
nlp = lang_cls()
|
||||||
|
# Update tag map with provided mapping
|
||||||
|
nlp.vocab.morphology.tag_map.update(tag_map)
|
||||||
|
|
||||||
msg.divider("Data format validation")
|
msg.divider("Data format validation")
|
||||||
|
|
||||||
|
@ -344,7 +352,7 @@ def debug_data(
|
||||||
if "tagger" in pipeline:
|
if "tagger" in pipeline:
|
||||||
msg.divider("Part-of-speech Tagging")
|
msg.divider("Part-of-speech Tagging")
|
||||||
labels = [label for label in gold_train_data["tags"]]
|
labels = [label for label in gold_train_data["tags"]]
|
||||||
tag_map = nlp.Defaults.tag_map
|
tag_map = nlp.vocab.morphology.tag_map
|
||||||
msg.info(
|
msg.info(
|
||||||
"{} {} in data ({} {} in tag map)".format(
|
"{} {} in data ({} {} in tag map)".format(
|
||||||
len(labels),
|
len(labels),
|
||||||
|
|
|
@ -57,6 +57,7 @@ from .. import about
|
||||||
textcat_multilabel=("Textcat classes aren't mutually exclusive (multilabel)", "flag", "TML", bool),
|
textcat_multilabel=("Textcat classes aren't mutually exclusive (multilabel)", "flag", "TML", bool),
|
||||||
textcat_arch=("Textcat model architecture", "option", "ta", str),
|
textcat_arch=("Textcat model architecture", "option", "ta", str),
|
||||||
textcat_positive_label=("Textcat positive label for binary classes with two labels", "option", "tpl", str),
|
textcat_positive_label=("Textcat positive label for binary classes with two labels", "option", "tpl", str),
|
||||||
|
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
||||||
verbose=("Display more information for debug", "flag", "VV", bool),
|
verbose=("Display more information for debug", "flag", "VV", bool),
|
||||||
debug=("Run data diagnostics before training", "flag", "D", bool),
|
debug=("Run data diagnostics before training", "flag", "D", bool),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -95,6 +96,7 @@ def train(
|
||||||
textcat_multilabel=False,
|
textcat_multilabel=False,
|
||||||
textcat_arch="bow",
|
textcat_arch="bow",
|
||||||
textcat_positive_label=None,
|
textcat_positive_label=None,
|
||||||
|
tag_map_path=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
debug=False,
|
debug=False,
|
||||||
):
|
):
|
||||||
|
@ -132,6 +134,9 @@ def train(
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
msg.good("Created output directory: {}".format(output_path))
|
msg.good("Created output directory: {}".format(output_path))
|
||||||
|
|
||||||
|
tag_map = {}
|
||||||
|
if tag_map_path is not None:
|
||||||
|
tag_map = srsly.read_json(tag_map_path)
|
||||||
# Take dropout and batch size as generators of values -- dropout
|
# Take dropout and batch size as generators of values -- dropout
|
||||||
# starts high and decays sharply, to force the optimizer to explore.
|
# starts high and decays sharply, to force the optimizer to explore.
|
||||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||||
|
@ -238,6 +243,9 @@ def train(
|
||||||
pipe_cfg = {}
|
pipe_cfg = {}
|
||||||
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
|
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
|
||||||
|
|
||||||
|
# Update tag map with provided mapping
|
||||||
|
nlp.vocab.morphology.tag_map.update(tag_map)
|
||||||
|
|
||||||
if vectors:
|
if vectors:
|
||||||
msg.text("Loading vector from model '{}'".format(vectors))
|
msg.text("Loading vector from model '{}'".format(vectors))
|
||||||
_load_vectors(nlp, vectors)
|
_load_vectors(nlp, vectors)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user