mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-30 18:53:36 +03:00
Update and auto-format
This commit is contained in:
parent
0389c34b81
commit
bfa8e11ffa
|
@ -1,10 +1,9 @@
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
|
|
||||||
from ._app import app, Arg, Opt
|
|
||||||
from .. import util
|
|
||||||
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
|
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
|
||||||
|
|
||||||
|
from ._util import app, Arg, Opt
|
||||||
|
from .. import util
|
||||||
from ..lang.en import English
|
from ..lang.en import English
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,16 +49,11 @@ def debug_model_cli(
|
||||||
msg.info(f"Using CPU")
|
msg.info(f"Using CPU")
|
||||||
|
|
||||||
debug_model(
|
debug_model(
|
||||||
config_path,
|
config_path, print_settings=print_settings,
|
||||||
print_settings=print_settings,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def debug_model(
|
def debug_model(config_path: Path, *, print_settings=None):
|
||||||
config_path: Path,
|
|
||||||
*,
|
|
||||||
print_settings=None
|
|
||||||
):
|
|
||||||
if print_settings is None:
|
if print_settings is None:
|
||||||
print_settings = {}
|
print_settings = {}
|
||||||
|
|
||||||
|
@ -83,7 +77,7 @@ def debug_model(
|
||||||
for e in range(3):
|
for e in range(3):
|
||||||
Y, get_dX = model.begin_update(_get_docs())
|
Y, get_dX = model.begin_update(_get_docs())
|
||||||
dY = get_gradient(model, Y)
|
dY = get_gradient(model, Y)
|
||||||
_ = get_dX(dY)
|
get_dX(dY)
|
||||||
model.finish_update(optimizer)
|
model.finish_update(optimizer)
|
||||||
if print_settings.get("print_after_training"):
|
if print_settings.get("print_after_training"):
|
||||||
msg.info(f"After training:")
|
msg.info(f"After training:")
|
||||||
|
@ -115,7 +109,12 @@ def _get_docs():
|
||||||
|
|
||||||
|
|
||||||
def _get_output(xp):
|
def _get_output(xp):
|
||||||
return xp.asarray([xp.asarray([i+10, i+20, i+30], dtype="float32") for i, _ in enumerate(_get_docs())])
|
return xp.asarray(
|
||||||
|
[
|
||||||
|
xp.asarray([i + 10, i + 20, i + 30], dtype="float32")
|
||||||
|
for i, _ in enumerate(_get_docs())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _print_model(model, print_settings):
|
def _print_model(model, print_settings):
|
||||||
|
@ -161,7 +160,7 @@ def _print_matrix(value):
|
||||||
return value
|
return value
|
||||||
result = str(value.shape) + " - sample: "
|
result = str(value.shape) + " - sample: "
|
||||||
sample_matrix = value
|
sample_matrix = value
|
||||||
for d in range(value.ndim-1):
|
for d in range(value.ndim - 1):
|
||||||
sample_matrix = sample_matrix[0]
|
sample_matrix = sample_matrix[0]
|
||||||
sample_matrix = sample_matrix[0:5]
|
sample_matrix = sample_matrix[0:5]
|
||||||
result = result + str(sample_matrix)
|
result = result + str(sample_matrix)
|
||||||
|
|
|
@ -201,7 +201,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||||
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
|
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
|
||||||
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
|
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
|
||||||
seed: StrictInt = Field(..., title="Random seed")
|
seed: Optional[StrictInt] = Field(..., title="Random seed")
|
||||||
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
|
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
|
||||||
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||||
use_gpu: StrictInt = Field(..., title="GPU ID or -1 for CPU")
|
use_gpu: StrictInt = Field(..., title="GPU ID or -1 for CPU")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user