From 619b1102e66c68a8ecb9db31e1959764b29035ab Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Sun, 3 Jul 2022 15:32:35 +0900 Subject: [PATCH] Use config to specify tok2vec_size --- spacy/ml/models/coref.py | 1 - spacy/tests/pipeline/test_coref.py | 11 ++++++----- spacy/tests/pipeline/test_span_predictor.py | 6 ++++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 1963a4127..22234390e 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -25,7 +25,6 @@ def build_wl_coref_model( tok2vec_size: int = 768, # tok2vec size ): # TODO add model return types - tok2vec_size = 64 with Model.define_operators({">>": chain}): coref_clusterer = PyTorchWrapper( diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 3bde6ad34..89906c87b 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -36,6 +36,7 @@ TRAIN_DATA = [ # fmt: on +CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}} @pytest.fixture @@ -66,7 +67,7 @@ def test_not_initialized(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized(nlp): - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["coref"] text = "She gave me her pen." @@ -78,7 +79,7 @@ def test_initialized(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_initialized_short(nlp): - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["coref"] text = "Hi there" @@ -88,7 +89,7 @@ def test_initialized_short(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_coref_serialization(nlp): # Test that the coref component can be serialized - nlp.add_pipe("coref", last=True) + nlp.add_pipe("coref", last=True, config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["coref"] text = "She gave me her pen." @@ -110,7 +111,7 @@ def test_overfitting_IO(nlp): for text, annot in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) optimizer = nlp.initialize() test_text = TRAIN_DATA[0][0] doc = nlp(test_text) @@ -165,7 +166,7 @@ def test_tokenization_mismatch(nlp): train_examples.append(eg) - nlp.add_pipe("coref") + nlp.add_pipe("coref", config=CONFIG) optimizer = nlp.initialize() test_text = TRAIN_DATA[0][0] doc = nlp(test_text) diff --git a/spacy/tests/pipeline/test_span_predictor.py b/spacy/tests/pipeline/test_span_predictor.py index 1adaecd3f..7d7a75279 100644 --- a/spacy/tests/pipeline/test_span_predictor.py +++ b/spacy/tests/pipeline/test_span_predictor.py @@ -44,6 +44,8 @@ TRAIN_DATA = [ ] # fmt: on +CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}} + @pytest.fixture def nlp(): @@ -74,7 +76,7 @@ def test_not_initialized(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_span_predictor_serialization(nlp): # Test that the span predictor component can be serialized - nlp.add_pipe("span_predictor", last=True) + nlp.add_pipe("span_predictor", last=True, config=CONFIG) nlp.initialize() assert nlp.pipe_names == ["span_predictor"] text = "She gave me her pen." @@ -96,7 +98,7 @@ def test_overfitting_IO(nlp): for text, annot in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) - nlp.add_pipe("span_predictor") + nlp.add_pipe("span_predictor", config=CONFIG) optimizer = nlp.initialize() test_text = TRAIN_DATA[0][0] doc = nlp(test_text)