mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Test works
This may not be done yet, as the test is just for consistency, and not overfitting correctly yet.
This commit is contained in:
parent
ef5762d78e
commit
d1ff933e9b
|
@ -919,6 +919,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E1035 = ("Token index {i} out of bounds ({length})")
|
E1035 = ("Token index {i} out of bounds ({length})")
|
||||||
E1036 = ("Cannot index into NoneNode")
|
E1036 = ("Cannot index into NoneNode")
|
||||||
E1037 = ("Invalid attribute value '{attr}'.")
|
E1037 = ("Invalid attribute value '{attr}'.")
|
||||||
|
E1038 = ("Misalignment in coref. Head token has no match in training doc.")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -143,16 +143,18 @@ def create_head_span_idxs(ops, doclen: int):
|
||||||
|
|
||||||
|
|
||||||
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
|
||||||
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
|
"""Given a Doc, convert the cluster spans to simple int tuple lists. The
|
||||||
|
ints are char spans, to be tokenization independent.
|
||||||
|
"""
|
||||||
out = []
|
out = []
|
||||||
for key, val in doc.spans.items():
|
for key, val in doc.spans.items():
|
||||||
cluster = []
|
cluster = []
|
||||||
for span in val:
|
for span in val:
|
||||||
# TODO check that there isn't an off-by-one error here
|
|
||||||
# cluster.append((span.start, span.end))
|
|
||||||
# TODO This conversion should be happening earlier in processing
|
|
||||||
head_i = span.root.i
|
head_i = span.root.i
|
||||||
cluster.append((head_i, head_i + 1))
|
head = doc[head_i]
|
||||||
|
char_span = (head.idx, head.idx + len(head))
|
||||||
|
cluster.append(char_span)
|
||||||
|
|
||||||
# don't want duplicates
|
# don't want duplicates
|
||||||
cluster = list(set(cluster))
|
cluster = list(set(cluster))
|
||||||
|
|
|
@ -267,7 +267,19 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
example = list(examples)[0]
|
example = list(examples)[0]
|
||||||
cidx = mention_idx
|
cidx = mention_idx
|
||||||
|
|
||||||
clusters = get_clusters_from_doc(example.reference)
|
clusters_by_char = get_clusters_from_doc(example.reference)
|
||||||
|
# convert to token clusters, and give up if necessary
|
||||||
|
clusters = []
|
||||||
|
for cluster in clusters_by_char:
|
||||||
|
cc = []
|
||||||
|
for start_char, end_char in cluster:
|
||||||
|
span = example.predicted.char_span(start_char, end_char)
|
||||||
|
if span is None:
|
||||||
|
# TODO log more details
|
||||||
|
raise IndexError(Errors.E1038)
|
||||||
|
cc.append( (span.start, span.end) )
|
||||||
|
clusters.append(cc)
|
||||||
|
|
||||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
gscores = create_gold_scores(span_idxs, clusters)
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
# TODO fix type here. This is bools but asarray2f wants ints.
|
# TODO fix type here. This is bools but asarray2f wants ints.
|
||||||
|
|
|
@ -34,6 +34,15 @@ TRAIN_DATA = [
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def spans2ints(doc):
|
||||||
|
"""Convert doc.spans to nested list of ints for comparison.
|
||||||
|
|
||||||
|
This is useful for checking consistency of predictions.
|
||||||
|
"""
|
||||||
|
out = []
|
||||||
|
for key, cluster in doc.spans.items():
|
||||||
|
out.append( [(ss.start, ss.end) for ss in cluster] )
|
||||||
|
return out
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def nlp():
|
def nlp():
|
||||||
|
@ -108,7 +117,7 @@ def test_coref_serialization(nlp):
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_overfitting_IO(nlp):
|
def test_overfitting_IO(nlp):
|
||||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit - ensuring the ML models work correctly
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annot in TRAIN_DATA:
|
for text, annot in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annot))
|
||||||
|
@ -117,25 +126,21 @@ def test_overfitting_IO(nlp):
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
test_text = TRAIN_DATA[0][0]
|
test_text = TRAIN_DATA[0][0]
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print("BEFORE", doc.spans)
|
|
||||||
|
|
||||||
for i in range(5):
|
# Needs ~12 epochs to converge
|
||||||
|
for i in range(15):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print(i, doc.spans)
|
|
||||||
print(losses["coref"]) # < 0.001
|
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print("AFTER", doc.spans)
|
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
nlp.to_disk(tmp_dir)
|
nlp.to_disk(tmp_dir)
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
print("doc2", doc2.spans)
|
|
||||||
|
|
||||||
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
texts = [
|
texts = [
|
||||||
|
@ -143,12 +148,16 @@ def test_overfitting_IO(nlp):
|
||||||
"I noticed many friends around me",
|
"I noticed many friends around me",
|
||||||
"They received it. They received the SMS.",
|
"They received it. They received the SMS.",
|
||||||
]
|
]
|
||||||
batch_deps_1 = [doc.spans for doc in nlp.pipe(texts)]
|
docs = list(nlp.pipe(texts))
|
||||||
|
batch_deps_1 = [doc.spans for doc in docs]
|
||||||
print(batch_deps_1)
|
print(batch_deps_1)
|
||||||
batch_deps_2 = [doc.spans for doc in nlp.pipe(texts)]
|
docs = list(nlp.pipe(texts))
|
||||||
|
batch_deps_2 = [doc.spans for doc in docs]
|
||||||
print(batch_deps_2)
|
print(batch_deps_2)
|
||||||
no_batch_deps = [doc.spans for doc in [nlp(text) for text in texts]]
|
docs = [nlp(text) for text in texts]
|
||||||
|
no_batch_deps = [doc.spans for doc in docs]
|
||||||
print(no_batch_deps)
|
print(no_batch_deps)
|
||||||
|
print("FINISH")
|
||||||
# assert_equal(batch_deps_1, batch_deps_2)
|
# assert_equal(batch_deps_1, batch_deps_2)
|
||||||
# assert_equal(batch_deps_1, no_batch_deps)
|
# assert_equal(batch_deps_1, no_batch_deps)
|
||||||
|
|
||||||
|
@ -183,7 +192,6 @@ def test_tokenization_mismatch(nlp):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
print(i, doc.spans)
|
|
||||||
|
|
||||||
# test the trained model
|
# test the trained model
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
|
@ -202,12 +210,11 @@ def test_tokenization_mismatch(nlp):
|
||||||
]
|
]
|
||||||
|
|
||||||
# save the docs so they don't get garbage collected
|
# save the docs so they don't get garbage collected
|
||||||
docs = list(nlp.pipe(texts))
|
docs1 = list(nlp.pipe(texts))
|
||||||
batch_deps_1 = [doc.spans for doc in docs]
|
docs2 = list(nlp.pipe(texts))
|
||||||
docs = list(nlp.pipe(texts))
|
docs3 = [nlp(text) for text in texts]
|
||||||
batch_deps_2 = [doc.spans for doc in docs]
|
assert spans2ints(docs1[0]) == spans2ints(docs2[0])
|
||||||
docs = [nlp(text) for text in texts]
|
assert spans2ints(docs1[0]) == spans2ints(docs3[0])
|
||||||
no_batch_deps = [doc.spans for doc in docs]
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_crossing_spans():
|
def test_crossing_spans():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user