mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Update tests
This commit is contained in:
parent
fd574a89c4
commit
cf33b48fe0
|
@ -207,11 +207,13 @@ def create_gold_scores(
|
||||||
|
|
||||||
def spans2ints(doc):
|
def spans2ints(doc):
|
||||||
"""Convert doc.spans to nested list of ints for comparison.
|
"""Convert doc.spans to nested list of ints for comparison.
|
||||||
The ints are token indices.
|
The ints are character indices, and the spans groups are sorted by key first.
|
||||||
|
|
||||||
This is useful for checking consistency of predictions.
|
This is useful for checking consistency of predictions.
|
||||||
"""
|
"""
|
||||||
out = []
|
out = []
|
||||||
for key, cluster in doc.spans.items():
|
keys = sorted([key for key in doc.spans])
|
||||||
out.append([(ss.start, ss.end) for ss in cluster])
|
for key in keys:
|
||||||
|
cluster = doc.spans[key]
|
||||||
|
out.append([(ss.start_char, ss.end_char) for ss in cluster])
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -114,13 +114,15 @@ def test_overfitting_IO(nlp):
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
||||||
for i in range(1500):
|
for i in range(15):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
||||||
# test the trained model, using the pred since it has heads
|
# test the trained model, using the pred since it has heads
|
||||||
doc = nlp(train_examples[0].predicted)
|
doc = nlp(train_examples[0].predicted)
|
||||||
|
# XXX This actually tests that it can overfit
|
||||||
|
assert spans2ints(doc) == spans2ints(train_examples[0].reference)
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
|
@ -134,6 +136,7 @@ def test_overfitting_IO(nlp):
|
||||||
"I noticed many friends around me",
|
"I noticed many friends around me",
|
||||||
"They received it. They received the SMS.",
|
"They received it. They received the SMS.",
|
||||||
]
|
]
|
||||||
|
# XXX Note these have no predictions because they have no input spans
|
||||||
docs1 = list(nlp.pipe(texts))
|
docs1 = list(nlp.pipe(texts))
|
||||||
docs2 = list(nlp.pipe(texts))
|
docs2 = list(nlp.pipe(texts))
|
||||||
docs3 = [nlp(text) for text in texts]
|
docs3 = [nlp(text) for text in texts]
|
||||||
|
@ -175,7 +178,7 @@ def test_tokenization_mismatch(nlp):
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(15):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -183,12 +186,8 @@ def test_tokenization_mismatch(nlp):
|
||||||
# test the trained model; need to use doc with head spans on it already
|
# test the trained model; need to use doc with head spans on it already
|
||||||
test_doc = train_examples[0].predicted
|
test_doc = train_examples[0].predicted
|
||||||
doc = nlp(test_doc)
|
doc = nlp(test_doc)
|
||||||
|
# XXX This actually tests that it can overfit
|
||||||
# XXX DEBUG
|
assert spans2ints(doc) == spans2ints(train_examples[0].reference)
|
||||||
print("SPANS", len(doc.spans))
|
|
||||||
for key, val in doc.spans.items():
|
|
||||||
print(key, val)
|
|
||||||
print("...")
|
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
|
@ -209,5 +208,4 @@ def test_tokenization_mismatch(nlp):
|
||||||
docs3 = [nlp(text) for text in texts]
|
docs3 = [nlp(text) for text in texts]
|
||||||
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
|
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
|
||||||
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
|
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user