mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
fixing output sample to proper 2D array
This commit is contained in:
parent
35a3931064
commit
e4fc7e0222
|
@ -60,13 +60,12 @@ def debug_model_cli(
|
||||||
msg.info(f"Fixing random seed: {seed}")
|
msg.info(f"Fixing random seed: {seed}")
|
||||||
fix_random_seed(seed)
|
fix_random_seed(seed)
|
||||||
pipe = nlp.get_pipe(component)
|
pipe = nlp.get_pipe(component)
|
||||||
if hasattr(pipe, "model"):
|
if not hasattr(pipe, "model"):
|
||||||
model = pipe.model
|
|
||||||
else:
|
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"The component '{component}' does not specify an object that holds a Model.",
|
f"The component '{component}' does not specify an object that holds a Model.",
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
model = pipe.model
|
||||||
debug_model(model, print_settings=print_settings)
|
debug_model(model, print_settings=print_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,7 +86,7 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# STEP 1: Initializing the model and printing again
|
# STEP 1: Initializing the model and printing again
|
||||||
X = _get_docs()
|
X = _get_docs()
|
||||||
Y = _get_output(model.ops.xp)
|
Y = _get_output(model.ops)
|
||||||
# The output vector might differ from the official type of the output layer
|
# The output vector might differ from the official type of the output layer
|
||||||
with data_validation(False):
|
with data_validation(False):
|
||||||
model.initialize(X=X, Y=Y)
|
model.initialize(X=X, Y=Y)
|
||||||
|
@ -113,9 +112,11 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
msg.divider(f"STEP 3 - prediction")
|
msg.divider(f"STEP 3 - prediction")
|
||||||
msg.info(str(prediction))
|
msg.info(str(prediction))
|
||||||
|
|
||||||
|
msg.good(f"Succesfully ended analysis - model looks good!")
|
||||||
|
|
||||||
|
|
||||||
def get_gradient(model, Y):
|
def get_gradient(model, Y):
|
||||||
goldY = _get_output(model.ops.xp)
|
goldY = _get_output(model.ops)
|
||||||
return Y - goldY
|
return Y - goldY
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,8 +134,14 @@ def _get_docs(lang: str = "en"):
|
||||||
return list(nlp.pipe(_sentences()))
|
return list(nlp.pipe(_sentences()))
|
||||||
|
|
||||||
|
|
||||||
def _get_output(xp):
|
def _get_output(ops):
|
||||||
return xp.asarray([i + 10 for i, _ in enumerate(_get_docs())], dtype="float32")
|
docs = len(_get_docs())
|
||||||
|
labels = 6
|
||||||
|
output = ops.alloc2f(d0=docs, d1=labels)
|
||||||
|
for i in range(docs):
|
||||||
|
for j in range(labels):
|
||||||
|
output[i, j] = 1 / (i+j+0.01)
|
||||||
|
return ops.xp.asarray(output)
|
||||||
|
|
||||||
|
|
||||||
def _print_model(model, print_settings):
|
def _print_model(model, print_settings):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user