Merge pull request #5855 from svlandeg/fix/cli-debug

This commit is contained in:
Ines Montani 2020-08-03 13:09:20 +02:00 committed by GitHub
commit 934447a611
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 20 deletions

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a20,<8.0.0a30", "thinc>=8.0.0a21,<8.0.0a30",
"blis>=0.4.0,<0.5.0", "blis>=0.4.0,<0.5.0",
"pytokenizations", "pytokenizations",
"smart_open>=2.0.0,<3.0.0" "smart_open>=2.0.0,<3.0.0"

View File

@ -1,7 +1,7 @@
# Our libraries # Our libraries
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a20,<8.0.0a30 thinc>=8.0.0a21,<8.0.0a30
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1 ml_datasets>=0.1.1
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a20,<8.0.0a30 thinc>=8.0.0a21,<8.0.0a30
install_requires = install_requires =
# Our libraries # Our libraries
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a20,<8.0.0a30 thinc>=8.0.0a21,<8.0.0a30
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
wasabi>=0.7.1,<1.1.0 wasabi>=0.7.1,<1.1.0
srsly>=2.1.0,<3.0.0 srsly>=2.1.0,<3.0.0

View File

@ -2,7 +2,7 @@ from typing import Dict, Any, Optional
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam, Config from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam, Config
from thinc.api import Model from thinc.api import Model, data_validation
import typer import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
@ -16,7 +16,7 @@ def debug_model_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True), config_path: Path = Arg(..., help="Path to config file", exists=True),
section: str = Arg(..., help="Section that defines the model to be analysed"), component: str = Arg(..., help="Name of the pipeline component of which the model should be analysed"),
layers: str = Opt("", "--layers", "-l", help="Comma-separated names of layer IDs to print"), layers: str = Opt("", "--layers", "-l", help="Comma-separated names of layer IDs to print"),
dimensions: bool = Opt(False, "--dimensions", "-DIM", help="Show dimensions"), dimensions: bool = Opt(False, "--dimensions", "-DIM", help="Show dimensions"),
parameters: bool = Opt(False, "--parameters", "-PAR", help="Show parameters"), parameters: bool = Opt(False, "--parameters", "-PAR", help="Show parameters"),
@ -25,7 +25,7 @@ def debug_model_cli(
P0: bool = Opt(False, "--print-step0", "-P0", help="Print model before training"), P0: bool = Opt(False, "--print-step0", "-P0", help="Print model before training"),
P1: bool = Opt(False, "--print-step1", "-P1", help="Print model after initialization"), P1: bool = Opt(False, "--print-step1", "-P1", help="Print model after initialization"),
P2: bool = Opt(False, "--print-step2", "-P2", help="Print model after training"), P2: bool = Opt(False, "--print-step2", "-P2", help="Print model after training"),
P3: bool = Opt(True, "--print-step3", "-P3", help="Print final predictions"), P3: bool = Opt(False, "--print-step3", "-P3", help="Print final predictions"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU") use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
# fmt: on # fmt: on
): ):
@ -53,20 +53,20 @@ def debug_model_cli(
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = Config().from_disk(config_path) cfg = Config().from_disk(config_path)
try: try:
_, config = util.load_model_from_config(cfg, overrides=config_overrides) nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
except ValueError as e: except ValueError as e:
msg.fail(str(e), exits=1) msg.fail(str(e), exits=1)
seed = config["pretraining"]["seed"] seed = config.get("training", {}).get("seed", None)
if seed is not None: if seed is not None:
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed) fix_random_seed(seed)
component = dot_to_object(config, section) pipe = nlp.get_pipe(component)
if hasattr(component, "model"): if hasattr(pipe, "model"):
model = component.model model = pipe.model
else: else:
msg.fail( msg.fail(
f"The section '{section}' does not specify an object that holds a Model.", f"The component '{component}' does not specify an object that holds a Model.",
exits=1, exits=1,
) )
debug_model(model, print_settings=print_settings) debug_model(model, print_settings=print_settings)
@ -84,15 +84,17 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
# STEP 0: Printing before training # STEP 0: Printing before training
msg.info(f"Analysing model with ID {model.id}") msg.info(f"Analysing model with ID {model.id}")
if print_settings.get("print_before_training"): if print_settings.get("print_before_training"):
msg.info(f"Before training:") msg.divider(f"STEP 0 - before training")
_print_model(model, print_settings) _print_model(model, print_settings)
# STEP 1: Initializing the model and printing again # STEP 1: Initializing the model and printing again
Y = _get_output(model.ops.xp) Y = _get_output(model.ops.xp)
_set_output_dim(nO=Y.shape[-1], model=model) _set_output_dim(nO=Y.shape[-1], model=model)
# The output vector might differ from the official type of the output layer
with data_validation(False):
model.initialize(X=_get_docs(), Y=Y) model.initialize(X=_get_docs(), Y=Y)
if print_settings.get("print_after_init"): if print_settings.get("print_after_init"):
msg.info(f"After initialization:") msg.divider(f"STEP 1 - after initialization")
_print_model(model, print_settings) _print_model(model, print_settings)
# STEP 2: Updating the model and printing again # STEP 2: Updating the model and printing again
@ -104,13 +106,14 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
get_dX(dY) get_dX(dY)
model.finish_update(optimizer) model.finish_update(optimizer)
if print_settings.get("print_after_training"): if print_settings.get("print_after_training"):
msg.info(f"After training:") msg.divider(f"STEP 2 - after training")
_print_model(model, print_settings) _print_model(model, print_settings)
# STEP 3: the final prediction # STEP 3: the final prediction
prediction = model.predict(_get_docs()) prediction = model.predict(_get_docs())
if print_settings.get("print_prediction"): if print_settings.get("print_prediction"):
msg.info(f"Prediction:", str(prediction)) msg.divider(f"STEP 3 - prediction")
msg.info(str(prediction))
def get_gradient(model, Y): def get_gradient(model, Y):

View File

@ -51,7 +51,7 @@ def train_cli(
referenced in the config. referenced in the config.
""" """
util.set_env_log(verbose) util.set_env_log(verbose)
verify_cli_args(train_path, dev_path, config_path) verify_cli_args(train_path, dev_path, config_path, output_path)
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
import_code(code_path) import_code(code_path)
train( train(
@ -174,7 +174,6 @@ def train(
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
except Exception as e: except Exception as e:
if output_path is not None: if output_path is not None:
raise e
msg.warn( msg.warn(
f"Aborting and saving the final best model. " f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}", f"Encountered exception: {str(e)}",