mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
Merge pull request #11089 from polm/coref/dimension-inference
Dimension inference in Coref
This commit is contained in:
commit
90973faf9e
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Tuple, Callable, cast
|
from typing import List, Tuple, Callable, cast
|
||||||
|
|
||||||
from thinc.api import Model, chain
|
from thinc.api import Model, chain, get_width
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints2d
|
from thinc.types import Floats2d, Ints2d
|
||||||
from thinc.util import torch, xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
@ -22,13 +22,48 @@ def build_wl_coref_model(
|
||||||
# pairs to keep per mention after rough scoring
|
# pairs to keep per mention after rough scoring
|
||||||
antecedent_limit: int = 50,
|
antecedent_limit: int = 50,
|
||||||
antecedent_batch_size: int = 512,
|
antecedent_batch_size: int = 512,
|
||||||
tok2vec_size: int = 768, # tok2vec size
|
nI=None,
|
||||||
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
coref_clusterer = PyTorchWrapper(
|
coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model(
|
||||||
|
"coref_clusterer",
|
||||||
|
forward=coref_forward,
|
||||||
|
init=coref_init,
|
||||||
|
dims={"nI": nI},
|
||||||
|
attrs={
|
||||||
|
"distance_embedding_size": distance_embedding_size,
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"depth": depth,
|
||||||
|
"dropout": dropout,
|
||||||
|
"antecedent_limit": antecedent_limit,
|
||||||
|
"antecedent_batch_size": antecedent_batch_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model = tok2vec >> coref_clusterer
|
||||||
|
model.set_ref("coref_clusterer", coref_clusterer)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def coref_init(model: Model, X=None, Y=None):
|
||||||
|
if model.layers:
|
||||||
|
return
|
||||||
|
|
||||||
|
if X is not None and model.has_dim("nI") is None:
|
||||||
|
model.set_dim("nI", get_width(X))
|
||||||
|
|
||||||
|
hidden_size = model.attrs["hidden_size"]
|
||||||
|
depth = model.attrs["depth"]
|
||||||
|
dropout = model.attrs["dropout"]
|
||||||
|
antecedent_limit = model.attrs["antecedent_limit"]
|
||||||
|
antecedent_batch_size = model.attrs["antecedent_batch_size"]
|
||||||
|
distance_embedding_size = model.attrs["distance_embedding_size"]
|
||||||
|
|
||||||
|
model._layers = [
|
||||||
|
PyTorchWrapper(
|
||||||
CorefClusterer(
|
CorefClusterer(
|
||||||
tok2vec_size,
|
model.get_dim("nI"),
|
||||||
distance_embedding_size,
|
distance_embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
depth,
|
depth,
|
||||||
|
@ -39,10 +74,13 @@ def build_wl_coref_model(
|
||||||
convert_inputs=convert_coref_clusterer_inputs,
|
convert_inputs=convert_coref_clusterer_inputs,
|
||||||
convert_outputs=convert_coref_clusterer_outputs,
|
convert_outputs=convert_coref_clusterer_outputs,
|
||||||
)
|
)
|
||||||
coref_model = tok2vec >> coref_clusterer
|
# TODO maybe we need mixed precision and grad scaling?
|
||||||
return coref_model
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def coref_forward(model: Model, X, is_train: bool):
|
||||||
|
return model.layers[0](X, is_train)
|
||||||
|
|
||||||
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
||||||
# The input here is List[Floats2d], one for each doc
|
# The input here is List[Floats2d], one for each doc
|
||||||
# just use the first
|
# just use the first
|
||||||
|
|
|
@ -147,7 +147,9 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
ints are char spans, to be tokenization independent.
|
ints are char spans, to be tokenization independent.
|
||||||
"""
|
"""
|
||||||
out = []
|
out = []
|
||||||
for key, val in doc.spans.items():
|
keys = sorted(list(doc.spans.keys()))
|
||||||
|
for key in keys:
|
||||||
|
val = doc.spans[key]
|
||||||
cluster = []
|
cluster = []
|
||||||
for span in val:
|
for span in val:
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Tuple, cast
|
from typing import List, Tuple, cast
|
||||||
|
|
||||||
from thinc.api import Model, chain, tuplify
|
from thinc.api import Model, chain, tuplify, get_width
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints1d
|
from thinc.types import Floats2d, Ints1d
|
||||||
from thinc.util import torch, xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
@ -13,7 +13,6 @@ from .coref_util import get_sentence_ids
|
||||||
@registry.architectures("spacy.SpanPredictor.v1")
|
@registry.architectures("spacy.SpanPredictor.v1")
|
||||||
def build_span_predictor(
|
def build_span_predictor(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
tok2vec_size: int = 768,
|
|
||||||
hidden_size: int = 1024,
|
hidden_size: int = 1024,
|
||||||
distance_embedding_size: int = 64,
|
distance_embedding_size: int = 64,
|
||||||
conv_channels: int = 4,
|
conv_channels: int = 4,
|
||||||
|
@ -23,10 +22,46 @@ def build_span_predictor(
|
||||||
):
|
):
|
||||||
# TODO add model return types
|
# TODO add model return types
|
||||||
|
|
||||||
|
nI = None
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor: Model[List[Floats2d], List[Floats2d]] = Model(
|
||||||
|
"span_predictor",
|
||||||
|
forward=span_predictor_forward,
|
||||||
|
init=span_predictor_init,
|
||||||
|
dims={"nI": nI},
|
||||||
|
attrs={
|
||||||
|
"distance_embedding_size": distance_embedding_size,
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"conv_channels": conv_channels,
|
||||||
|
"window_size": window_size,
|
||||||
|
"max_distance": max_distance,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
head_info = build_get_head_metadata(prefix)
|
||||||
|
model = (tok2vec & head_info) >> span_predictor
|
||||||
|
model.set_ref("span_predictor", span_predictor)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def span_predictor_init(model: Model, X=None, Y=None):
|
||||||
|
if model.layers:
|
||||||
|
return
|
||||||
|
|
||||||
|
if X is not None and model.has_dim("nI") is None:
|
||||||
|
model.set_dim("nI", get_width(X))
|
||||||
|
|
||||||
|
hidden_size = model.attrs["hidden_size"]
|
||||||
|
distance_embedding_size = model.attrs["distance_embedding_size"]
|
||||||
|
conv_channels = model.attrs["conv_channels"]
|
||||||
|
window_size = model.attrs["window_size"]
|
||||||
|
max_distance = model.attrs["max_distance"]
|
||||||
|
|
||||||
|
model._layers = [
|
||||||
|
PyTorchWrapper(
|
||||||
SpanPredictor(
|
SpanPredictor(
|
||||||
tok2vec_size,
|
model.get_dim("nI"),
|
||||||
hidden_size,
|
hidden_size,
|
||||||
distance_embedding_size,
|
distance_embedding_size,
|
||||||
conv_channels,
|
conv_channels,
|
||||||
|
@ -35,10 +70,12 @@ def build_span_predictor(
|
||||||
),
|
),
|
||||||
convert_inputs=convert_span_predictor_inputs,
|
convert_inputs=convert_span_predictor_inputs,
|
||||||
)
|
)
|
||||||
head_info = build_get_head_metadata(prefix)
|
# TODO maybe we need mixed precision and grad scaling?
|
||||||
model = (tok2vec & head_info) >> span_predictor
|
]
|
||||||
|
|
||||||
return model
|
|
||||||
|
def span_predictor_forward(model: Model, X, is_train: bool):
|
||||||
|
return model.layers[0](X, is_train)
|
||||||
|
|
||||||
|
|
||||||
def convert_span_predictor_inputs(
|
def convert_span_predictor_inputs(
|
||||||
|
@ -61,7 +98,9 @@ def convert_span_predictor_inputs(
|
||||||
else:
|
else:
|
||||||
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
|
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)
|
||||||
|
|
||||||
argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={})
|
argskwargs = ArgsKwargs(
|
||||||
|
args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}
|
||||||
|
)
|
||||||
return argskwargs, backprop
|
return argskwargs, backprop
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from thinc.api import Model, Config, Optimizer
|
||||||
from thinc.api import set_dropout_rate, to_categorical
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
import srsly
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
@ -13,7 +14,7 @@ from ..training import Example, validate_examples, validate_get_examples
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..util import registry
|
from ..util import registry, from_disk, from_bytes
|
||||||
|
|
||||||
from ..ml.models.coref_util import (
|
from ..ml.models.coref_util import (
|
||||||
create_gold_scores,
|
create_gold_scores,
|
||||||
|
@ -30,7 +31,6 @@ from ..scorer import Scorer
|
||||||
default_config = """
|
default_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.Coref.v1"
|
@architectures = "spacy.Coref.v1"
|
||||||
tok2vec_size = 768
|
|
||||||
distance_embedding_size = 20
|
distance_embedding_size = 20
|
||||||
hidden_size = 1024
|
hidden_size = 1024
|
||||||
depth = 1
|
depth = 1
|
||||||
|
@ -340,3 +340,57 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||||
|
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||||
|
# during deserialization. So, besides storing the labels, we also
|
||||||
|
# store the number of inputs.
|
||||||
|
coref_clusterer = self.model.get_ref("coref_clusterer")
|
||||||
|
self.cfg["nI"] = coref_clusterer.get_dim("nI")
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||||
|
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||||
|
}
|
||||||
|
from_bytes(bytes_data, deserializers, exclude)
|
||||||
|
|
||||||
|
self._initialize_from_disk()
|
||||||
|
|
||||||
|
model_deserializers = {
|
||||||
|
"model": lambda b: self.model.from_bytes(b),
|
||||||
|
}
|
||||||
|
from_bytes(bytes_data, model_deserializers, exclude)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple()):
|
||||||
|
def load_model(p):
|
||||||
|
try:
|
||||||
|
with open(p, "rb") as mfile:
|
||||||
|
self.model.from_bytes(mfile.read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||||
|
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||||
|
}
|
||||||
|
from_disk(path, deserializers, exclude)
|
||||||
|
|
||||||
|
self._initialize_from_disk()
|
||||||
|
|
||||||
|
model_deserializers = {
|
||||||
|
"model": load_model,
|
||||||
|
}
|
||||||
|
from_disk(path, model_deserializers, exclude)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _initialize_from_disk(self):
|
||||||
|
# The PyTorch model is constructed lazily, so we need to
|
||||||
|
# explicitly initialize the model before deserialization.
|
||||||
|
model = self.model.get_ref("coref_clusterer")
|
||||||
|
if model.has_dim("nI") is None:
|
||||||
|
model.set_dim("nI", self.cfg["nI"])
|
||||||
|
self.model.initialize()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from thinc.types import Floats2d, Floats3d, Ints2d
|
||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||||
from thinc.api import set_dropout_rate, to_categorical
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
import srsly
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
@ -13,7 +14,7 @@ from ..errors import Errors
|
||||||
from ..scorer import Scorer, doc2clusters
|
from ..scorer import Scorer, doc2clusters
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..util import registry
|
from ..util import registry, from_bytes, from_disk
|
||||||
|
|
||||||
from ..ml.models.coref_util import (
|
from ..ml.models.coref_util import (
|
||||||
MentionClusters,
|
MentionClusters,
|
||||||
|
@ -23,7 +24,6 @@ from ..ml.models.coref_util import (
|
||||||
default_span_predictor_config = """
|
default_span_predictor_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.SpanPredictor.v1"
|
@architectures = "spacy.SpanPredictor.v1"
|
||||||
tok2vec_size = 768
|
|
||||||
hidden_size = 1024
|
hidden_size = 1024
|
||||||
distance_embedding_size = 64
|
distance_embedding_size = 64
|
||||||
conv_channels = 4
|
conv_channels = 4
|
||||||
|
@ -346,3 +346,57 @@ class SpanPredictor(TrainablePipe):
|
||||||
|
|
||||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||||
|
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||||
|
# during deserialization. So, besides storing the labels, we also
|
||||||
|
# store the number of inputs.
|
||||||
|
span_predictor = self.model.get_ref("span_predictor")
|
||||||
|
self.cfg["nI"] = span_predictor.get_dim("nI")
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||||
|
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||||
|
}
|
||||||
|
from_bytes(bytes_data, deserializers, exclude)
|
||||||
|
|
||||||
|
self._initialize_from_disk()
|
||||||
|
|
||||||
|
model_deserializers = {
|
||||||
|
"model": lambda b: self.model.from_bytes(b),
|
||||||
|
}
|
||||||
|
from_bytes(bytes_data, model_deserializers, exclude)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple()):
|
||||||
|
def load_model(p):
|
||||||
|
try:
|
||||||
|
with open(p, "rb") as mfile:
|
||||||
|
self.model.from_bytes(mfile.read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||||
|
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||||
|
}
|
||||||
|
from_disk(path, deserializers, exclude)
|
||||||
|
|
||||||
|
self._initialize_from_disk()
|
||||||
|
|
||||||
|
model_deserializers = {
|
||||||
|
"model": load_model,
|
||||||
|
}
|
||||||
|
from_disk(path, model_deserializers, exclude)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _initialize_from_disk(self):
|
||||||
|
# The PyTorch model is constructed lazily, so we need to
|
||||||
|
# explicitly initialize the model before deserialization.
|
||||||
|
model = self.model.get_ref("span_predictor")
|
||||||
|
if model.has_dim("nI") is None:
|
||||||
|
model.set_dim("nI", self.cfg["nI"])
|
||||||
|
self.model.initialize()
|
||||||
|
|
|
@ -36,9 +36,6 @@ TRAIN_DATA = [
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nlp():
|
def nlp():
|
||||||
return English()
|
return English()
|
||||||
|
@ -67,7 +64,7 @@ def test_not_initialized(nlp):
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_initialized(nlp):
|
def test_initialized(nlp):
|
||||||
nlp.add_pipe("coref", config=CONFIG)
|
nlp.add_pipe("coref")
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -79,7 +76,7 @@ def test_initialized(nlp):
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_initialized_short(nlp):
|
def test_initialized_short(nlp):
|
||||||
nlp.add_pipe("coref", config=CONFIG)
|
nlp.add_pipe("coref")
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "Hi there"
|
text = "Hi there"
|
||||||
|
@ -89,7 +86,7 @@ def test_initialized_short(nlp):
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_coref_serialization(nlp):
|
def test_coref_serialization(nlp):
|
||||||
# Test that the coref component can be serialized
|
# Test that the coref component can be serialized
|
||||||
nlp.add_pipe("coref", last=True, config=CONFIG)
|
nlp.add_pipe("coref", last=True)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -111,7 +108,7 @@ def test_overfitting_IO(nlp):
|
||||||
for text, annot in TRAIN_DATA:
|
for text, annot in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
|
||||||
nlp.add_pipe("coref", config=CONFIG)
|
nlp.add_pipe("coref")
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -166,7 +163,7 @@ def test_tokenization_mismatch(nlp):
|
||||||
|
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
|
|
||||||
nlp.add_pipe("coref", config=CONFIG)
|
nlp.add_pipe("coref")
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -228,7 +225,7 @@ def test_whitespace_mismatch(nlp):
|
||||||
eg.predicted = nlp.make_doc(" " + text)
|
eg.predicted = nlp.make_doc(" " + text)
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
|
|
||||||
nlp.add_pipe("coref", config=CONFIG)
|
nlp.add_pipe("coref")
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
|
@ -44,8 +44,6 @@ TRAIN_DATA = [
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nlp():
|
def nlp():
|
||||||
|
@ -76,7 +74,7 @@ def test_not_initialized(nlp):
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_span_predictor_serialization(nlp):
|
def test_span_predictor_serialization(nlp):
|
||||||
# Test that the span predictor component can be serialized
|
# Test that the span predictor component can be serialized
|
||||||
nlp.add_pipe("span_predictor", last=True, config=CONFIG)
|
nlp.add_pipe("span_predictor", last=True)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["span_predictor"]
|
assert nlp.pipe_names == ["span_predictor"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -109,7 +107,7 @@ def test_overfitting_IO(nlp):
|
||||||
pred.spans[key] = [pred[span.start : span.end] for span in spans]
|
pred.spans[key] = [pred[span.start : span.end] for span in spans]
|
||||||
|
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
nlp.add_pipe("span_predictor")
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -173,7 +171,7 @@ def test_tokenization_mismatch(nlp):
|
||||||
|
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
|
|
||||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
nlp.add_pipe("span_predictor")
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -218,7 +216,7 @@ def test_whitespace_mismatch(nlp):
|
||||||
eg.predicted = nlp.make_doc(" " + text)
|
eg.predicted = nlp.make_doc(" " + text)
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
|
|
||||||
nlp.add_pipe("span_predictor", config=CONFIG)
|
nlp.add_pipe("span_predictor")
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user