mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
fixing output sample to proper 2D array
This commit is contained in:
parent
35a3931064
commit
e4fc7e0222
|
@ -60,13 +60,12 @@ def debug_model_cli(
|
|||
msg.info(f"Fixing random seed: {seed}")
|
||||
fix_random_seed(seed)
|
||||
pipe = nlp.get_pipe(component)
|
||||
if hasattr(pipe, "model"):
|
||||
model = pipe.model
|
||||
else:
|
||||
if not hasattr(pipe, "model"):
|
||||
msg.fail(
|
||||
f"The component '{component}' does not specify an object that holds a Model.",
|
||||
exits=1,
|
||||
)
|
||||
model = pipe.model
|
||||
debug_model(model, print_settings=print_settings)
|
||||
|
||||
|
||||
|
@ -87,7 +86,7 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
|||
|
||||
# STEP 1: Initializing the model and printing again
|
||||
X = _get_docs()
|
||||
Y = _get_output(model.ops.xp)
|
||||
Y = _get_output(model.ops)
|
||||
# The output vector might differ from the official type of the output layer
|
||||
with data_validation(False):
|
||||
model.initialize(X=X, Y=Y)
|
||||
|
@ -113,9 +112,11 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
|||
msg.divider(f"STEP 3 - prediction")
|
||||
msg.info(str(prediction))
|
||||
|
||||
msg.good(f"Succesfully ended analysis - model looks good!")
|
||||
|
||||
|
||||
def get_gradient(model, Y):
|
||||
goldY = _get_output(model.ops.xp)
|
||||
goldY = _get_output(model.ops)
|
||||
return Y - goldY
|
||||
|
||||
|
||||
|
@ -133,8 +134,14 @@ def _get_docs(lang: str = "en"):
|
|||
return list(nlp.pipe(_sentences()))
|
||||
|
||||
|
||||
def _get_output(xp):
|
||||
return xp.asarray([i + 10 for i, _ in enumerate(_get_docs())], dtype="float32")
|
||||
def _get_output(ops):
|
||||
docs = len(_get_docs())
|
||||
labels = 6
|
||||
output = ops.alloc2f(d0=docs, d1=labels)
|
||||
for i in range(docs):
|
||||
for j in range(labels):
|
||||
output[i, j] = 1 / (i+j+0.01)
|
||||
return ops.xp.asarray(output)
|
||||
|
||||
|
||||
def _print_model(model, print_settings):
|
||||
|
|
Loading…
Reference in New Issue
Block a user