Use config to specify tok2vec_size

This commit is contained in:
Paul O'Leary McCann 2022-07-03 15:32:35 +09:00
parent 1a4dbb702d
commit 619b1102e6
3 changed files with 10 additions and 8 deletions

View File

@ -25,7 +25,6 @@ def build_wl_coref_model(
tok2vec_size: int = 768, # tok2vec size tok2vec_size: int = 768, # tok2vec size
): ):
# TODO add model return types # TODO add model return types
tok2vec_size = 64
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper( coref_clusterer = PyTorchWrapper(

View File

@ -36,6 +36,7 @@ TRAIN_DATA = [
# fmt: on # fmt: on
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
@pytest.fixture @pytest.fixture
@ -66,7 +67,7 @@ def test_not_initialized(nlp):
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized(nlp): def test_initialized(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "She gave me her pen." text = "She gave me her pen."
@ -78,7 +79,7 @@ def test_initialized(nlp):
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_initialized_short(nlp): def test_initialized_short(nlp):
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "Hi there" text = "Hi there"
@ -88,7 +89,7 @@ def test_initialized_short(nlp):
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_coref_serialization(nlp): def test_coref_serialization(nlp):
# Test that the coref component can be serialized # Test that the coref component can be serialized
nlp.add_pipe("coref", last=True) nlp.add_pipe("coref", last=True, config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["coref"] assert nlp.pipe_names == ["coref"]
text = "She gave me her pen." text = "She gave me her pen."
@ -110,7 +111,7 @@ def test_overfitting_IO(nlp):
for text, annot in TRAIN_DATA: for text, annot in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize() optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0] test_text = TRAIN_DATA[0][0]
doc = nlp(test_text) doc = nlp(test_text)
@ -165,7 +166,7 @@ def test_tokenization_mismatch(nlp):
train_examples.append(eg) train_examples.append(eg)
nlp.add_pipe("coref") nlp.add_pipe("coref", config=CONFIG)
optimizer = nlp.initialize() optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0] test_text = TRAIN_DATA[0][0]
doc = nlp(test_text) doc = nlp(test_text)

View File

@ -44,6 +44,8 @@ TRAIN_DATA = [
] ]
# fmt: on # fmt: on
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
@pytest.fixture @pytest.fixture
def nlp(): def nlp():
@ -74,7 +76,7 @@ def test_not_initialized(nlp):
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_span_predictor_serialization(nlp): def test_span_predictor_serialization(nlp):
# Test that the span predictor component can be serialized # Test that the span predictor component can be serialized
nlp.add_pipe("span_predictor", last=True) nlp.add_pipe("span_predictor", last=True, config=CONFIG)
nlp.initialize() nlp.initialize()
assert nlp.pipe_names == ["span_predictor"] assert nlp.pipe_names == ["span_predictor"]
text = "She gave me her pen." text = "She gave me her pen."
@ -96,7 +98,7 @@ def test_overfitting_IO(nlp):
for text, annot in TRAIN_DATA: for text, annot in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
nlp.add_pipe("span_predictor") nlp.add_pipe("span_predictor", config=CONFIG)
optimizer = nlp.initialize() optimizer = nlp.initialize()
test_text = TRAIN_DATA[0][0] test_text = TRAIN_DATA[0][0]
doc = nlp(test_text) doc = nlp(test_text)