From da81a90d64dbb156ba3abafcf7aaa6ff8d126e81 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 6 Jul 2022 19:29:27 +0900 Subject: [PATCH] Span predictor leftovers --- spacy/ml/models/span_predictor.py | 7 +++---- spacy/pipeline/span_predictor.py | 6 ++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 55a966c1d..4e394ed78 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -1,6 +1,6 @@ from typing import List, Tuple -from thinc.api import Model, chain, tuplify +from thinc.api import Model, chain, tuplify, get_width from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Ints1d from thinc.util import torch, xp2torch, torch2xp @@ -22,6 +22,8 @@ def build_span_predictor( ): # TODO add model return types + nI = None + with Model.define_operators({">>": chain, "&": tuplify}): span_predictor = Model( "span_predictor", @@ -34,7 +36,6 @@ def build_span_predictor( "conv_channels": conv_channels, "window_size": window_size, "max_distance": max_distance, - "prefix": prefix, }, ) head_info = build_get_head_metadata(prefix) @@ -55,7 +56,6 @@ def span_predictor_init(model: Model, X=None, Y=None): conv_channels = model.attrs["conv_channels"] window_size = model.attrs["window_size"] max_distance = model.attrs["max_distance"] - prefix = model.attrs["prefix"] model._layers = [ PyTorchWrapper( @@ -66,7 +66,6 @@ def span_predictor_init(model: Model, X=None, Y=None): conv_channels, window_size, max_distance, - prefix, ), convert_inputs=convert_span_predictor_inputs, ) diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index eed6ce9f8..b5f25cd81 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -275,6 +275,12 @@ class SpanPredictor(TrainablePipe): assert len(X) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=X, Y=Y) + # Store the input dimensionality. nI and nO are not stored explicitly + # for PyTorch models. This makes it tricky to reconstruct the model + # during deserialization. So, besides storing the labels, we also + # store the number of inputs. + span_predictor = self.model.get_ref("span_predictor") + self.cfg["nI"] = span_predictor.get_dim("nI") def from_bytes(self, bytes_data, *, exclude=tuple()): deserializers = {