Remove _spans_to_offsets

Basically the same as get_clusters_from_doc
This commit is contained in:
Paul O'Leary McCann 2022-07-06 14:05:05 +09:00
parent 8f598d7b01
commit 6f5cf838ec
3 changed files with 14 additions and 28 deletions

View File

@ -203,17 +203,3 @@ def create_gold_scores(
# caller needs to convert to array, and add placeholder # caller needs to convert to array, and add placeholder
return out return out
def _spans_to_offsets(doc: Doc) -> List[List[Tuple[int, int]]]:
"""Convert doc.spans to nested list of ints for comparison.
The ints are character indices, and the spans groups are sorted by key first.
This is useful for checking consistency of predictions.
"""
out = []
keys = sorted([key for key in doc.spans])
for key in keys:
cluster = doc.spans[key]
out.append([(ss.start_char, ss.end_char) for ss in cluster])
return out

View File

@ -9,7 +9,7 @@ from spacy.ml.models.coref_util import (
DEFAULT_CLUSTER_PREFIX, DEFAULT_CLUSTER_PREFIX,
select_non_crossing_spans, select_non_crossing_spans,
get_sentence_ids, get_sentence_ids,
_spans_to_offsets, get_clusters_from_doc,
) )
from thinc.util import has_torch from thinc.util import has_torch
@ -101,7 +101,7 @@ def test_coref_serialization(nlp):
assert nlp2.pipe_names == ["coref"] assert nlp2.pipe_names == ["coref"]
doc2 = nlp2(text) doc2 = nlp2(text)
assert _spans_to_offsets(doc) == _spans_to_offsets(doc2) assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -140,8 +140,8 @@ def test_overfitting_IO(nlp):
docs1 = list(nlp.pipe(texts)) docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts)) docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts] docs3 = [nlp(text) for text in texts]
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -196,8 +196,8 @@ def test_tokenization_mismatch(nlp):
docs1 = list(nlp.pipe(texts)) docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts)) docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts] docs3 = [nlp(text) for text in texts]
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")

View File

@ -9,7 +9,7 @@ from spacy.ml.models.coref_util import (
DEFAULT_CLUSTER_PREFIX, DEFAULT_CLUSTER_PREFIX,
select_non_crossing_spans, select_non_crossing_spans,
get_sentence_ids, get_sentence_ids,
_spans_to_offsets, get_clusters_from_doc,
) )
from thinc.util import has_torch from thinc.util import has_torch
@ -88,7 +88,7 @@ def test_span_predictor_serialization(nlp):
assert nlp2.pipe_names == ["span_predictor"] assert nlp2.pipe_names == ["span_predictor"]
doc2 = nlp2(text) doc2 = nlp2(text)
assert _spans_to_offsets(doc) == _spans_to_offsets(doc2) assert get_clusters_from_doc(doc) == get_clusters_from_doc(doc2)
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -122,7 +122,7 @@ def test_overfitting_IO(nlp):
# test the trained model, using the pred since it has heads # test the trained model, using the pred since it has heads
doc = nlp(train_examples[0].predicted) doc = nlp(train_examples[0].predicted)
# XXX This actually tests that it can overfit # XXX This actually tests that it can overfit
assert _spans_to_offsets(doc) == _spans_to_offsets(train_examples[0].reference) assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
@ -140,8 +140,8 @@ def test_overfitting_IO(nlp):
docs1 = list(nlp.pipe(texts)) docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts)) docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts] docs3 = [nlp(text) for text in texts]
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")
@ -187,7 +187,7 @@ def test_tokenization_mismatch(nlp):
test_doc = train_examples[0].predicted test_doc = train_examples[0].predicted
doc = nlp(test_doc) doc = nlp(test_doc)
# XXX This actually tests that it can overfit # XXX This actually tests that it can overfit
assert _spans_to_offsets(doc) == _spans_to_offsets(train_examples[0].reference) assert get_clusters_from_doc(doc) == get_clusters_from_doc(train_examples[0].reference)
# Also test the results are still the same after IO # Also test the results are still the same after IO
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
@ -206,8 +206,8 @@ def test_tokenization_mismatch(nlp):
docs1 = list(nlp.pipe(texts)) docs1 = list(nlp.pipe(texts))
docs2 = list(nlp.pipe(texts)) docs2 = list(nlp.pipe(texts))
docs3 = [nlp(text) for text in texts] docs3 = [nlp(text) for text in texts]
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs2[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs2[0])
assert _spans_to_offsets(docs1[0]) == _spans_to_offsets(docs3[0]) assert get_clusters_from_doc(docs1[0]) == get_clusters_from_doc(docs3[0])
@pytest.mark.skipif(not has_torch, reason="Torch not available") @pytest.mark.skipif(not has_torch, reason="Torch not available")