Fix offsets in Span.get_lca_matrix (#8116)

* Fix range in Span.get_lca_matrix

Fix the adjusted token index / lca matrix index ranges for
`_get_lca_matrix` for spans.

* The range for `k` should correspond to the adjusted indices in
`lca_matrix` with the `start` indexed at `0`

* Update test for v3.x
This commit is contained in:
Adriane Boyd 2021-05-17 16:54:23 +02:00 committed by GitHub
parent 0dffc5d9e2
commit 2c545c4c5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View File

@ -1,4 +1,6 @@
import pytest import pytest
import numpy
from numpy.testing import assert_array_equal
from spacy.attrs import ORTH, LENGTH from spacy.attrs import ORTH, LENGTH
from spacy.tokens import Doc, Span, Token from spacy.tokens import Doc, Span, Token
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -120,6 +122,17 @@ def test_spans_lca_matrix(en_tokenizer):
assert lca[1, 0] == 1 # slept & dog -> slept assert lca[1, 0] == 1 # slept & dog -> slept
assert lca[1, 1] == 1 # slept & slept -> slept assert lca[1, 1] == 1 # slept & slept -> slept
# example from Span API docs
tokens = en_tokenizer("I like New York in Autumn")
doc = Doc(
tokens.vocab,
words=[t.text for t in tokens],
heads=[1, 1, 3, 1, 3, 4],
deps=["dep"] * len(tokens),
)
lca = doc[1:4].get_lca_matrix()
assert_array_equal(lca, numpy.asarray([[0, 0, 0], [0, 1, 2], [0, 2, 2]]))
def test_span_similarity_match(): def test_span_similarity_match():
doc = Doc(Vocab(), words=["a", "b", "a", "b"]) doc = Doc(Vocab(), words=["a", "b", "a", "b"])

View File

@ -1673,7 +1673,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
j_idx_in_sent = start + j - sent_start j_idx_in_sent = start + j - sent_start
n_missing_tokens_in_sent = len(sent) - j_idx_in_sent n_missing_tokens_in_sent = len(sent) - j_idx_in_sent
# make sure we do not go past `end`, in cases where `end` < sent.end # make sure we do not go past `end`, in cases where `end` < sent.end
max_range = min(j + n_missing_tokens_in_sent, end) max_range = min(j + n_missing_tokens_in_sent, end - start)
for k in range(j + 1, max_range): for k in range(j + 1, max_range):
lca = _get_tokens_lca(token_j, doc[start + k]) lca = _get_tokens_lca(token_j, doc[start + k])
# if lca is outside of span, we set it to -1 # if lca is outside of span, we set it to -1