Span predictor leftovers

This commit is contained in:
Paul O'Leary McCann 2022-07-06 19:29:27 +09:00
parent b0800ea855
commit da81a90d64
2 changed files with 9 additions and 4 deletions

View File

@ -1,6 +1,6 @@
from typing import List, Tuple from typing import List, Tuple
from thinc.api import Model, chain, tuplify from thinc.api import Model, chain, tuplify, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d from thinc.types import Floats2d, Ints1d
from thinc.util import torch, xp2torch, torch2xp from thinc.util import torch, xp2torch, torch2xp
@ -22,6 +22,8 @@ def build_span_predictor(
): ):
# TODO add model return types # TODO add model return types
nI = None
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = Model( span_predictor = Model(
"span_predictor", "span_predictor",
@ -34,7 +36,6 @@ def build_span_predictor(
"conv_channels": conv_channels, "conv_channels": conv_channels,
"window_size": window_size, "window_size": window_size,
"max_distance": max_distance, "max_distance": max_distance,
"prefix": prefix,
}, },
) )
head_info = build_get_head_metadata(prefix) head_info = build_get_head_metadata(prefix)
@ -55,7 +56,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
conv_channels = model.attrs["conv_channels"] conv_channels = model.attrs["conv_channels"]
window_size = model.attrs["window_size"] window_size = model.attrs["window_size"]
max_distance = model.attrs["max_distance"] max_distance = model.attrs["max_distance"]
prefix = model.attrs["prefix"]
model._layers = [ model._layers = [
PyTorchWrapper( PyTorchWrapper(
@ -66,7 +66,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
conv_channels, conv_channels,
window_size, window_size,
max_distance, max_distance,
prefix,
), ),
convert_inputs=convert_span_predictor_inputs, convert_inputs=convert_span_predictor_inputs,
) )

View File

@ -275,6 +275,12 @@ class SpanPredictor(TrainablePipe):
assert len(X) > 0, Errors.E923.format(name=self.name) assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y) self.model.initialize(X=X, Y=Y)
# Store the input dimensionality. nI and nO are not stored explicitly
# for PyTorch models. This makes it tricky to reconstruct the model
# during deserialization. So, besides storing the labels, we also
# store the number of inputs.
span_predictor = self.model.get_ref("span_predictor")
self.cfg["nI"] = span_predictor.get_dim("nI")
def from_bytes(self, bytes_data, *, exclude=tuple()): def from_bytes(self, bytes_data, *, exclude=tuple()):
deserializers = { deserializers = {