mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-17 19:52:18 +03:00
Merge pull request #11089 from polm/coref/dimension-inference
Dimension inference in Coref
This commit is contained in:
commit
90973faf9e
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Tuple, Callable, cast
|
||||
|
||||
from thinc.api import Model, chain
|
||||
from thinc.api import Model, chain, get_width
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Ints2d
|
||||
from thinc.util import torch, xp2torch, torch2xp
|
||||
|
@ -22,13 +22,48 @@ def build_wl_coref_model(
|
|||
# pairs to keep per mention after rough scoring
|
||||
antecedent_limit: int = 50,
|
||||
antecedent_batch_size: int = 512,
|
||||
tok2vec_size: int = 768, # tok2vec size
|
||||
nI=None,
|
||||
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
coref_clusterer = PyTorchWrapper(
|
||||
coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model(
|
||||
"coref_clusterer",
|
||||
forward=coref_forward,
|
||||
init=coref_init,
|
||||
dims={"nI": nI},
|
||||
attrs={
|
||||
"distance_embedding_size": distance_embedding_size,
|
||||
"hidden_size": hidden_size,
|
||||
"depth": depth,
|
||||
"dropout": dropout,
|
||||
"antecedent_limit": antecedent_limit,
|
||||
"antecedent_batch_size": antecedent_batch_size,
|
||||
},
|
||||
)
|
||||
|
||||
model = tok2vec >> coref_clusterer
|
||||
model.set_ref("coref_clusterer", coref_clusterer)
|
||||
return model
|
||||
|
||||
|
||||
def coref_init(model: Model, X=None, Y=None):
|
||||
if model.layers:
|
||||
return
|
||||
|
||||
if X is not None and model.has_dim("nI") is None:
|
||||
model.set_dim("nI", get_width(X))
|
||||
|
||||
hidden_size = model.attrs["hidden_size"]
|
||||
depth = model.attrs["depth"]
|
||||
dropout = model.attrs["dropout"]
|
||||
antecedent_limit = model.attrs["antecedent_limit"]
|
||||
antecedent_batch_size = model.attrs["antecedent_batch_size"]
|
||||
distance_embedding_size = model.attrs["distance_embedding_size"]
|
||||
|
||||
model._layers = [
|
||||
PyTorchWrapper(
|
||||
CorefClusterer(
|
||||
tok2vec_size,
|
||||
model.get_dim("nI"),
|
||||
distance_embedding_size,
|
||||
hidden_size,
|
||||
depth,
|
||||
|
@ -39,10 +74,13 @@ def build_wl_coref_model(
|
|||
convert_inputs=convert_coref_clusterer_inputs,
|
||||
convert_outputs=convert_coref_clusterer_outputs,
|
||||
)
|
||||
coref_model = tok2vec >> coref_clusterer
|
||||
return coref_model
|
||||
# TODO maybe we need mixed precision and grad scaling?
|
||||
]
|
||||
|
||||
|
||||
def coref_forward(model: Model, X, is_train: bool):
|
||||
return model.layers[0](X, is_train)
|
||||
|
||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||
# The input here is List[Floats2d], one for each doc
|
||||
# just use the first
|
||||
|
|
|
@ -147,7 +147,9 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
|||
ints are char spans, to be tokenization independent.
|
||||
"""
|
||||
out = []
|
||||
for key, val in doc.spans.items():
|
||||
keys = sorted(list(doc.spans.keys()))
|
||||
for key in keys:
|
||||
val = doc.spans[key]
|
||||
cluster = []
|
||||
for span in val:
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Tuple, cast
|
||||
|
||||
from thinc.api import Model, chain, tuplify
|
||||
from thinc.api import Model, chain, tuplify, get_width
|
||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||
from thinc.types import Floats2d, Ints1d
|
||||
from thinc.util import torch, xp2torch, torch2xp
|
||||
|
@ -13,7 +13,6 @@ from .coref_util import get_sentence_ids
|
|||
@registry.architectures("spacy.SpanPredictor.v1")
|
||||
def build_span_predictor(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
tok2vec_size: int = 768,
|
||||
hidden_size: int = 1024,
|
||||
distance_embedding_size: int = 64,
|
||||
conv_channels: int = 4,
|
||||
|
@ -23,10 +22,46 @@ def build_span_predictor(
|
|||
):
|
||||
# TODO add model return types
|
||||
|
||||
nI = None
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
span_predictor: Model[List[Floats2d], List[Floats2d]] = Model(
|
||||
"span_predictor",
|
||||
forward=span_predictor_forward,
|
||||
init=span_predictor_init,
|
||||
dims={"nI": nI},
|
||||
attrs={
|
||||
"distance_embedding_size": distance_embedding_size,
|
||||
"hidden_size": hidden_size,
|
||||
"conv_channels": conv_channels,
|
||||
"window_size": window_size,
|
||||
"max_distance": max_distance,
|
||||
},
|
||||
)
|
||||
head_info = build_get_head_metadata(prefix)
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
model.set_ref("span_predictor", span_predictor)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def span_predictor_init(model: Model, X=None, Y=None):
|
||||
if model.layers:
|
||||
return
|
||||
|
||||
if X is not None and model.has_dim("nI") is None:
|
||||
model.set_dim("nI", get_width(X))
|
||||
|
||||
hidden_size = model.attrs["hidden_size"]
|
||||
distance_embedding_size = model.attrs["distance_embedding_size"]
|
||||
conv_channels = model.attrs["conv_channels"]
|
||||
window_size = model.attrs["window_size"]
|
||||
max_distance = model.attrs["max_distance"]
|
||||
|
||||
model._layers = [
|
||||
PyTorchWrapper(
|
||||
SpanPredictor(
|
||||
tok2vec_size,
|
||||
model.get_dim("nI"),
|
||||
hidden_size,
|
||||
distance_embedding_size,
|
||||
conv_channels,
|
||||
|
@ -35,10 +70,12 @@ def build_span_predictor(
|
|||
),
|
||||
convert_inputs=convert_span_predictor_inputs,
|
||||
)
|
||||
head_info = build_get_head_metadata(prefix)
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
# TODO maybe we need mixed precision and grad scaling?
|
||||
]
|
||||
|
||||
return model
|
||||
|
||||
def span_predictor_forward(model: Model, X, is_train: bool):
|
||||
return model.layers[0](X, is_train)
|
||||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
|
@ -61,7 +98,9 @@ def convert_span_predictor_inputs(
|
|||
else:
|
||||
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={})
|
||||
argskwargs = ArgsKwargs(
|
||||
args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}
|
||||
)
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from thinc.api import Model, Config, Optimizer
|
|||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
from statistics import mean
|
||||
import srsly
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
@ -13,7 +14,7 @@ from ..training import Example, validate_examples, validate_get_examples
|
|||
from ..errors import Errors
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..util import registry
|
||||
from ..util import registry, from_disk, from_bytes
|
||||
|
||||
from ..ml.models.coref_util import (
|
||||
create_gold_scores,
|
||||
|
@ -30,7 +31,6 @@ from ..scorer import Scorer
|
|||
default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Coref.v1"
|
||||
tok2vec_size = 768
|
||||
distance_embedding_size = 20
|
||||
hidden_size = 1024
|
||||
depth = 1
|
||||
|
@ -340,3 +340,57 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||
# during deserialization. So, besides storing the labels, we also
|
||||
# store the number of inputs.
|
||||
coref_clusterer = self.model.get_ref("coref_clusterer")
|
||||
self.cfg["nI"] = coref_clusterer.get_dim("nI")
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
deserializers = {
|
||||
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||
}
|
||||
from_bytes(bytes_data, deserializers, exclude)
|
||||
|
||||
self._initialize_from_disk()
|
||||
|
||||
model_deserializers = {
|
||||
"model": lambda b: self.model.from_bytes(b),
|
||||
}
|
||||
from_bytes(bytes_data, model_deserializers, exclude)
|
||||
|
||||
return self
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def load_model(p):
|
||||
try:
|
||||
with open(p, "rb") as mfile:
|
||||
self.model.from_bytes(mfile.read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
|
||||
deserializers = {
|
||||
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||
}
|
||||
from_disk(path, deserializers, exclude)
|
||||
|
||||
self._initialize_from_disk()
|
||||
|
||||
model_deserializers = {
|
||||
"model": load_model,
|
||||
}
|
||||
from_disk(path, model_deserializers, exclude)
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_from_disk(self):
|
||||
# The PyTorch model is constructed lazily, so we need to
|
||||
# explicitly initialize the model before deserialization.
|
||||
model = self.model.get_ref("coref_clusterer")
|
||||
if model.has_dim("nI") is None:
|
||||
model.set_dim("nI", self.cfg["nI"])
|
||||
self.model.initialize()
|
||||
|
|
|
@ -5,6 +5,7 @@ from thinc.types import Floats2d, Floats3d, Ints2d
|
|||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
import srsly
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
@ -13,7 +14,7 @@ from ..errors import Errors
|
|||
from ..scorer import Scorer, doc2clusters
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..util import registry
|
||||
from ..util import registry, from_bytes, from_disk
|
||||
|
||||
from ..ml.models.coref_util import (
|
||||
MentionClusters,
|
||||
|
@ -23,7 +24,6 @@ from ..ml.models.coref_util import (
|
|||
default_span_predictor_config = """
|
||||
[model]
|
||||
@architectures = "spacy.SpanPredictor.v1"
|
||||
tok2vec_size = 768
|
||||
hidden_size = 1024
|
||||
distance_embedding_size = 64
|
||||
conv_channels = 4
|
||||
|
@ -346,3 +346,57 @@ class SpanPredictor(TrainablePipe):
|
|||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||
# during deserialization. So, besides storing the labels, we also
|
||||
# store the number of inputs.
|
||||
span_predictor = self.model.get_ref("span_predictor")
|
||||
self.cfg["nI"] = span_predictor.get_dim("nI")
|
||||
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
deserializers = {
|
||||
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||
}
|
||||
from_bytes(bytes_data, deserializers, exclude)
|
||||
|
||||
self._initialize_from_disk()
|
||||
|
||||
model_deserializers = {
|
||||
"model": lambda b: self.model.from_bytes(b),
|
||||
}
|
||||
from_bytes(bytes_data, model_deserializers, exclude)
|
||||
|
||||
return self
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def load_model(p):
|
||||
try:
|
||||
with open(p, "rb") as mfile:
|
||||
self.model.from_bytes(mfile.read())
|
||||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
|
||||
deserializers = {
|
||||
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||
}
|
||||
from_disk(path, deserializers, exclude)
|
||||
|
||||
self._initialize_from_disk()
|
||||
|
||||
model_deserializers = {
|
||||
"model": load_model,
|
||||
}
|
||||
from_disk(path, model_deserializers, exclude)
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_from_disk(self):
|
||||
# The PyTorch model is constructed lazily, so we need to
|
||||
# explicitly initialize the model before deserialization.
|
||||
model = self.model.get_ref("span_predictor")
|
||||
if model.has_dim("nI") is None:
|
||||
model.set_dim("nI", self.cfg["nI"])
|
||||
self.model.initialize()
|
||||
|
|
|
@ -36,9 +36,6 @@ TRAIN_DATA = [
|
|||
# fmt: on
|
||||
|
||||
|
||||
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
@ -67,7 +64,7 @@ def test_not_initialized(nlp):
|
|||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_initialized(nlp):
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.add_pipe("coref")
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "She gave me her pen."
|
||||
|
@ -79,7 +76,7 @@ def test_initialized(nlp):
|
|||
|
||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_initialized_short(nlp):
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.add_pipe("coref")
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "Hi there"
|
||||
|
@ -89,7 +86,7 @@ def test_initialized_short(nlp):
|
|||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_coref_serialization(nlp):
|
||||
# Test that the coref component can be serialized
|
||||
nlp.add_pipe("coref", last=True, config=CONFIG)
|
||||
nlp.add_pipe("coref", last=True)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["coref"]
|
||||
text = "She gave me her pen."
|
||||
|
@ -111,7 +108,7 @@ def test_overfitting_IO(nlp):
|
|||
for text, annot in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.add_pipe("coref")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
@ -166,7 +163,7 @@ def test_tokenization_mismatch(nlp):
|
|||
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.add_pipe("coref")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
@ -228,7 +225,7 @@ def test_whitespace_mismatch(nlp):
|
|||
eg.predicted = nlp.make_doc(" " + text)
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("coref", config=CONFIG)
|
||||
nlp.add_pipe("coref")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
|
|
@ -44,8 +44,6 @@ TRAIN_DATA = [
|
|||
]
|
||||
# fmt: on
|
||||
|
||||
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
|
@ -76,7 +74,7 @@ def test_not_initialized(nlp):
|
|||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||
def test_span_predictor_serialization(nlp):
|
||||
# Test that the span predictor component can be serialized
|
||||
nlp.add_pipe("span_predictor", last=True, config=CONFIG)
|
||||
nlp.add_pipe("span_predictor", last=True)
|
||||
nlp.initialize()
|
||||
assert nlp.pipe_names == ["span_predictor"]
|
||||
text = "She gave me her pen."
|
||||
|
@ -109,7 +107,7 @@ def test_overfitting_IO(nlp):
|
|||
pred.spans[key] = [pred[span.start : span.end] for span in spans]
|
||||
|
||||
train_examples.append(eg)
|
||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||
nlp.add_pipe("span_predictor")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
@ -173,7 +171,7 @@ def test_tokenization_mismatch(nlp):
|
|||
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||
nlp.add_pipe("span_predictor")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
@ -218,7 +216,7 @@ def test_whitespace_mismatch(nlp):
|
|||
eg.predicted = nlp.make_doc(" " + text)
|
||||
train_examples.append(eg)
|
||||
|
||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||
nlp.add_pipe("span_predictor")
|
||||
optimizer = nlp.initialize()
|
||||
test_text = TRAIN_DATA[0][0]
|
||||
doc = nlp(test_text)
|
||||
|
|
Loading…
Reference in New Issue
Block a user