mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Do dimension inference in span predictor
This commit is contained in:
parent
b59b924e49
commit
b0800ea855
|
@ -13,7 +13,6 @@ from .coref_util import get_sentence_ids
|
|||
@registry.architectures("spacy.SpanPredictor.v1")
|
||||
def build_span_predictor(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
tok2vec_size: int = 768,
|
||||
hidden_size: int = 1024,
|
||||
distance_embedding_size: int = 64,
|
||||
conv_channels: int = 4,
|
||||
|
@ -24,23 +23,58 @@ def build_span_predictor(
|
|||
# TODO add model return types
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
span_predictor = Model(
|
||||
"span_predictor",
|
||||
forward=span_predictor_forward,
|
||||
init=span_predictor_init,
|
||||
dims={"nI": nI},
|
||||
attrs={
|
||||
"distance_embedding_size": distance_embedding_size,
|
||||
"hidden_size": hidden_size,
|
||||
"conv_channels": conv_channels,
|
||||
"window_size": window_size,
|
||||
"max_distance": max_distance,
|
||||
"prefix": prefix,
|
||||
},
|
||||
)
|
||||
head_info = build_get_head_metadata(prefix)
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
model.set_ref("span_predictor", span_predictor)
|
||||
|
||||
return model
|
||||
|
||||
def span_predictor_init(model: Model, X=None, Y=None):
|
||||
if model.layers:
|
||||
return
|
||||
|
||||
if X is not None and model.has_dim("nI") is None:
|
||||
model.set_dim("nI", get_width(X))
|
||||
|
||||
hidden_size = model.attrs["hidden_size"]
|
||||
distance_embedding_size = model.attrs["distance_embedding_size"]
|
||||
conv_channels = model.attrs["conv_channels"]
|
||||
window_size = model.attrs["window_size"]
|
||||
max_distance = model.attrs["max_distance"]
|
||||
prefix = model.attrs["prefix"]
|
||||
|
||||
model._layers = [
|
||||
PyTorchWrapper(
|
||||
SpanPredictor(
|
||||
tok2vec_size,
|
||||
model.get_dim("nI"),
|
||||
hidden_size,
|
||||
distance_embedding_size,
|
||||
conv_channels,
|
||||
window_size,
|
||||
max_distance,
|
||||
prefix,
|
||||
),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
head_info = build_get_head_metadata(prefix)
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
|
||||
return model
|
||||
# TODO maybe we need mixed precision and grad scaling?
|
||||
]
|
||||
|
||||
def span_predictor_forward(model: Model, X, is_train: bool):
|
||||
return model.layers[0](X, is_train)
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool
|
||||
|
|
|
@ -5,6 +5,7 @@ from thinc.types import Floats2d, Floats3d, Ints2d
|
|||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
import srsly
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
@ -13,7 +14,7 @@ from ..errors import Errors
|
|||
from ..scorer import Scorer, doc2clusters
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..util import registry
|
||||
from ..util import registry, from_bytes, from_disk
|
||||
|
||||
from ..ml.models.coref_util import (
|
||||
MentionClusters,
|
||||
|
@ -23,7 +24,6 @@ from ..ml.models.coref_util import (
|
|||
default_span_predictor_config = """
|
||||
[model]
|
||||
@architectures = "spacy.SpanPredictor.v1"
|
||||
tok2vec_size = 768
|
||||
hidden_size = 1024
|
||||
distance_embedding_size = 64
|
||||
conv_channels = 4
|
||||
|
@ -274,3 +274,51 @@ class SpanPredictor(TrainablePipe):
|
|||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
deserializers = {
|
||||
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||
}
|
||||
from_bytes(bytes_data, deserializers, exclude)
|
||||
|
||||
self._initialize_from_disk()
|
||||
|
||||
model_deserializers = {
|
||||
"model": lambda b: self.model.from_bytes(b),
|
||||
}
|
||||
from_bytes(bytes_data, model_deserializers, exclude)
|
||||
|
||||
return self
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def load_model(p):
|
||||
try:
|
||||
with open(p, "rb") as mfile:
|
||||
self.model.from_bytes(mfile.read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
|
||||
deserializers = {
|
||||
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||
}
|
||||
from_disk(path, deserializers, exclude)
|
||||
|
||||
self._initialize_from_disk()
|
||||
|
||||
model_deserializers = {
|
||||
"model": load_model,
|
||||
}
|
||||
from_disk(path, model_deserializers, exclude)
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_from_disk(self):
|
||||
# The PyTorch model is constructed lazily, so we need to
|
||||
# explicitly initialize the model before deserialization.
|
||||
model = self.model.get_ref("span_predictor")
|
||||
if model.has_dim("nI") is None:
|
||||
model.set_dim("nI", self.cfg["nI"])
|
||||
self.model.initialize()
|
||||
|
|
Loading…
Reference in New Issue
Block a user