mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-17 19:52:18 +03:00
Add type annotations for internal models
This commit is contained in:
parent
4d032396b8
commit
baeb35f31b
|
@ -22,11 +22,11 @@ def build_wl_coref_model(
|
|||
# pairs to keep per mention after rough scoring
|
||||
antecedent_limit: int = 50,
|
||||
antecedent_batch_size: int = 512,
|
||||
nI = None
|
||||
nI=None,
|
||||
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
coref_clusterer = Model(
|
||||
coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model(
|
||||
"coref_clusterer",
|
||||
forward=coref_forward,
|
||||
init=coref_init,
|
||||
|
@ -81,6 +81,7 @@ def coref_init(model: Model, X=None, Y=None):
|
|||
def coref_forward(model: Model, X, is_train: bool):
|
||||
return model.layers[0](X, is_train)
|
||||
|
||||
|
||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
# just use the first
|
||||
|
|
|
@ -25,7 +25,7 @@ def build_span_predictor(
|
|||
nI = None
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = Model(
|
||||
span_predictor: Model[List[Floats2d], List[Floats2d]] = Model(
|
||||
"span_predictor",
|
||||
forward=span_predictor_forward,
|
||||
init=span_predictor_init,
|
||||
|
@ -44,6 +44,7 @@ def build_span_predictor(
|
|||
|
||||
return model
|
||||
|
||||
|
||||
def span_predictor_init(model: Model, X=None, Y=None):
|
||||
if model.layers:
|
||||
return
|
||||
|
@ -72,9 +73,11 @@ def span_predictor_init(model: Model, X=None, Y=None):
|
|||
# TODO maybe we need mixed precision and grad scaling?
|
||||
]
|
||||
|
||||
|
||||
def span_predictor_forward(model: Model, X, is_train: bool):
|
||||
return model.layers[0](X, is_train)
|
||||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model,
|
||||
X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]],
|
||||
|
@ -95,7 +98,9 @@ def convert_span_predictor_inputs(
|
|||
else:
|
||||
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={})
|
||||
argskwargs = ArgsKwargs(
|
||||
args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}
|
||||
)
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user