mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix 'debug model' for transformers + generalize (#7973)
* add overrides to docs * fix debug model with transformer * assume training data is set in config
This commit is contained in:
parent
cc5aeaed29
commit
02a6a5fea0
|
@ -1,5 +1,6 @@
|
||||||
from typing import Dict, Any, Optional, Iterable
|
from typing import Dict, Any, Optional, Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import itertools
|
||||||
|
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.util import resolve_dot_names
|
from spacy.util import resolve_dot_names
|
||||||
|
@ -73,23 +74,24 @@ def debug_model_cli(
|
||||||
msg.info(f"Fixing random seed: {seed}")
|
msg.info(f"Fixing random seed: {seed}")
|
||||||
fix_random_seed(seed)
|
fix_random_seed(seed)
|
||||||
pipe = nlp.get_pipe(component)
|
pipe = nlp.get_pipe(component)
|
||||||
if not hasattr(pipe, "model"):
|
|
||||||
msg.fail(
|
debug_model(config, T, nlp, pipe, print_settings=print_settings)
|
||||||
f"The component '{component}' does not specify an object that holds a Model.",
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
model = pipe.model
|
|
||||||
debug_model(config, T, nlp, model, print_settings=print_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def debug_model(
|
def debug_model(
|
||||||
config,
|
config,
|
||||||
resolved_train_config,
|
resolved_train_config,
|
||||||
nlp,
|
nlp,
|
||||||
model: Model,
|
pipe,
|
||||||
*,
|
*,
|
||||||
print_settings: Optional[Dict[str, Any]] = None,
|
print_settings: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
if not hasattr(pipe, "model"):
|
||||||
|
msg.fail(
|
||||||
|
f"The component '{pipe}' does not specify an object that holds a Model.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
model = pipe.model
|
||||||
if not isinstance(model, Model):
|
if not isinstance(model, Model):
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"Requires a Thinc Model to be analysed, but found {type(model)} instead.",
|
f"Requires a Thinc Model to be analysed, but found {type(model)} instead.",
|
||||||
|
@ -105,8 +107,6 @@ def debug_model(
|
||||||
_print_model(model, print_settings)
|
_print_model(model, print_settings)
|
||||||
|
|
||||||
# STEP 1: Initializing the model and printing again
|
# STEP 1: Initializing the model and printing again
|
||||||
X = _get_docs()
|
|
||||||
# The output vector might differ from the official type of the output layer
|
|
||||||
with data_validation(False):
|
with data_validation(False):
|
||||||
try:
|
try:
|
||||||
dot_names = [resolved_train_config["train_corpus"]]
|
dot_names = [resolved_train_config["train_corpus"]]
|
||||||
|
@ -114,15 +114,17 @@ def debug_model(
|
||||||
(train_corpus,) = resolve_dot_names(config, dot_names)
|
(train_corpus,) = resolve_dot_names(config, dot_names)
|
||||||
nlp.initialize(lambda: train_corpus(nlp))
|
nlp.initialize(lambda: train_corpus(nlp))
|
||||||
msg.info("Initialized the model with the training corpus.")
|
msg.info("Initialized the model with the training corpus.")
|
||||||
|
examples = list(itertools.islice(train_corpus(nlp), 5))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
try:
|
try:
|
||||||
_set_output_dim(nO=7, model=model)
|
_set_output_dim(nO=7, model=model)
|
||||||
with show_validation_error():
|
with show_validation_error():
|
||||||
nlp.initialize(lambda: [Example.from_dict(x, {}) for x in X])
|
examples = [Example.from_dict(x, {}) for x in _get_docs()]
|
||||||
|
nlp.initialize(lambda: examples)
|
||||||
msg.info("Initialized the model with dummy data.")
|
msg.info("Initialized the model with dummy data.")
|
||||||
except Exception:
|
except Exception:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
|
"Could not initialize the model: you'll have to provide a valid 'train_corpus' argument in the config file.",
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -133,26 +135,23 @@ def debug_model(
|
||||||
# STEP 2: Updating the model and printing again
|
# STEP 2: Updating the model and printing again
|
||||||
optimizer = Adam(0.001)
|
optimizer = Adam(0.001)
|
||||||
set_dropout_rate(model, 0.2)
|
set_dropout_rate(model, 0.2)
|
||||||
# ugly hack to deal with Tok2Vec listeners
|
# ugly hack to deal with Tok2Vec/Transformer listeners
|
||||||
tok2vec = None
|
upstream_component = None
|
||||||
if model.has_ref("tok2vec") and model.get_ref("tok2vec").name == "tok2vec-listener":
|
if model.has_ref("tok2vec") and "tok2vec-listener" in model.get_ref("tok2vec").name:
|
||||||
tok2vec = nlp.get_pipe("tok2vec")
|
upstream_component = nlp.get_pipe("tok2vec")
|
||||||
|
if model.has_ref("tok2vec") and "transformer-listener" in model.get_ref("tok2vec").name:
|
||||||
|
upstream_component = nlp.get_pipe("transformer")
|
||||||
goldY = None
|
goldY = None
|
||||||
for e in range(3):
|
for e in range(3):
|
||||||
if tok2vec:
|
if upstream_component:
|
||||||
tok2vec.update([Example.from_dict(x, {}) for x in X])
|
upstream_component.update(examples)
|
||||||
Y, get_dX = model.begin_update(X)
|
pipe.update(examples)
|
||||||
if goldY is None:
|
|
||||||
goldY = _simulate_gold(Y)
|
|
||||||
dY = get_gradient(goldY, Y, model.ops)
|
|
||||||
get_dX(dY)
|
|
||||||
model.finish_update(optimizer)
|
|
||||||
if print_settings.get("print_after_training"):
|
if print_settings.get("print_after_training"):
|
||||||
msg.divider(f"STEP 2 - after training")
|
msg.divider(f"STEP 2 - after training")
|
||||||
_print_model(model, print_settings)
|
_print_model(model, print_settings)
|
||||||
|
|
||||||
# STEP 3: the final prediction
|
# STEP 3: the final prediction
|
||||||
prediction = model.predict(X)
|
prediction = model.predict([ex.predicted for ex in examples])
|
||||||
if print_settings.get("print_prediction"):
|
if print_settings.get("print_prediction"):
|
||||||
msg.divider(f"STEP 3 - prediction")
|
msg.divider(f"STEP 3 - prediction")
|
||||||
msg.info(str(prediction))
|
msg.info(str(prediction))
|
||||||
|
@ -160,19 +159,6 @@ def debug_model(
|
||||||
msg.good(f"Succesfully ended analysis - model looks good.")
|
msg.good(f"Succesfully ended analysis - model looks good.")
|
||||||
|
|
||||||
|
|
||||||
def get_gradient(goldY, Y, ops):
|
|
||||||
return ops.asarray(Y) - ops.asarray(goldY)
|
|
||||||
|
|
||||||
|
|
||||||
def _simulate_gold(element, counter=1):
|
|
||||||
if isinstance(element, Iterable):
|
|
||||||
for i in range(len(element)):
|
|
||||||
element[i] = _simulate_gold(element[i], counter + i)
|
|
||||||
return element
|
|
||||||
else:
|
|
||||||
return 1 / counter
|
|
||||||
|
|
||||||
|
|
||||||
def _sentences():
|
def _sentences():
|
||||||
return [
|
return [
|
||||||
"Apple is looking at buying U.K. startup for $1 billion",
|
"Apple is looking at buying U.K. startup for $1 billion",
|
||||||
|
@ -209,11 +195,7 @@ def _print_model(model, print_settings):
|
||||||
|
|
||||||
if dimensions:
|
if dimensions:
|
||||||
for name in node.dim_names:
|
for name in node.dim_names:
|
||||||
if node.has_dim(name):
|
msg.info(f" - dim {name}: {node.maybe_get_dim(name)}")
|
||||||
msg.info(f" - dim {name}: {node.get_dim(name)}")
|
|
||||||
else:
|
|
||||||
msg.info(f" - dim {name}: {node.has_dim(name)}")
|
|
||||||
|
|
||||||
if parameters:
|
if parameters:
|
||||||
for name in node.param_names:
|
for name in node.param_names:
|
||||||
if node.has_param(name):
|
if node.has_param(name):
|
||||||
|
|
|
@ -768,6 +768,7 @@ $ python -m spacy debug model ./config.cfg tagger -l "5,15" -DIM -PAR -P0 -P1 -P
|
||||||
| `--print-step3`, `-P3` | Print final predictions. ~~bool (flag)~~ |
|
| `--print-step3`, `-P3` | Print final predictions. ~~bool (flag)~~ |
|
||||||
| `--gpu-id`, `-g` | GPU ID or `-1` for CPU. Defaults to `-1`. ~~int (option)~~ |
|
| `--gpu-id`, `-g` | GPU ID or `-1` for CPU. Defaults to `-1`. ~~int (option)~~ |
|
||||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
|
| overrides | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--paths.train ./train.spacy`. ~~Any (option/flag)~~ |
|
||||||
| **PRINTS** | Debugging information. |
|
| **PRINTS** | Debugging information. |
|
||||||
|
|
||||||
## train {#train tag="command"}
|
## train {#train tag="command"}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user