mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Test works
This may not be done yet, as the test is just for consistency, and not overfitting correctly yet.
This commit is contained in:
		
							parent
							
								
									ef5762d78e
								
							
						
					
					
						commit
						d1ff933e9b
					
				|  | @ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes): | |||
|     E1035 = ("Token index {i} out of bounds ({length})") | ||||
|     E1036 = ("Cannot index into NoneNode") | ||||
|     E1037 = ("Invalid attribute value '{attr}'.") | ||||
|     E1038 = ("Misalignment in coref. Head token has no match in training doc.") | ||||
| 
 | ||||
| 
 | ||||
| # Deprecated model shortcuts, only used in errors and warnings | ||||
|  |  | |||
|  | @ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int): | |||
| 
 | ||||
| 
 | ||||
| def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: | ||||
|     """Given a Doc, convert the cluster spans to simple int tuple lists.""" | ||||
|     """Given a Doc, convert the cluster spans to simple int tuple lists. The  | ||||
|     ints are char spans, to be tokenization independent. | ||||
|     """ | ||||
|     out = [] | ||||
|     for key, val in doc.spans.items(): | ||||
|         cluster = [] | ||||
|         for span in val: | ||||
|             # TODO check that there isn't an off-by-one error here | ||||
|             # cluster.append((span.start, span.end)) | ||||
|             # TODO This conversion should be happening earlier in processing | ||||
| 
 | ||||
|             head_i = span.root.i | ||||
|             cluster.append((head_i, head_i + 1)) | ||||
|             head = doc[head_i] | ||||
|             char_span = (head.idx, head.idx + len(head)) | ||||
|             cluster.append(char_span) | ||||
| 
 | ||||
|         # don't want duplicates | ||||
|         cluster = list(set(cluster)) | ||||
|  |  | |||
|  | @ -267,7 +267,19 @@ class CoreferenceResolver(TrainablePipe): | |||
|         example = list(examples)[0] | ||||
|         cidx = mention_idx | ||||
| 
 | ||||
|         clusters = get_clusters_from_doc(example.reference) | ||||
|         clusters_by_char = get_clusters_from_doc(example.reference) | ||||
|         # convert to token clusters, and give up if necessary | ||||
|         clusters = [] | ||||
|         for cluster in clusters_by_char: | ||||
|             cc = [] | ||||
|             for start_char, end_char in cluster: | ||||
|                 span = example.predicted.char_span(start_char, end_char) | ||||
|                 if span is None: | ||||
|                     # TODO log more details | ||||
|                     raise IndexError(Errors.E1038) | ||||
|                 cc.append( (span.start, span.end) ) | ||||
|             clusters.append(cc) | ||||
| 
 | ||||
|         span_idxs = create_head_span_idxs(ops, len(example.predicted)) | ||||
|         gscores = create_gold_scores(span_idxs, clusters) | ||||
|         # TODO fix type here. This is bools but asarray2f wants ints. | ||||
|  |  | |||
|  | @ -34,6 +34,15 @@ TRAIN_DATA = [ | |||
| ] | ||||
| # fmt: on | ||||
| 
 | ||||
| def spans2ints(doc): | ||||
|     """Convert doc.spans to nested list of ints for comparison. | ||||
| 
 | ||||
|     This is useful for checking consistency of predictions. | ||||
|     """ | ||||
|     out = [] | ||||
|     for key, cluster in doc.spans.items(): | ||||
|         out.append( [(ss.start, ss.end) for ss in cluster] ) | ||||
|     return out | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def nlp(): | ||||
|  | @ -108,7 +117,7 @@ def test_coref_serialization(nlp): | |||
| 
 | ||||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_overfitting_IO(nlp): | ||||
|     # Simple test to try and quickly overfit the senter - ensuring the ML models work correctly | ||||
|     # Simple test to try and quickly overfit - ensuring the ML models work correctly | ||||
|     train_examples = [] | ||||
|     for text, annot in TRAIN_DATA: | ||||
|         train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) | ||||
|  | @ -117,25 +126,21 @@ def test_overfitting_IO(nlp): | |||
|     optimizer = nlp.initialize() | ||||
|     test_text = TRAIN_DATA[0][0] | ||||
|     doc = nlp(test_text) | ||||
|     print("BEFORE", doc.spans) | ||||
| 
 | ||||
|     for i in range(5): | ||||
|     # Needs ~12 epochs to converge | ||||
|     for i in range(15): | ||||
|         losses = {} | ||||
|         nlp.update(train_examples, sgd=optimizer, losses=losses) | ||||
|         doc = nlp(test_text) | ||||
|         print(i, doc.spans) | ||||
|     print(losses["coref"])  # < 0.001 | ||||
| 
 | ||||
|     # test the trained model | ||||
|     doc = nlp(test_text) | ||||
|     print("AFTER", doc.spans) | ||||
| 
 | ||||
|     # Also test the results are still the same after IO | ||||
|     with make_tempdir() as tmp_dir: | ||||
|         nlp.to_disk(tmp_dir) | ||||
|         nlp2 = util.load_model_from_path(tmp_dir) | ||||
|         doc2 = nlp2(test_text) | ||||
|         print("doc2", doc2.spans) | ||||
| 
 | ||||
|     # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions | ||||
|     texts = [ | ||||
|  | @ -143,12 +148,16 @@ def test_overfitting_IO(nlp): | |||
|         "I noticed many friends around me", | ||||
|         "They received it. They received the SMS.", | ||||
|     ] | ||||
|     batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)] | ||||
|     docs = list(nlp.pipe(texts)) | ||||
|     batch_deps_1 = [doc.spans for doc in docs] | ||||
|     print(batch_deps_1) | ||||
|     batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)] | ||||
|     docs = list(nlp.pipe(texts)) | ||||
|     batch_deps_2 = [doc.spans for doc in docs] | ||||
|     print(batch_deps_2) | ||||
|     no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]] | ||||
|     docs = [nlp(text) for text in texts] | ||||
|     no_batch_deps = [doc.spans for doc in docs] | ||||
|     print(no_batch_deps) | ||||
|     print("FINISH") | ||||
|     # assert_equal(batch_deps_1, batch_deps_2) | ||||
|     # assert_equal(batch_deps_1, no_batch_deps) | ||||
| 
 | ||||
|  | @ -183,7 +192,6 @@ def test_tokenization_mismatch(nlp): | |||
|         losses = {} | ||||
|         nlp.update(train_examples, sgd=optimizer, losses=losses) | ||||
|         doc = nlp(test_text) | ||||
|         print(i, doc.spans) | ||||
| 
 | ||||
|     # test the trained model | ||||
|     doc = nlp(test_text) | ||||
|  | @ -202,12 +210,11 @@ def test_tokenization_mismatch(nlp): | |||
|     ] | ||||
| 
 | ||||
|     # save the docs so they don't get garbage collected | ||||
|     docs = list(nlp.pipe(texts)) | ||||
|     batch_deps_1 = [doc.spans for doc in docs] | ||||
|     docs = list(nlp.pipe(texts)) | ||||
|     batch_deps_2 = [doc.spans for doc in docs] | ||||
|     docs = [nlp(text) for text in texts] | ||||
|     no_batch_deps = [doc.spans for doc in docs] | ||||
|     docs1 = list(nlp.pipe(texts)) | ||||
|     docs2 = list(nlp.pipe(texts)) | ||||
|     docs3 = [nlp(text) for text in texts] | ||||
|     assert spans2ints(docs1[0]) == spans2ints(docs2[0]) | ||||
|     assert spans2ints(docs1[0]) == spans2ints(docs3[0]) | ||||
| 
 | ||||
| @pytest.mark.skipif(not has_torch, reason="Torch not available") | ||||
| def test_crossing_spans(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user