mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Initial test of mismatched tokenization
This runs, but the results are nonsense because the indices are off.
This commit is contained in:
parent
16894e665d
commit
af6d5ae2fe
|
@ -152,6 +152,62 @@ def test_overfitting_IO(nlp):
|
||||||
# assert_equal(batch_deps_1, batch_deps_2)
|
# assert_equal(batch_deps_1, batch_deps_2)
|
||||||
# assert_equal(batch_deps_1, no_batch_deps)
|
# assert_equal(batch_deps_1, no_batch_deps)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
|
def test_tokenization_mismatch(nlp):
|
||||||
|
train_examples = []
|
||||||
|
for text, annot in TRAIN_DATA:
|
||||||
|
eg = Example.from_dict(nlp.make_doc(text), annot)
|
||||||
|
ref = eg.reference
|
||||||
|
char_spans = {}
|
||||||
|
for key, cluster in ref.spans.items():
|
||||||
|
char_spans[key] = []
|
||||||
|
for span in cluster:
|
||||||
|
char_spans[key].append( (span[0].idx, span[-1].idx + len(span[-1])) )
|
||||||
|
with ref.retokenize() as retokenizer:
|
||||||
|
# merge "many friends"
|
||||||
|
retokenizer.merge(ref[5:7])
|
||||||
|
|
||||||
|
# Note this works because it's the same doc and we know the keys
|
||||||
|
for key, _ in ref.spans.items():
|
||||||
|
spans = char_spans[key]
|
||||||
|
ref.spans[key] = [ref.char_span(*span) for span in spans]
|
||||||
|
|
||||||
|
train_examples.append(eg)
|
||||||
|
|
||||||
|
nlp.add_pipe("coref")
|
||||||
|
optimizer = nlp.initialize()
|
||||||
|
test_text = TRAIN_DATA[0][0]
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
for i in range(15):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
doc = nlp(test_text)
|
||||||
|
print(i, doc.spans)
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
doc = nlp(test_text)
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
|
||||||
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
|
texts = [
|
||||||
|
test_text,
|
||||||
|
"I noticed many friends around me",
|
||||||
|
"They received it. They received the SMS.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# save the docs so they don't get garbage collected
|
||||||
|
docs = list(nlp.pipe(texts))
|
||||||
|
batch_deps_1 = [doc.spans for doc in docs]
|
||||||
|
docs = list(nlp.pipe(texts))
|
||||||
|
batch_deps_2 = [doc.spans for doc in docs]
|
||||||
|
docs = [nlp(text) for text in texts]
|
||||||
|
no_batch_deps = [doc.spans for doc in docs]
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_crossing_spans():
|
def test_crossing_spans():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user