Span predictor leftovers

This commit is contained in:
Paul O'Leary McCann 2022-07-06 19:29:27 +09:00
parent b0800ea855
commit da81a90d64
2 changed files with 9 additions and 4 deletions

View File

@ -1,6 +1,6 @@
from typing import List, Tuple
from thinc.api import Model, chain, tuplify
from thinc.api import Model, chain, tuplify, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d
from thinc.util import torch, xp2torch, torch2xp
@ -22,6 +22,8 @@ def build_span_predictor(
):
# TODO add model return types
nI = None
with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = Model(
"span_predictor",
@ -34,7 +36,6 @@ def build_span_predictor(
"conv_channels": conv_channels,
"window_size": window_size,
"max_distance": max_distance,
"prefix": prefix,
},
)
head_info = build_get_head_metadata(prefix)
@ -55,7 +56,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
conv_channels = model.attrs["conv_channels"]
window_size = model.attrs["window_size"]
max_distance = model.attrs["max_distance"]
prefix = model.attrs["prefix"]
model._layers = [
PyTorchWrapper(
@ -66,7 +66,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
conv_channels,
window_size,
max_distance,
prefix,
),
convert_inputs=convert_span_predictor_inputs,
)

View File

@ -275,6 +275,12 @@ class SpanPredictor(TrainablePipe):
assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)
# Store the input dimensionality. nI and nO are not stored explicitly
# for PyTorch models. This makes it tricky to reconstruct the model
# during deserialization. So, besides storing the labels, we also
# store the number of inputs.
span_predictor = self.model.get_ref("span_predictor")
self.cfg["nI"] = span_predictor.get_dim("nI")
def from_bytes(self, bytes_data, *, exclude=tuple()):
deserializers = {