Fixed problems with non-projective trees

This commit is contained in:
Richard Hudson 2021-12-07 12:04:41 +01:00
parent 06a9939eb5
commit e04950ef3c
2 changed files with 79 additions and 7 deletions

View File

@ -1,4 +1,5 @@
import pytest
import deplacy
from spacy.visualization import Visualizer
from spacy.tokens import Span, Doc
@ -113,6 +114,8 @@ def test_dependency_tree_non_projective(en_vocab):
deps=["dep"] * 9,
)
dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], True)
for line in dep_tree:
print(line)
assert dep_tree == [
"<╗ ",
"═╩═══╗",
@ -159,3 +162,45 @@ def test_dependency_tree_input_not_span(en_vocab):
with pytest.raises(AssertionError):
Visualizer.render_dependency_tree(doc[1:3], True)
def test_dependency_tree_highly_nonprojective(en_vocab):
"""Test a highly non-projective tree (colloquial Polish)."""
doc = Doc(
en_vocab,
words=[
"Owczarki",
"przecież",
"niemieckie",
"zawsze",
"wierne",
"",
"bardzo",
".",
],
heads=[5, 5, 0, 5, 5, None, 4, 5],
deps=["dep"] * 8,
)
dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], True)
assert dep_tree == [
"═╗<╗",
" ║<╣",
"<╝ ║",
"<══╣",
"═╗<╣",
"═══╣",
"<╝ ║",
"<══╝",
]
dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], False)
assert dep_tree == [
"╔>╔═",
"╠>║ ",
"║ ╚>",
"╠══>",
"╠>╔═",
"╠═══",
"║ ╚>",
"╚══>",
]

View File

@ -37,6 +37,11 @@ ROOT_LEFT_CHARS = {
}
class TableColumn:
def __init__(self, entity: str, width: int, overflow_strategy: str = "truncate"):
pass
class Visualizer:
@staticmethod
def render_dependency_tree(sent: Span, root_right: bool) -> list[str]:
@ -68,16 +73,38 @@ class Visualizer:
# start with the root column
indices_in_current_column = [i for i, h in enumerate(heads) if h is None]
while len(indices_in_current_column) > 0:
assert (
len(
[
i
for i in indices_in_current_column
if i in all_indices_ordered_by_column
]
)
== 0
)
all_indices_ordered_by_column = (
indices_in_current_column + all_indices_ordered_by_column
)
indices_in_current_column = [
i
for index_in_current_column in indices_in_current_column
for i in children_lists[index_in_current_column]
if i not in all_indices_ordered_by_column
]
# -1: root token with no arrow; 0: length not yet set
indices_in_next_column = []
# The calculation order of the horizontal lengths of the children
# on either given side of a head must ensure that children
# closer to the head are processed first.
for index_in_current_column in indices_in_current_column:
following_children_indices = [
i
for i in children_lists[index_in_current_column]
if i > index_in_current_column
]
indices_in_next_column.extend(following_children_indices)
preceding_children_indices = [
i
for i in children_lists[index_in_current_column]
if i < index_in_current_column
]
preceding_children_indices.reverse()
indices_in_next_column.extend(preceding_children_indices)
indices_in_current_column = indices_in_next_column
horizontal_line_lengths = [
-1 if heads[i] is None else 1
# length == 1: governed by direct neighbour and has no children itself