mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Use config to specify tok2vec_size
This commit is contained in:
parent
1a4dbb702d
commit
619b1102e6
|
@ -25,7 +25,6 @@ def build_wl_coref_model(
|
||||||
tok2vec_size: int = 768, # tok2vec size
|
tok2vec_size: int = 768, # tok2vec size
|
||||||
):
|
):
|
||||||
# TODO add model return types
|
# TODO add model return types
|
||||||
tok2vec_size = 64
|
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
coref_clusterer = PyTorchWrapper(
|
coref_clusterer = PyTorchWrapper(
|
||||||
|
|
|
@ -36,6 +36,7 @@ TRAIN_DATA = [
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -66,7 +67,7 @@ def test_not_initialized(nlp):
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_initialized(nlp):
|
def test_initialized(nlp):
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -78,7 +79,7 @@ def test_initialized(nlp):
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_initialized_short(nlp):
|
def test_initialized_short(nlp):
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "Hi there"
|
text = "Hi there"
|
||||||
|
@ -88,7 +89,7 @@ def test_initialized_short(nlp):
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_coref_serialization(nlp):
|
def test_coref_serialization(nlp):
|
||||||
# Test that the coref component can be serialized
|
# Test that the coref component can be serialized
|
||||||
nlp.add_pipe("coref", last=True)
|
nlp.add_pipe("coref", last=True, config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["coref"]
|
assert nlp.pipe_names == ["coref"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -110,7 +111,7 @@ def test_overfitting_IO(nlp):
|
||||||
for text, annot in TRAIN_DATA:
|
for text, annot in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -165,7 +166,7 @@ def test_tokenization_mismatch(nlp):
|
||||||
|
|
||||||
train_examples.append(eg)
|
train_examples.append(eg)
|
||||||
|
|
||||||
nlp.add_pipe("coref")
|
nlp.add_pipe("coref", config=CONFIG)
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
|
@ -44,6 +44,8 @@ TRAIN_DATA = [
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nlp():
|
def nlp():
|
||||||
|
@ -74,7 +76,7 @@ def test_not_initialized(nlp):
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_span_predictor_serialization(nlp):
|
def test_span_predictor_serialization(nlp):
|
||||||
# Test that the span predictor component can be serialized
|
# Test that the span predictor component can be serialized
|
||||||
nlp.add_pipe("span_predictor", last=True)
|
nlp.add_pipe("span_predictor", last=True, config=CONFIG)
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert nlp.pipe_names == ["span_predictor"]
|
assert nlp.pipe_names == ["span_predictor"]
|
||||||
text = "She gave me her pen."
|
text = "She gave me her pen."
|
||||||
|
@ -96,7 +98,7 @@ def test_overfitting_IO(nlp):
|
||||||
for text, annot in TRAIN_DATA:
|
for text, annot in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
|
||||||
nlp.add_pipe("span_predictor")
|
nlp.add_pipe("span_predictor", config=CONFIG)
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user