diff --git a/spacy/tests/pipeline/test_coref.py b/spacy/tests/pipeline/test_coref.py index 53f0b2011..358da6b03 100644 --- a/spacy/tests/pipeline/test_coref.py +++ b/spacy/tests/pipeline/test_coref.py @@ -152,6 +152,62 @@ def test_overfitting_IO(nlp): # assert_equal(batch_deps_1, batch_deps_2) # assert_equal(batch_deps_1, no_batch_deps) +@pytest.mark.skipif(not has_torch, reason="Torch not available") +def test_tokenization_mismatch(nlp): + train_examples = [] + for text, annot in TRAIN_DATA: + eg = Example.from_dict(nlp.make_doc(text), annot) + ref = eg.reference + char_spans = {} + for key, cluster in ref.spans.items(): + char_spans[key] = [] + for span in cluster: + char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) ) + with ref.retokenize() as retokenizer: + # merge "many friends" + retokenizer.merge(ref[5:7]) + + # Note this works because it's the same doc and we know the keys + for key, _ in ref.spans.items(): + spans = char_spans[key] + ref.spans[key] = [ref.char_span(*span) for span in spans] + + train_examples.append(eg) + + nlp.add_pipe("coref") + optimizer = nlp.initialize() + test_text = TRAIN_DATA[0][0] + doc = nlp(test_text) + + for i in range(15): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + doc = nlp(test_text) + print(i, doc.spans) + + # test the trained model + doc = nlp(test_text) + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + doc2 = nlp2(test_text) + + # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions + texts = [ + test_text, + "I noticed many friends around me", + "They received it. They received the SMS.", + ] + + # save the docs so they don't get garbage collected + docs = list(nlp.pipe(texts)) + batch_deps_1 = [doc.spans for doc in docs] + docs = list(nlp.pipe(texts)) + batch_deps_2 = [doc.spans for doc in docs] + docs = [nlp(text) for text in texts] + no_batch_deps = [doc.spans for doc in docs] @pytest.mark.skipif(not has_torch, reason="Torch not available") def test_crossing_spans():