mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-28 04:43:42 +03:00
hack for tok2vec listener
This commit is contained in:
parent
e4fc7e0222
commit
73ff52b9ec
|
@ -66,10 +66,12 @@ def debug_model_cli(
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
model = pipe.model
|
model = pipe.model
|
||||||
debug_model(model, print_settings=print_settings)
|
# call _link_components directly as we won't call nlp.begin_training
|
||||||
|
nlp._link_components()
|
||||||
|
debug_model(nlp, model, print_settings=print_settings)
|
||||||
|
|
||||||
|
|
||||||
def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None):
|
def debug_model(nlp, model: Model, *, print_settings: Optional[Dict[str, Any]] = None):
|
||||||
if not isinstance(model, Model):
|
if not isinstance(model, Model):
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"Requires a Thinc Model to be analysed, but found {type(model)} instead.",
|
f"Requires a Thinc Model to be analysed, but found {type(model)} instead.",
|
||||||
|
@ -86,10 +88,10 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# STEP 1: Initializing the model and printing again
|
# STEP 1: Initializing the model and printing again
|
||||||
X = _get_docs()
|
X = _get_docs()
|
||||||
Y = _get_output(model.ops)
|
goldY = _get_output(model.ops)
|
||||||
# The output vector might differ from the official type of the output layer
|
# The output vector might differ from the official type of the output layer
|
||||||
with data_validation(False):
|
with data_validation(False):
|
||||||
model.initialize(X=X, Y=Y)
|
model.initialize(X=X, Y=goldY)
|
||||||
if print_settings.get("print_after_init"):
|
if print_settings.get("print_after_init"):
|
||||||
msg.divider(f"STEP 1 - after initialization")
|
msg.divider(f"STEP 1 - after initialization")
|
||||||
_print_model(model, print_settings)
|
_print_model(model, print_settings)
|
||||||
|
@ -97,9 +99,16 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
# STEP 2: Updating the model and printing again
|
# STEP 2: Updating the model and printing again
|
||||||
optimizer = Adam(0.001)
|
optimizer = Adam(0.001)
|
||||||
set_dropout_rate(model, 0.2)
|
set_dropout_rate(model, 0.2)
|
||||||
|
# ugly hack to deal with Tok2Vec listeners
|
||||||
|
tok2vec = None
|
||||||
|
if model.has_ref("tok2vec") and model.get_ref("tok2vec").name == "tok2vec-listener":
|
||||||
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
tok2vec.model.initialize(X=X)
|
||||||
for e in range(3):
|
for e in range(3):
|
||||||
Y, get_dX = model.begin_update(_get_docs())
|
if tok2vec:
|
||||||
dY = get_gradient(model, Y)
|
tok2vec.predict(X)
|
||||||
|
Y, get_dX = model.begin_update(X)
|
||||||
|
dY = get_gradient(goldY, Y)
|
||||||
get_dX(dY)
|
get_dX(dY)
|
||||||
model.finish_update(optimizer)
|
model.finish_update(optimizer)
|
||||||
if print_settings.get("print_after_training"):
|
if print_settings.get("print_after_training"):
|
||||||
|
@ -107,7 +116,7 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
_print_model(model, print_settings)
|
_print_model(model, print_settings)
|
||||||
|
|
||||||
# STEP 3: the final prediction
|
# STEP 3: the final prediction
|
||||||
prediction = model.predict(_get_docs())
|
prediction = model.predict(X)
|
||||||
if print_settings.get("print_prediction"):
|
if print_settings.get("print_prediction"):
|
||||||
msg.divider(f"STEP 3 - prediction")
|
msg.divider(f"STEP 3 - prediction")
|
||||||
msg.info(str(prediction))
|
msg.info(str(prediction))
|
||||||
|
@ -115,8 +124,7 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
msg.good(f"Succesfully ended analysis - model looks good!")
|
msg.good(f"Succesfully ended analysis - model looks good!")
|
||||||
|
|
||||||
|
|
||||||
def get_gradient(model, Y):
|
def get_gradient(goldY, Y):
|
||||||
goldY = _get_output(model.ops)
|
|
||||||
return Y - goldY
|
return Y - goldY
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -545,7 +545,8 @@ class Errors:
|
||||||
E949 = ("Can only create an alignment when the texts are the same.")
|
E949 = ("Can only create an alignment when the texts are the same.")
|
||||||
E952 = ("The section '{name}' is not a valid section in the provided config.")
|
E952 = ("The section '{name}' is not a valid section in the provided config.")
|
||||||
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||||
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
E954 = ("The Tok2Vec listener did not receive any valid input from an upstream "
|
||||||
|
"component.")
|
||||||
E955 = ("Can't find table(s) '{table}' for language '{lang}' in spacy-lookups-data.")
|
E955 = ("Can't find table(s) '{table}' for language '{lang}' in spacy-lookups-data.")
|
||||||
E956 = ("Can't find component '{name}' in [components] block in the config. "
|
E956 = ("Can't find component '{name}' in [components] block in the config. "
|
||||||
"Available components: {opts}")
|
"Available components: {opts}")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user