mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Initial test of mismatched tokenization
This runs, but the results are nonsense because the indices are off.
This commit is contained in:
		
							parent
							
								
									16894e665d
								
							
						
					
					
						commit
						af6d5ae2fe
					
				|  | @ -152,6 +152,62 @@ def test_overfitting_IO(nlp): | |||
|     # assert_equal(batch_deps_1, batch_deps_2) | ||||
|     # assert_equal(batch_deps_1, no_batch_deps) | ||||
| 
 | ||||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_tokenization_mismatch(nlp): | ||||
|     train_examples = [] | ||||
|     for text, annot in TRAIN_DATA: | ||||
|         eg = Example.from_dict(nlp.make_doc(text), annot) | ||||
|         ref = eg.reference | ||||
|         char_spans = {} | ||||
|         for key, cluster in ref.spans.items(): | ||||
|             char_spans[key] = [] | ||||
|             for span in cluster: | ||||
|                 char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) ) | ||||
|         with ref.retokenize() as retokenizer: | ||||
|             # merge "many friends" | ||||
|             retokenizer.merge(ref[5:7]) | ||||
| 
 | ||||
|         # Note this works because it's the same doc and we know the keys | ||||
|         for key, _ in ref.spans.items(): | ||||
|             spans = char_spans[key] | ||||
|             ref.spans[key] = [ref.char_span(*span) for span in spans] | ||||
| 
 | ||||
|         train_examples.append(eg) | ||||
| 
 | ||||
|     nlp.add_pipe("coref") | ||||
|     optimizer = nlp.initialize() | ||||
|     test_text = TRAIN_DATA[0][0] | ||||
|     doc = nlp(test_text) | ||||
| 
 | ||||
|     for i in range(15): | ||||
|         losses = {} | ||||
|         nlp.update(train_examples, sgd=optimizer, losses=losses) | ||||
|         doc = nlp(test_text) | ||||
|         print(i, doc.spans) | ||||
| 
 | ||||
|     # test the trained model | ||||
|     doc = nlp(test_text) | ||||
| 
 | ||||
|     # Also test the results are still the same after IO | ||||
|     with make_tempdir() as tmp_dir: | ||||
|         nlp.to_disk(tmp_dir) | ||||
|         nlp2 = util.load_model_from_path(tmp_dir) | ||||
|         doc2 = nlp2(test_text) | ||||
| 
 | ||||
|     # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions | ||||
|     texts = [ | ||||
|         test_text, | ||||
|         "I noticed many friends around me", | ||||
|         "They received it. They received the SMS.", | ||||
|     ] | ||||
| 
 | ||||
|     # save the docs so they don't get garbage collected | ||||
|     docs = list(nlp.pipe(texts)) | ||||
|     batch_deps_1 = [doc.spans for doc in docs] | ||||
|     docs = list(nlp.pipe(texts)) | ||||
|     batch_deps_2 = [doc.spans for doc in docs] | ||||
|     docs = [nlp(text) for text in texts] | ||||
|     no_batch_deps = [doc.spans for doc in docs] | ||||
| 
 | ||||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_crossing_spans(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user