Update and auto-format

This commit is contained in:
Ines Montani 2020-07-10 20:52:00 +02:00
parent 0389c34b81
commit bfa8e11ffa
2 changed files with 14 additions and 15 deletions

View File

@ -1,10 +1,9 @@
from typing import List
from pathlib import Path
from wasabi import msg
from ._app import app, Arg, Opt
from .. import util
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
from ._util import app, Arg, Opt
from .. import util
from ..lang.en import English
@ -50,16 +49,11 @@ def debug_model_cli(
msg.info(f"Using CPU")
debug_model(
config_path,
print_settings=print_settings,
config_path, print_settings=print_settings,
)
def debug_model(
config_path: Path,
*,
print_settings=None
):
def debug_model(config_path: Path, *, print_settings=None):
if print_settings is None:
print_settings = {}
@ -83,7 +77,7 @@ def debug_model(
for e in range(3):
Y, get_dX = model.begin_update(_get_docs())
dY = get_gradient(model, Y)
_ = get_dX(dY)
get_dX(dY)
model.finish_update(optimizer)
if print_settings.get("print_after_training"):
msg.info(f"After training:")
@ -115,7 +109,12 @@ def _get_docs():
def _get_output(xp):
return xp.asarray([xp.asarray([i+10, i+20, i+30], dtype="float32") for i, _ in enumerate(_get_docs())])
return xp.asarray(
[
xp.asarray([i + 10, i + 20, i + 30], dtype="float32")
for i, _ in enumerate(_get_docs())
]
)
def _print_model(model, print_settings):
@ -161,7 +160,7 @@ def _print_matrix(value):
return value
result = str(value.shape) + " - sample: "
sample_matrix = value
for d in range(value.ndim-1):
for d in range(value.ndim - 1):
sample_matrix = sample_matrix[0]
sample_matrix = sample_matrix[0:5]
result = result + str(sample_matrix)

View File

@ -201,7 +201,7 @@ class ConfigSchemaTraining(BaseModel):
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
seed: StrictInt = Field(..., title="Random seed")
seed: Optional[StrictInt] = Field(..., title="Random seed")
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
use_gpu: StrictInt = Field(..., title="GPU ID or -1 for CPU")