mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-17 19:52:18 +03:00
Span predictor leftovers
This commit is contained in:
parent
b0800ea855
commit
da81a90d64
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
from thinc.api import Model, chain, tuplify
|
||||
from thinc.api import Model, chain, tuplify, get_width
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Ints1d
|
||||
from thinc.util import torch, xp2torch, torch2xp
|
||||
|
@ -22,6 +22,8 @@ def build_span_predictor(
|
|||
):
|
||||
# TODO add model return types
|
||||
|
||||
nI = None
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = Model(
|
||||
"span_predictor",
|
||||
|
@ -34,7 +36,6 @@ def build_span_predictor(
|
|||
"conv_channels": conv_channels,
|
||||
"window_size": window_size,
|
||||
"max_distance": max_distance,
|
||||
"prefix": prefix,
|
||||
},
|
||||
)
|
||||
head_info = build_get_head_metadata(prefix)
|
||||
|
@ -55,7 +56,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
|
|||
conv_channels = model.attrs["conv_channels"]
|
||||
window_size = model.attrs["window_size"]
|
||||
max_distance = model.attrs["max_distance"]
|
||||
prefix = model.attrs["prefix"]
|
||||
|
||||
model._layers = [
|
||||
PyTorchWrapper(
|
||||
|
@ -66,7 +66,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
|
|||
conv_channels,
|
||||
window_size,
|
||||
max_distance,
|
||||
prefix,
|
||||
),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
|
|
|
@ -275,6 +275,12 @@ class SpanPredictor(TrainablePipe):
|
|||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||
# during deserialization. So, besides storing the labels, we also
|
||||
# store the number of inputs.
|
||||
span_predictor = self.model.get_ref("span_predictor")
|
||||
self.cfg["nI"] = span_predictor.get_dim("nI")
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
deserializers = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user