mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Span predictor leftovers
This commit is contained in:
parent
b0800ea855
commit
da81a90d64
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from thinc.api import Model, chain, tuplify
|
from thinc.api import Model, chain, tuplify, get_width
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints1d
|
from thinc.types import Floats2d, Ints1d
|
||||||
from thinc.util import torch, xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
@ -22,6 +22,8 @@ def build_span_predictor(
|
||||||
):
|
):
|
||||||
# TODO add model return types
|
# TODO add model return types
|
||||||
|
|
||||||
|
nI = None
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
span_predictor = Model(
|
span_predictor = Model(
|
||||||
"span_predictor",
|
"span_predictor",
|
||||||
|
@ -34,7 +36,6 @@ def build_span_predictor(
|
||||||
"conv_channels": conv_channels,
|
"conv_channels": conv_channels,
|
||||||
"window_size": window_size,
|
"window_size": window_size,
|
||||||
"max_distance": max_distance,
|
"max_distance": max_distance,
|
||||||
"prefix": prefix,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
head_info = build_get_head_metadata(prefix)
|
head_info = build_get_head_metadata(prefix)
|
||||||
|
@ -55,7 +56,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
|
||||||
conv_channels = model.attrs["conv_channels"]
|
conv_channels = model.attrs["conv_channels"]
|
||||||
window_size = model.attrs["window_size"]
|
window_size = model.attrs["window_size"]
|
||||||
max_distance = model.attrs["max_distance"]
|
max_distance = model.attrs["max_distance"]
|
||||||
prefix = model.attrs["prefix"]
|
|
||||||
|
|
||||||
model._layers = [
|
model._layers = [
|
||||||
PyTorchWrapper(
|
PyTorchWrapper(
|
||||||
|
@ -66,7 +66,6 @@ def span_predictor_init(model: Model, X=None, Y=None):
|
||||||
conv_channels,
|
conv_channels,
|
||||||
window_size,
|
window_size,
|
||||||
max_distance,
|
max_distance,
|
||||||
prefix,
|
|
||||||
),
|
),
|
||||||
convert_inputs=convert_span_predictor_inputs,
|
convert_inputs=convert_span_predictor_inputs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -275,6 +275,12 @@ class SpanPredictor(TrainablePipe):
|
||||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||||
|
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||||
|
# during deserialization. So, besides storing the labels, we also
|
||||||
|
# store the number of inputs.
|
||||||
|
span_predictor = self.model.get_ref("span_predictor")
|
||||||
|
self.cfg["nI"] = span_predictor.get_dim("nI")
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
deserializers = {
|
deserializers = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user