Work in progress

This commit is contained in:
Richard Hudson 2021-11-28 15:39:07 +01:00
parent 93a4905b25
commit 2d0e916220

View File

@ -1,5 +1,41 @@
from spacy.tokens import Doc, Token from spacy.tokens import Doc, Token
from spacy.util import working_dir
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT = 1
HORIZONTAL_LINE_PART_TOWARDS_ROOT = 2
UPPER_VERTICAL_LINE_PART = 4
LOWER_VERTICAL_LINE_PART = 8
ARROWHEAD = 16
SPACE = 0
HORIZONTAL_LINE = 3
VERTICAL_LINE = 12
ROOT_RIGHT_CHARS = {
0: " ",
3: "",
5: "",
7: "",
9: "",
11: "",
12: "",
13: "",
15: "",
16: "<",
}
ROOT_LEFT_CHARS = {
0: " ",
3: "",
5: "",
7: "",
9: "",
11: "",
12: "",
13: "",
15: "",
16: ">",
}
class Visualizer: class Visualizer:
@staticmethod @staticmethod
@ -12,17 +48,17 @@ class Visualizer:
Adapted from https://github.com/KoichiYasuoka/deplacy Adapted from https://github.com/KoichiYasuoka/deplacy
""" """
heads = [ heads = [
-1 None
if token.dep_.lower() == "root" or token.head.i == token.i if token.dep_.lower() == "root" or token.head.i == token.i
else token.head.i else token.head.i
for token in doc for token in doc
] ]
children_lists = [[] for _ in range(len(doc))] children_lists = [[] for _ in range(len(doc))]
for child, head in enumerate(heads): for child, head in enumerate(heads):
if head != -1: if head is not None:
children_lists[head].append(child) children_lists[head].append(child)
all_indices_ordered_by_column = [] all_indices_ordered_by_column = []
indices_in_current_column = [i for i, h in enumerate(heads) if h == -1] indices_in_current_column = [i for i, h in enumerate(heads) if h is None]
while len(indices_in_current_column) > 0: while len(indices_in_current_column) > 0:
all_indices_ordered_by_column = ( all_indices_ordered_by_column = (
indices_in_current_column + all_indices_ordered_by_column indices_in_current_column + all_indices_ordered_by_column
@ -36,7 +72,7 @@ class Visualizer:
# -1: root token with no arrow; 0: length not yet set # -1: root token with no arrow; 0: length not yet set
horizontal_line_lengths = [ horizontal_line_lengths = [
-1 -1
if heads[i] == -1 if heads[i] is None
else 1 else 1
if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1
else 0 else 0
@ -97,9 +133,83 @@ class Visualizer:
] = working_horizontal_line_length ] = working_horizontal_line_length
print(horizontal_line_lengths) print(horizontal_line_lengths)
max_horizontal_line_length = max(horizontal_line_lengths)
char_matrix = [
[SPACE] * max_horizontal_line_length * 2 for _ in range(len(doc))
]
for working_token_index in range(len(doc)):
head_token_index = heads[working_token_index]
if head_token_index is None:
continue
first_index_in_relation = min(working_token_index, head_token_index)
second_index_in_relation = max(working_token_index, head_token_index)
char_horizontal_line_length = (
2 * horizontal_line_lengths[working_token_index]
)
char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= (
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + LOWER_VERTICAL_LINE_PART
)
char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= (
HORIZONTAL_LINE_PART_TOWARDS_ROOT + UPPER_VERTICAL_LINE_PART
)
for working_horizontal_position in range(char_horizontal_line_length - 1):
if (
char_matrix[working_token_index][working_horizontal_position]
!= VERTICAL_LINE
):
char_matrix[working_token_index][
working_horizontal_position
] |= HORIZONTAL_LINE
for working_vertical_position in range(
first_index_in_relation + 1, second_index_in_relation
):
if (
char_matrix[working_vertical_position][
char_horizontal_line_length - 1
]
!= HORIZONTAL_LINE
):
char_matrix[working_vertical_position][
char_horizontal_line_length - 1
] |= VERTICAL_LINE
for working_token_index in (i for i in range(len(doc)) if heads[i] is not None):
for working_horizontal_position in range(
2 * horizontal_line_lengths[working_token_index], -1, -1
):
print(working_token_index, working_horizontal_position)
if (
working_horizontal_position > 0
and char_matrix[working_token_index][working_horizontal_position]
== VERTICAL_LINE
and char_matrix[
working_token_index][working_horizontal_position - 1
]
== 0
):
# jump over crossing line in non-projective parse case
continue
if (
char_matrix[working_token_index][working_horizontal_position]
!= SPACE
):
char_matrix[working_token_index][
working_horizontal_position + 1
] = ARROWHEAD
break
if working_horizontal_position == 0:
char_matrix[working_token_index][
working_horizontal_position
] = ARROWHEAD
else:
char_matrix[working_token_index][
working_horizontal_position
] |= HORIZONTAL_LINE
return [''.join(ROOT_RIGHT_CHARS[char_matrix[vertical_position][horizontal_position]] for horizontal_position in range((max_horizontal_line_length * 2 - 1))) for vertical_position in range(len(doc))]
import spacy import spacy
nlp = spacy.load("en_core_web_sm") nlp = spacy.load("en_core_web_sm")
doc = nlp("I saw a horse yesterday that was injured") doc = nlp("I saw a horse yesterday that was injured")
Visualizer().render_dependency_trees(doc)
print (Visualizer().render_dependency_trees(doc))