mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 13:47:13 +03:00
f277bfdf0f
* Draft out initial Spans data structure * Initial span group commit * Basic span group support on Doc * Basic test for span group * Compile span_group.pyx * Draft addition of SpanGroup to DocBin * Add deserialization for SpanGroup * Add tests for serializing SpanGroup * Fix serialization of SpanGroup * Add EdgeC and GraphC structs * Add draft Graph data structure * Compile graph * More work on Graph * Update GraphC * Upd graph * Fix walk functions * Let Graph take nodes and edges on construction * Fix walking and getting * Add graph tests * Fix import * Add module with the SpanGroups dict thingy * Update test * Rename 'span_groups' attribute * Try to fix c++11 compilation * Fix test * Update DocBin * Try to fix compilation * Try to fix graph * Improve SpanGroup docstrings * Add doc.spans to documentation * Fix serialization * Tidy up and add docs * Update docs [ci skip] * Add SpanGroup.has_overlap * WIP updated Graph API * Start testing new Graph API * Update Graph tests * Update Graph * Add docstring Co-authored-by: Ines Montani <ines@ines.io>
58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
from spacy.vocab import Vocab
|
|
from spacy.tokens.doc import Doc
|
|
from spacy.tokens.graph import Graph
|
|
|
|
|
|
def test_graph_init():
|
|
doc = Doc(Vocab(), words=["a", "b", "c", "d"])
|
|
graph = Graph(doc, name="hello")
|
|
assert graph.name == "hello"
|
|
assert graph.doc is doc
|
|
|
|
|
|
def test_graph_edges_and_nodes():
|
|
doc = Doc(Vocab(), words=["a", "b", "c", "d"])
|
|
graph = Graph(doc, name="hello")
|
|
node1 = graph.add_node((0,))
|
|
assert graph.get_node((0,)) == node1
|
|
node2 = graph.add_node((1, 3))
|
|
assert list(node2) == [1, 3]
|
|
graph.add_edge(
|
|
node1,
|
|
node2,
|
|
label="one",
|
|
weight=-10.5
|
|
)
|
|
assert graph.has_edge(
|
|
node1,
|
|
node2,
|
|
label="one"
|
|
)
|
|
assert node1.heads() == []
|
|
assert [tuple(h) for h in node2.heads()] == [(0,)]
|
|
assert [tuple(t) for t in node1.tails()] == [(1, 3)]
|
|
assert [tuple(t) for t in node2.tails()] == []
|
|
|
|
|
|
def test_graph_walk():
|
|
doc = Doc(Vocab(), words=["a", "b", "c", "d"])
|
|
graph = Graph(
|
|
doc,
|
|
name="hello",
|
|
nodes=[(0,), (1,), (2,), (3,)],
|
|
edges=[(0, 1), (0, 2), (0, 3), (3, 0)],
|
|
labels=None,
|
|
weights=None
|
|
)
|
|
node0, node1, node2, node3 = list(graph.nodes)
|
|
assert [tuple(h) for h in node0.heads()] == [(3,)]
|
|
assert [tuple(h) for h in node1.heads()] == [(0,)]
|
|
assert [tuple(h) for h in node0.walk_heads()] == [(3,), (0,)]
|
|
assert [tuple(h) for h in node1.walk_heads()] == [(0,), (3,), (0,)]
|
|
assert [tuple(h) for h in node2.walk_heads()] == [(0,), (3,), (0,)]
|
|
assert [tuple(h) for h in node3.walk_heads()] == [(0,), (3,)]
|
|
assert [tuple(t) for t in node0.walk_tails()] == [(1,), (2,), (3,), (0,)]
|
|
assert [tuple(t) for t in node1.walk_tails()] == []
|
|
assert [tuple(t) for t in node2.walk_tails()] == []
|
|
assert [tuple(t) for t in node3.walk_tails()] == [(0,), (1,), (2,), (3,)]
|