mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Use config to specify tok2vec_size
This commit is contained in:
		
							parent
							
								
									1a4dbb702d
								
							
						
					
					
						commit
						619b1102e6
					
				|  | @ -25,7 +25,6 @@ def build_wl_coref_model( | |||
|     tok2vec_size: int = 768,  # tok2vec size | ||||
| ): | ||||
|     # TODO add model return types | ||||
|     tok2vec_size = 64 | ||||
| 
 | ||||
|     with Model.define_operators({">>": chain}): | ||||
|         coref_clusterer = PyTorchWrapper( | ||||
|  |  | |||
|  | @ -36,6 +36,7 @@ TRAIN_DATA = [ | |||
| # fmt: on | ||||
| 
 | ||||
| 
 | ||||
| CONFIG = {"model": {"@architectures": "spacy.Coref.v1", "tok2vec_size": 64}} | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
|  | @ -66,7 +67,7 @@ def test_not_initialized(nlp): | |||
| 
 | ||||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_initialized(nlp): | ||||
|     nlp.add_pipe("coref") | ||||
|     nlp.add_pipe("coref", config=CONFIG) | ||||
|     nlp.initialize() | ||||
|     assert nlp.pipe_names == ["coref"] | ||||
|     text = "She gave me her pen." | ||||
|  | @ -78,7 +79,7 @@ def test_initialized(nlp): | |||
| 
 | ||||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_initialized_short(nlp): | ||||
|     nlp.add_pipe("coref") | ||||
|     nlp.add_pipe("coref", config=CONFIG) | ||||
|     nlp.initialize() | ||||
|     assert nlp.pipe_names == ["coref"] | ||||
|     text = "Hi there" | ||||
|  | @ -88,7 +89,7 @@ def test_initialized_short(nlp): | |||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_coref_serialization(nlp): | ||||
|     # Test that the coref component can be serialized | ||||
|     nlp.add_pipe("coref", last=True) | ||||
|     nlp.add_pipe("coref", last=True, config=CONFIG) | ||||
|     nlp.initialize() | ||||
|     assert nlp.pipe_names == ["coref"] | ||||
|     text = "She gave me her pen." | ||||
|  | @ -110,7 +111,7 @@ def test_overfitting_IO(nlp): | |||
|     for text, annot in TRAIN_DATA: | ||||
|         train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) | ||||
| 
 | ||||
|     nlp.add_pipe("coref") | ||||
|     nlp.add_pipe("coref", config=CONFIG) | ||||
|     optimizer = nlp.initialize() | ||||
|     test_text = TRAIN_DATA[0][0] | ||||
|     doc = nlp(test_text) | ||||
|  | @ -165,7 +166,7 @@ def test_tokenization_mismatch(nlp): | |||
| 
 | ||||
|         train_examples.append(eg) | ||||
| 
 | ||||
|     nlp.add_pipe("coref") | ||||
|     nlp.add_pipe("coref", config=CONFIG) | ||||
|     optimizer = nlp.initialize() | ||||
|     test_text = TRAIN_DATA[0][0] | ||||
|     doc = nlp(test_text) | ||||
|  |  | |||
|  | @ -44,6 +44,8 @@ TRAIN_DATA = [ | |||
| ] | ||||
| # fmt: on | ||||
| 
 | ||||
| CONFIG = {"model": {"@architectures": "spacy.SpanPredictor.v1", "tok2vec_size": 64}} | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def nlp(): | ||||
|  | @ -74,7 +76,7 @@ def test_not_initialized(nlp): | |||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_span_predictor_serialization(nlp): | ||||
|     # Test that the span predictor component can be serialized | ||||
|     nlp.add_pipe("span_predictor", last=True) | ||||
|     nlp.add_pipe("span_predictor", last=True, config=CONFIG) | ||||
|     nlp.initialize() | ||||
|     assert nlp.pipe_names == ["span_predictor"] | ||||
|     text = "She gave me her pen." | ||||
|  | @ -96,7 +98,7 @@ def test_overfitting_IO(nlp): | |||
|     for text, annot in TRAIN_DATA: | ||||
|         train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) | ||||
| 
 | ||||
|     nlp.add_pipe("span_predictor") | ||||
|     nlp.add_pipe("span_predictor", config=CONFIG) | ||||
|     optimizer = nlp.initialize() | ||||
|     test_text = TRAIN_DATA[0][0] | ||||
|     doc = nlp(test_text) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user