From d1ff933e9b77b723b7ed326a6c339340fd47d318 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 28 Jun 2022 19:15:33 +0900 Subject: [PATCH] Test works This may not be done yet, as the test is just for consistency, and not overfitting correctly yet. --- spacy/errors.py | 1 + spacy/ml/models/coref_util.py | 12 +++++---- spacy/pipeline/coref.py | 14 +++++++++- spacy/tests/pipeline/test_coref.py | 41 +++++++++++++++++------------- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index c82ffe882..837bfd740 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes): E1035 = ("Token index {i} out of bounds ({length})") E1036 = ("Cannot index into NoneNode") E1037 = ("Invalid attribute value '{attr}'.") + E1038 = ("Misalignment in coref. Head token has no match in training doc.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index a004a69d7..bd577e65f 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int): def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]: - """Given a Doc, convert the cluster spans to simple int tuple lists.""" + """Given a Doc, convert the cluster spans to simple int tuple lists. The + ints are char spans, to be tokenization independent. + """ out = [] for key, val in doc.spans.items(): cluster = [] for span in val: - # TODO check that there isn't an off-by-one error here - # cluster.append((span.start, span.end)) - # TODO This conversion should be happening earlier in processing + head_i = span.root.i - cluster.append((head_i, head_i + 1)) + head = doc[head_i] + char_span = (head.idx, head.idx + len(head)) + cluster.append(char_span) # don't want duplicates cluster = list(set(cluster)) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index cd07f80e8..630502f6d 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -267,7 +267,19 @@ class CoreferenceResolver(TrainablePipe): example = list(examples)[0] cidx = mention_idx - clusters = get_clusters_from_doc(example.reference) + clusters_by_char = get_clusters_from_doc(example.reference) + # convert to token clusters, and give up if necessary + clusters = [] + for cluster in clusters_by_char: + cc = [] + for start_char, end_char in cluster: + span = example.predicted.char_span(start_char, end_char) + if span is None: + # TODO log more details + raise IndexError(Errors.E1038) + cc.append( (span.start, span.end) ) + clusters.append(cc) + span_idxs = create_head_span_idxs(ops, len(example.predicted)) gscores = create_gold_scores(span_idxs, clusters) # TODO fix type here. This is bools but asarray2f wants ints. diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 358da6b03..584db99b8 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -34,6 +34,15 @@ TRAIN_DATA = [ ] # fmt: on +def spans2ints(doc): + """Convert doc.spans to nested list of ints for comparison. + + This is useful for checking consistency of predictions. + """ + out = [] + for key, cluster in doc.spans.items(): + out.append( [(ss.start, ss.end) for ss in cluster] ) + return out @pytest.fixture def nlp(): @@ -108,7 +117,7 @@ def test_coref_serialization(nlp): @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_overfitting_IO(nlp): - # Simple test to try and quickly overfit the senter - ensuring the ML models work correctly + # Simple test to try and quickly overfit - ensuring the ML models work correctly train_examples = [] for text, annot in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annot)) @@ -117,25 +126,21 @@ def test_overfitting_IO(nlp): optimizer = nlp.initialize() test_text = TRAIN_DATA[0][0] doc = nlp(test_text) - print("BEFORE", doc.spans) - for i in range(5): + # Needs ~12 epochs to converge + for i in range(15): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) doc = nlp(test_text) - print(i, doc.spans) - print(losses["coref"]) # < 0.001 # test the trained model doc = nlp(test_text) - print("AFTER", doc.spans) # Also test the results are still the same after IO with make_tempdir() as tmp_dir: nlp.to_disk(tmp_dir) nlp2 = util.load_model_from_path(tmp_dir) doc2 = nlp2(test_text) - print("doc2", doc2.spans) # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions texts = [ @@ -143,12 +148,16 @@ def test_overfitting_IO(nlp): "I noticed many friends around me", "They received it. They received the SMS.", ] - batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)] + docs = list(nlp.pipe(texts)) + batch_deps_1 = [doc.spans for doc in docs] print(batch_deps_1) - batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)] + docs = list(nlp.pipe(texts)) + batch_deps_2 = [doc.spans for doc in docs] print(batch_deps_2) - no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]] + docs = [nlp(text) for text in texts] + no_batch_deps = [doc.spans for doc in docs] print(no_batch_deps) + print("FINISH") # assert_equal(batch_deps_1, batch_deps_2) # assert_equal(batch_deps_1, no_batch_deps) @@ -183,7 +192,6 @@ def test_tokenization_mismatch(nlp): losses = {} nlp.update(train_examples, sgd=optimizer, losses=losses) doc = nlp(test_text) - print(i, doc.spans) # test the trained model doc = nlp(test_text) @@ -202,12 +210,11 @@ def test_tokenization_mismatch(nlp): ] # save the docs so they don't get garbage collected - docs = list(nlp.pipe(texts)) - batch_deps_1 = [doc.spans for doc in docs] - docs = list(nlp.pipe(texts)) - batch_deps_2 = [doc.spans for doc in docs] - docs = [nlp(text) for text in texts] - no_batch_deps = [doc.spans for doc in docs] + docs1 = list(nlp.pipe(texts)) + docs2 = list(nlp.pipe(texts)) + docs3 = [nlp(text) for text in texts] + assert spans2ints(docs1[0]) == spans2ints(docs2[0]) + assert spans2ints(docs1[0]) == spans2ints(docs3[0]) @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_crossing_spans():