mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
First take at dimension inference
This follows the pattern used in the Biaffine Parser, which uses an init function to get the size only after the tok2vec is available. This works at first, but serialization fails with an error.
This commit is contained in:
parent
c59aeeb0ae
commit
ba1bf8ae72
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from thinc.api import Model, chain
|
from thinc.api import Model, chain, get_width
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.util import torch, xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
@ -25,12 +25,48 @@ def build_wl_coref_model(
|
||||||
tok2vec_size: int = 768, # tok2vec size
|
tok2vec_size: int = 768, # tok2vec size
|
||||||
):
|
):
|
||||||
# TODO add model return types
|
# TODO add model return types
|
||||||
# dim = tok2vec.maybe_get_dim("n0")
|
|
||||||
|
nI = None
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
coref_clusterer = PyTorchWrapper(
|
coref_clusterer = Model(
|
||||||
|
"coref_clusterer",
|
||||||
|
forward=coref_forward,
|
||||||
|
init=coref_init,
|
||||||
|
dims={"nI": nI},
|
||||||
|
attrs={
|
||||||
|
"distance_embedding_size": distance_embedding_size,
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"depth": depth,
|
||||||
|
"dropout": dropout,
|
||||||
|
"antecedent_limit": antecedent_limit,
|
||||||
|
"antecedent_batch_size": antecedent_batch_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
coref_model = tok2vec >> coref_clusterer
|
||||||
|
return coref_model
|
||||||
|
|
||||||
|
|
||||||
|
def coref_init(model: Model, X=None, Y=None):
|
||||||
|
if model.layers:
|
||||||
|
return
|
||||||
|
|
||||||
|
if X is not None and model.has_dim("nI") is None:
|
||||||
|
model.set_dim("nI", get_width(X))
|
||||||
|
|
||||||
|
hidden_size = model.attrs["hidden_size"]
|
||||||
|
depth = model.attrs["depth"]
|
||||||
|
dropout = model.attrs["dropout"]
|
||||||
|
antecedent_limit = model.attrs["antecedent_limit"]
|
||||||
|
antecedent_batch_size = model.attrs["antecedent_batch_size"]
|
||||||
|
distance_embedding_size = model.attrs["distance_embedding_size"]
|
||||||
|
|
||||||
|
PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2")
|
||||||
|
model._layers = [
|
||||||
|
PyTorchWrapper(
|
||||||
CorefClusterer(
|
CorefClusterer(
|
||||||
tok2vec_size,
|
model.get_dim("nI"),
|
||||||
distance_embedding_size,
|
distance_embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
depth,
|
depth,
|
||||||
|
@ -41,13 +77,15 @@ def build_wl_coref_model(
|
||||||
convert_inputs=convert_coref_clusterer_inputs,
|
convert_inputs=convert_coref_clusterer_inputs,
|
||||||
convert_outputs=convert_coref_clusterer_outputs,
|
convert_outputs=convert_coref_clusterer_outputs,
|
||||||
)
|
)
|
||||||
coref_model = tok2vec >> coref_clusterer
|
# TODO maybe we need mixed precision and grad scaling?
|
||||||
return coref_model
|
]
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_inputs(
|
def coref_forward(model: Model, X, is_train: bool):
|
||||||
model: Model, X: List[Floats2d], is_train: bool
|
return model.layers[0](X, is_train)
|
||||||
):
|
|
||||||
|
|
||||||
|
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||||
# The input here is List[Floats2d], one for each doc
|
# The input here is List[Floats2d], one for each doc
|
||||||
# just use the first
|
# just use the first
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
|
@ -55,7 +93,7 @@ def convert_coref_clusterer_inputs(
|
||||||
word_features = xp2torch(X, requires_grad=is_train)
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
|
|
||||||
# TODO fix or remove type annotations
|
# TODO fix or remove type annotations
|
||||||
def backprop(args: ArgsKwargs): #-> List[Floats2d]:
|
def backprop(args: ArgsKwargs): # -> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
@ -63,9 +101,7 @@ def convert_coref_clusterer_inputs(
|
||||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
def convert_coref_clusterer_outputs(
|
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
|
||||||
model: Model, inputs_outputs, is_train: bool
|
|
||||||
):
|
|
||||||
_, outputs = inputs_outputs
|
_, outputs = inputs_outputs
|
||||||
scores, indices = outputs
|
scores, indices = outputs
|
||||||
|
|
||||||
|
@ -115,9 +151,7 @@ class CorefClusterer(torch.nn.Module):
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||||
|
|
||||||
pair_emb = dim * 3 + self.pw.shape
|
pair_emb = dim * 3 + self.pw.shape
|
||||||
self.a_scorer = AnaphoricityScorer(
|
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||||
pair_emb, hidden_size, n_layers, dropout
|
|
||||||
)
|
|
||||||
self.lstm = torch.nn.LSTM(
|
self.lstm = torch.nn.LSTM(
|
||||||
input_size=dim,
|
input_size=dim,
|
||||||
hidden_size=dim,
|
hidden_size=dim,
|
||||||
|
@ -156,10 +190,10 @@ class CorefClusterer(torch.nn.Module):
|
||||||
a_scores_lst: List[torch.Tensor] = []
|
a_scores_lst: List[torch.Tensor] = []
|
||||||
|
|
||||||
for i in range(0, len(words), batch_size):
|
for i in range(0, len(words), batch_size):
|
||||||
pw_batch = pw[i:i + batch_size]
|
pw_batch = pw[i : i + batch_size]
|
||||||
words_batch = words[i:i + batch_size]
|
words_batch = words[i : i + batch_size]
|
||||||
top_indices_batch = top_indices[i:i + batch_size]
|
top_indices_batch = top_indices[i : i + batch_size]
|
||||||
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
|
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
|
||||||
|
|
||||||
# a_scores_batch [batch_size, n_ants]
|
# a_scores_batch [batch_size, n_ants]
|
||||||
a_scores_batch = self.a_scorer(
|
a_scores_batch = self.a_scorer(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user