mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Add SpanGroup and Graph container types to represent arbitrary annotations (#6696)
* Draft out initial Spans data structure * Initial span group commit * Basic span group support on Doc * Basic test for span group * Compile span_group.pyx * Draft addition of SpanGroup to DocBin * Add deserialization for SpanGroup * Add tests for serializing SpanGroup * Fix serialization of SpanGroup * Add EdgeC and GraphC structs * Add draft Graph data structure * Compile graph * More work on Graph * Update GraphC * Upd graph * Fix walk functions * Let Graph take nodes and edges on construction * Fix walking and getting * Add graph tests * Fix import * Add module with the SpanGroups dict thingy * Update test * Rename 'span_groups' attribute * Try to fix c++11 compilation * Fix test * Update DocBin * Try to fix compilation * Try to fix graph * Improve SpanGroup docstrings * Add doc.spans to documentation * Fix serialization * Tidy up and add docs * Update docs [ci skip] * Add SpanGroup.has_overlap * WIP updated Graph API * Start testing new Graph API * Update Graph tests * Update Graph * Add docstring Co-authored-by: Ines Montani <ines@ines.io>
This commit is contained in:
parent
54e8e3c208
commit
f277bfdf0f
6
setup.py
6
setup.py
|
@ -55,6 +55,8 @@ MOD_NAMES = [
|
|||
"spacy.tokens.doc",
|
||||
"spacy.tokens.span",
|
||||
"spacy.tokens.token",
|
||||
"spacy.tokens.span_group",
|
||||
"spacy.tokens.graph",
|
||||
"spacy.tokens.morphanalysis",
|
||||
"spacy.tokens._retokenize",
|
||||
"spacy.matcher.matcher",
|
||||
|
@ -68,7 +70,7 @@ COMPILE_OPTIONS = {
|
|||
"mingw32": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
|
||||
"other": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
|
||||
}
|
||||
LINK_OPTIONS = {"msvc": [], "mingw32": [], "other": []}
|
||||
LINK_OPTIONS = {"msvc": ["-std=c++11"], "mingw32": ["-std=c++11"], "other": []}
|
||||
COMPILER_DIRECTIVES = {
|
||||
"language_level": -3,
|
||||
"embedsignature": True,
|
||||
|
@ -201,7 +203,7 @@ def setup_package():
|
|||
ext_modules = []
|
||||
for name in MOD_NAMES:
|
||||
mod_path = name.replace(".", "/") + ".pyx"
|
||||
ext = Extension(name, [mod_path], language="c++")
|
||||
ext = Extension(name, [mod_path], language="c++", extra_compile_args=["-std=c++11"])
|
||||
ext_modules.append(ext)
|
||||
print("Cythonizing sources")
|
||||
ext_modules = cythonize(ext_modules, compiler_directives=COMPILER_DIRECTIVES)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from libc.stdint cimport uint8_t, uint32_t, int32_t, uint64_t
|
||||
from libcpp.vector cimport vector
|
||||
from libcpp.unordered_set cimport unordered_set
|
||||
from libcpp.unordered_map cimport unordered_map
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
|
||||
from .typedefs cimport flags_t, attr_t, hash_t
|
||||
|
@ -91,3 +93,22 @@ cdef struct AliasC:
|
|||
|
||||
# Prior probability P(entity|alias) - should sum up to (at most) 1.
|
||||
vector[float] probs
|
||||
|
||||
|
||||
cdef struct EdgeC:
|
||||
hash_t label
|
||||
int32_t head
|
||||
int32_t tail
|
||||
|
||||
|
||||
cdef struct GraphC:
|
||||
vector[vector[int32_t]] nodes
|
||||
vector[EdgeC] edges
|
||||
vector[float] weights
|
||||
vector[int] n_heads
|
||||
vector[int] n_tails
|
||||
vector[int] first_head
|
||||
vector[int] first_tail
|
||||
unordered_set[int]* roots
|
||||
unordered_map[hash_t, int]* node_map
|
||||
unordered_map[hash_t, int]* edge_map
|
||||
|
|
|
@ -631,3 +631,24 @@ def test_doc_set_ents_invalid_spans(en_tokenizer):
|
|||
retokenizer.merge(span)
|
||||
with pytest.raises(IndexError):
|
||||
doc.ents = spans
|
||||
|
||||
|
||||
def test_span_groups(en_tokenizer):
|
||||
doc = en_tokenizer("Some text about Colombia and the Czech Republic")
|
||||
doc.spans["hi"] = [Span(doc, 3, 4, label="bye")]
|
||||
assert "hi" in doc.spans
|
||||
assert "bye" not in doc.spans
|
||||
assert len(doc.spans["hi"]) == 1
|
||||
assert doc.spans["hi"][0].label_ == "bye"
|
||||
doc.spans["hi"].append(doc[0:3])
|
||||
assert len(doc.spans["hi"]) == 2
|
||||
assert doc.spans["hi"][1].text == "Some text about"
|
||||
assert [span.text for span in doc.spans["hi"]] == ["Colombia", "Some text about"]
|
||||
assert not doc.spans["hi"].has_overlap
|
||||
doc.ents = [Span(doc, 3, 4, label="GPE"), Span(doc, 6, 8, label="GPE")]
|
||||
doc.spans["hi"].extend(doc.ents)
|
||||
assert len(doc.spans["hi"]) == 4
|
||||
assert [span.label_ for span in doc.spans["hi"]] == ["bye", "", "GPE", "GPE"]
|
||||
assert doc.spans["hi"].has_overlap
|
||||
del doc.spans["hi"]
|
||||
assert "hi" not in doc.spans
|
||||
|
|
57
spacy/tests/doc/test_graph.py
Normal file
57
spacy/tests/doc/test_graph.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from spacy.vocab import Vocab
|
||||
from spacy.tokens.doc import Doc
|
||||
from spacy.tokens.graph import Graph
|
||||
|
||||
|
||||
def test_graph_init():
|
||||
doc = Doc(Vocab(), words=["a", "b", "c", "d"])
|
||||
graph = Graph(doc, name="hello")
|
||||
assert graph.name == "hello"
|
||||
assert graph.doc is doc
|
||||
|
||||
|
||||
def test_graph_edges_and_nodes():
|
||||
doc = Doc(Vocab(), words=["a", "b", "c", "d"])
|
||||
graph = Graph(doc, name="hello")
|
||||
node1 = graph.add_node((0,))
|
||||
assert graph.get_node((0,)) == node1
|
||||
node2 = graph.add_node((1, 3))
|
||||
assert list(node2) == [1, 3]
|
||||
graph.add_edge(
|
||||
node1,
|
||||
node2,
|
||||
label="one",
|
||||
weight=-10.5
|
||||
)
|
||||
assert graph.has_edge(
|
||||
node1,
|
||||
node2,
|
||||
label="one"
|
||||
)
|
||||
assert node1.heads() == []
|
||||
assert [tuple(h) for h in node2.heads()] == [(0,)]
|
||||
assert [tuple(t) for t in node1.tails()] == [(1, 3)]
|
||||
assert [tuple(t) for t in node2.tails()] == []
|
||||
|
||||
|
||||
def test_graph_walk():
|
||||
doc = Doc(Vocab(), words=["a", "b", "c", "d"])
|
||||
graph = Graph(
|
||||
doc,
|
||||
name="hello",
|
||||
nodes=[(0,), (1,), (2,), (3,)],
|
||||
edges=[(0, 1), (0, 2), (0, 3), (3, 0)],
|
||||
labels=None,
|
||||
weights=None
|
||||
)
|
||||
node0, node1, node2, node3 = list(graph.nodes)
|
||||
assert [tuple(h) for h in node0.heads()] == [(3,)]
|
||||
assert [tuple(h) for h in node1.heads()] == [(0,)]
|
||||
assert [tuple(h) for h in node0.walk_heads()] == [(3,), (0,)]
|
||||
assert [tuple(h) for h in node1.walk_heads()] == [(0,), (3,), (0,)]
|
||||
assert [tuple(h) for h in node2.walk_heads()] == [(0,), (3,), (0,)]
|
||||
assert [tuple(h) for h in node3.walk_heads()] == [(0,), (3,)]
|
||||
assert [tuple(t) for t in node0.walk_tails()] == [(1,), (2,), (3,), (0,)]
|
||||
assert [tuple(t) for t in node1.walk_tails()] == []
|
||||
assert [tuple(t) for t in node2.walk_tails()] == []
|
||||
assert [tuple(t) for t in node3.walk_tails()] == [(0,), (1,), (2,), (3,)]
|
|
@ -56,6 +56,13 @@ def test_serialize_doc_exclude(en_vocab):
|
|||
assert not new_doc.user_data
|
||||
|
||||
|
||||
def test_serialize_doc_span_groups(en_vocab):
|
||||
doc = Doc(en_vocab, words=["hello", "world", "!"])
|
||||
doc.spans["content"] = [doc[0:2]]
|
||||
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
|
||||
assert len(new_doc.spans["content"]) == 1
|
||||
|
||||
|
||||
def test_serialize_doc_bin():
|
||||
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
||||
texts = ["Some text", "Lots of texts...", "..."]
|
||||
|
@ -63,6 +70,7 @@ def test_serialize_doc_bin():
|
|||
nlp = English()
|
||||
for doc in nlp.pipe(texts):
|
||||
doc.cats = cats
|
||||
doc.spans["start"] = [doc[0:2]]
|
||||
doc_bin.add(doc)
|
||||
bytes_data = doc_bin.to_bytes()
|
||||
|
||||
|
@ -73,6 +81,7 @@ def test_serialize_doc_bin():
|
|||
for i, doc in enumerate(reloaded_docs):
|
||||
assert doc.text == texts[i]
|
||||
assert doc.cats == cats
|
||||
assert len(doc.spans) == 1
|
||||
|
||||
|
||||
def test_serialize_doc_bin_unknown_spaces(en_vocab):
|
||||
|
|
49
spacy/tokens/_dict_proxies.py
Normal file
49
spacy/tokens/_dict_proxies.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from typing import Iterable, Tuple, Union, TYPE_CHECKING
|
||||
import weakref
|
||||
from collections import UserDict
|
||||
import srsly
|
||||
|
||||
from .span_group import SpanGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This lets us add type hints for mypy etc. without causing circular imports
|
||||
from .doc import Doc # noqa: F401
|
||||
from .span import Span # noqa: F401
|
||||
|
||||
|
||||
# Why inherit from UserDict instead of dict here?
|
||||
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
|
||||
# for performance reasons. The UserDict is slower by better behaved.
|
||||
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
|
||||
class SpanGroups(UserDict):
|
||||
"""A dict-like proxy held by the Doc, to control access to span groups."""
|
||||
|
||||
def __init__(
|
||||
self, doc: "Doc", items: Iterable[Tuple[str, SpanGroup]] = tuple()
|
||||
) -> None:
|
||||
self.doc_ref = weakref.ref(doc)
|
||||
UserDict.__init__(self, items)
|
||||
|
||||
def __setitem__(self, key: str, value: Union[SpanGroup, Iterable["Span"]]) -> None:
|
||||
if not isinstance(value, SpanGroup):
|
||||
value = self._make_span_group(key, value)
|
||||
assert value.doc is self.doc_ref()
|
||||
UserDict.__setitem__(self, key, value)
|
||||
|
||||
def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup:
|
||||
return SpanGroup(self.doc_ref(), name=name, spans=spans)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
# We don't need to serialize this as a dict, because the groups
|
||||
# know their names.
|
||||
msg = [value.to_bytes() for value in self.values()]
|
||||
return srsly.msgpack_dumps(msg)
|
||||
|
||||
def from_bytes(self, bytes_data: bytes) -> "SpanGroups":
|
||||
msg = srsly.msgpack_loads(bytes_data)
|
||||
self.clear()
|
||||
doc = self.doc_ref()
|
||||
for value_bytes in msg:
|
||||
group = SpanGroup(doc).from_bytes(value_bytes)
|
||||
self[group.name] = group
|
||||
return self
|
|
@ -33,6 +33,7 @@ class DocBin:
|
|||
{
|
||||
"attrs": List[uint64], # e.g. [TAG, HEAD, ENT_IOB, ENT_TYPE]
|
||||
"tokens": bytes, # Serialized numpy uint64 array with the token data
|
||||
"spans": List[Dict[str, bytes]], # SpanGroups data for each doc
|
||||
"spaces": bytes, # Serialized numpy boolean array with spaces data
|
||||
"lengths": bytes, # Serialized numpy int32 array with the doc lengths
|
||||
"strings": List[unicode] # List of unique strings in the token data
|
||||
|
@ -70,6 +71,7 @@ class DocBin:
|
|||
self.tokens = []
|
||||
self.spaces = []
|
||||
self.cats = []
|
||||
self.span_groups = []
|
||||
self.user_data = []
|
||||
self.flags = []
|
||||
self.strings = set()
|
||||
|
@ -107,6 +109,10 @@ class DocBin:
|
|||
self.strings.add(token.ent_kb_id_)
|
||||
self.cats.append(doc.cats)
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
self.span_groups.append(doc.spans.to_bytes())
|
||||
for key, group in doc.spans.items():
|
||||
for span in group:
|
||||
self.strings.add(span.label_)
|
||||
|
||||
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
||||
"""Recover Doc objects from the annotations, using the given vocab.
|
||||
|
@ -130,6 +136,10 @@ class DocBin:
|
|||
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
|
||||
doc = doc.from_array(self.attrs, tokens)
|
||||
doc.cats = self.cats[i]
|
||||
if self.span_groups[i]:
|
||||
doc.spans.from_bytes(self.span_groups[i])
|
||||
else:
|
||||
doc.spans.clear()
|
||||
if i < len(self.user_data) and self.user_data[i] is not None:
|
||||
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
||||
doc.user_data.update(user_data)
|
||||
|
@ -161,6 +171,7 @@ class DocBin:
|
|||
self.spaces.extend(other.spaces)
|
||||
self.strings.update(other.strings)
|
||||
self.cats.extend(other.cats)
|
||||
self.span_groups.extend(other.span_groups)
|
||||
self.flags.extend(other.flags)
|
||||
self.user_data.extend(other.user_data)
|
||||
|
||||
|
@ -185,6 +196,7 @@ class DocBin:
|
|||
"strings": list(sorted(self.strings)),
|
||||
"cats": self.cats,
|
||||
"flags": self.flags,
|
||||
"span_groups": self.span_groups,
|
||||
}
|
||||
if self.store_user_data:
|
||||
msg["user_data"] = self.user_data
|
||||
|
@ -213,6 +225,7 @@ class DocBin:
|
|||
self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
|
||||
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
||||
self.cats = msg["cats"]
|
||||
self.span_groups = msg.get("span_groups", [b"" for _ in lengths])
|
||||
self.flags = msg.get("flags", [{} for _ in lengths])
|
||||
if "user_data" in msg:
|
||||
self.user_data = list(msg["user_data"])
|
||||
|
|
|
@ -2,7 +2,7 @@ from cymem.cymem cimport Pool
|
|||
cimport numpy as np
|
||||
|
||||
from ..vocab cimport Vocab
|
||||
from ..structs cimport TokenC, LexemeC
|
||||
from ..structs cimport TokenC, LexemeC, SpanC
|
||||
from ..typedefs cimport attr_t
|
||||
from ..attrs cimport attr_id_t
|
||||
|
||||
|
@ -33,6 +33,7 @@ cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
|
|||
|
||||
cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
|
||||
|
||||
|
||||
cdef class Doc:
|
||||
cdef readonly Pool mem
|
||||
cdef readonly Vocab vocab
|
||||
|
@ -43,6 +44,7 @@ cdef class Doc:
|
|||
cdef public object tensor
|
||||
cdef public object cats
|
||||
cdef public object user_data
|
||||
cdef readonly object spans
|
||||
|
||||
cdef TokenC* c
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from thinc.util import copy_array
|
|||
import warnings
|
||||
|
||||
from .span cimport Span
|
||||
from ._dict_proxies import SpanGroups
|
||||
from .token cimport Token
|
||||
from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||
from ..typedefs cimport attr_t, flags_t
|
||||
|
@ -222,6 +223,7 @@ cdef class Doc:
|
|||
self.vocab = vocab
|
||||
size = max(20, (len(words) if words is not None else 0))
|
||||
self.mem = Pool()
|
||||
self.spans = SpanGroups(self)
|
||||
# Guarantee self.lex[i-x], for any i >= 0 and x < padding is in bounds
|
||||
# However, we need to remember the true starting places, so that we can
|
||||
# realloc.
|
||||
|
@ -1255,6 +1257,9 @@ cdef class Doc:
|
|||
strings.add(token.ent_kb_id_)
|
||||
strings.add(token.ent_id_)
|
||||
strings.add(token.norm_)
|
||||
for group in self.spans.values():
|
||||
for span in group:
|
||||
strings.add(span.label_)
|
||||
# Msgpack doesn't distinguish between lists and tuples, which is
|
||||
# vexing for user data. As a best guess, we *know* that within
|
||||
# keys, we must have tuples. In values we just have to hope
|
||||
|
@ -1266,6 +1271,7 @@ cdef class Doc:
|
|||
"sentiment": lambda: self.sentiment,
|
||||
"tensor": lambda: self.tensor,
|
||||
"cats": lambda: self.cats,
|
||||
"spans": lambda: self.spans.to_bytes(),
|
||||
"strings": lambda: list(strings),
|
||||
"has_unknown_spaces": lambda: self.has_unknown_spaces
|
||||
}
|
||||
|
@ -1290,18 +1296,6 @@ cdef class Doc:
|
|||
"""
|
||||
if self.length != 0:
|
||||
raise ValueError(Errors.E033.format(length=self.length))
|
||||
deserializers = {
|
||||
"text": lambda b: None,
|
||||
"array_head": lambda b: None,
|
||||
"array_body": lambda b: None,
|
||||
"sentiment": lambda b: None,
|
||||
"tensor": lambda b: None,
|
||||
"cats": lambda b: None,
|
||||
"strings": lambda b: None,
|
||||
"user_data_keys": lambda b: None,
|
||||
"user_data_values": lambda b: None,
|
||||
"has_unknown_spaces": lambda b: None
|
||||
}
|
||||
# Msgpack doesn't distinguish between lists and tuples, which is
|
||||
# vexing for user data. As a best guess, we *know* that within
|
||||
# keys, we must have tuples. In values we just have to hope
|
||||
|
@ -1336,9 +1330,12 @@ cdef class Doc:
|
|||
self.push_back(lex, has_space)
|
||||
start = end + has_space
|
||||
self.from_array(msg["array_head"][2:], attrs[:, 2:])
|
||||
if "spans" in msg:
|
||||
self.spans.from_bytes(msg["spans"])
|
||||
else:
|
||||
self.spans.clear()
|
||||
return self
|
||||
|
||||
|
||||
def extend_tensor(self, tensor):
|
||||
"""Concatenate a new tensor onto the doc.tensor object.
|
||||
|
||||
|
|
13
spacy/tokens/graph.pxd
Normal file
13
spacy/tokens/graph.pxd
Normal file
|
@ -0,0 +1,13 @@
|
|||
from libcpp.vector cimport vector
|
||||
from cymem.cymem cimport Pool
|
||||
from preshed.maps cimport PreshMap
|
||||
from ..structs cimport GraphC, EdgeC
|
||||
|
||||
|
||||
cdef class Graph:
|
||||
cdef GraphC c
|
||||
cdef Pool mem
|
||||
cdef PreshMap node_map
|
||||
cdef PreshMap edge_map
|
||||
cdef object doc_ref
|
||||
cdef public str name
|
709
spacy/tokens/graph.pyx
Normal file
709
spacy/tokens/graph.pyx
Normal file
|
@ -0,0 +1,709 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||
from typing import List, Tuple, Generator
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
from libcpp.pair cimport pair
|
||||
from libcpp.unordered_map cimport unordered_map
|
||||
from libcpp.unordered_set cimport unordered_set
|
||||
from cython.operator cimport dereference
|
||||
cimport cython
|
||||
import weakref
|
||||
from preshed.maps cimport map_get_unless_missing
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from ..typedefs cimport hash_t
|
||||
from ..strings import get_string_id
|
||||
from ..structs cimport EdgeC, GraphC
|
||||
from .token import Token
|
||||
|
||||
|
||||
@cython.freelist(8)
|
||||
cdef class Edge:
|
||||
cdef readonly Graph graph
|
||||
cdef readonly int i
|
||||
|
||||
def __init__(self, Graph graph, int i):
|
||||
self.graph = graph
|
||||
self.i = i
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def doc(self) -> "Doc":
|
||||
return self.graph.doc
|
||||
|
||||
@property
|
||||
def head(self) -> "Node":
|
||||
return Node(self.graph, self.graph.c.edges[self.i].head)
|
||||
|
||||
@property
|
||||
def tail(self) -> "Tail":
|
||||
return Node(self.graph, self.graph.c.edges[self.i].tail)
|
||||
|
||||
@property
|
||||
def label(self) -> int:
|
||||
return self.graph.c.edges[self.i].label
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self.graph.c.weights[self.i]
|
||||
|
||||
@property
|
||||
def label_(self) -> str:
|
||||
return self.doc.vocab.strings[self.label]
|
||||
|
||||
|
||||
@cython.freelist(8)
|
||||
cdef class Node:
|
||||
cdef readonly Graph graph
|
||||
cdef readonly int i
|
||||
|
||||
def __init__(self, Graph graph, int i):
|
||||
"""A reference to a node of an annotation graph. Each node is made up of
|
||||
an ordered set of zero or more token indices.
|
||||
|
||||
Node references are usually created by the Graph object itself, or from
|
||||
the Node or Edge objects. You usually won't need to instantiate this
|
||||
class yourself.
|
||||
"""
|
||||
cdef int length = graph.c.nodes.size()
|
||||
if i >= length or -i >= length:
|
||||
raise IndexError(f"Node index {i} out of bounds ({length})")
|
||||
if i < 0:
|
||||
i += length
|
||||
self.graph = graph
|
||||
self.i = i
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.graph is not other.graph:
|
||||
return False
|
||||
else:
|
||||
return self.i == other.i
|
||||
|
||||
def __iter__(self) -> Generator[int]:
|
||||
for i in self.graph.c.nodes[self.i]:
|
||||
yield i
|
||||
|
||||
def __getitem__(self, int i) -> int:
|
||||
"""Get a token index from the node's set of tokens."""
|
||||
length = self.graph.c.nodes[self.i].size()
|
||||
if i >= length or -i >= length:
|
||||
raise IndexError(f"Token index {i} out of bounds ({length})")
|
||||
if i < 0:
|
||||
i += length
|
||||
return self.graph.c.nodes[self.i][i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of tokens that make up the node."""
|
||||
return self.graph.c.nodes[self.i].size()
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
"""Whether the node is a special value, indicating 'none'.
|
||||
|
||||
The NoneNode type is returned by the Graph, Edge and Node objects when
|
||||
there is no match to a query. It has the same API as Node, but it always
|
||||
returns NoneNode, NoneEdge or empty lists for its queries.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def doc(self) -> "Doc":
|
||||
"""The Doc object that the graph refers to."""
|
||||
return self.graph.doc
|
||||
|
||||
@property
|
||||
def tokens(self) -> Tuple[Token]:
|
||||
"""A tuple of Token objects that make up the node."""
|
||||
doc = self.doc
|
||||
return tuple([doc[i] for i in self])
|
||||
|
||||
def head(self, i=None, label=None) -> "Node":
|
||||
"""Get the head of the first matching edge, searching by index, label,
|
||||
both or neither.
|
||||
|
||||
For instance, `node.head(i=1)` will get the head of the second edge that
|
||||
this node is a tail of. `node.head(i=1, label="ARG0")` will further
|
||||
check that the second edge has the label `"ARG0"`.
|
||||
|
||||
If no matching node can be found, the graph's NoneNode is returned.
|
||||
"""
|
||||
return self.headed(i=i, label=label)
|
||||
|
||||
def tail(self, i=None, label=None) -> "Node":
|
||||
"""Get the tail of the first matching edge, searching by index, label,
|
||||
both or neither.
|
||||
|
||||
If no matching node can be found, the graph's NoneNode is returned.
|
||||
"""
|
||||
return self.tailed(i=i, label=label).tail
|
||||
|
||||
def sibling(self, i=None, label=None):
|
||||
"""Get the first matching sibling node. Two nodes are siblings if they
|
||||
are both tails of the same head.
|
||||
If no matching node can be found, the graph's NoneNode is returned.
|
||||
"""
|
||||
if i is None:
|
||||
siblings = self.siblings(label=label)
|
||||
return siblings[0] if siblings else NoneNode(self)
|
||||
else:
|
||||
edges = []
|
||||
for h in self.headed():
|
||||
edges.extend([e for e in h.tailed() if e.tail.i != self.i])
|
||||
if i >= len(edges):
|
||||
return NoneNode(self)
|
||||
elif label is not None and edges[i].label != label:
|
||||
return NoneNode(self)
|
||||
else:
|
||||
return edges[i].tail
|
||||
|
||||
def heads(self, label=None) -> List["Node"]:
|
||||
"""Find all matching heads of this node."""
|
||||
cdef vector[int] edge_indices
|
||||
self._find_edges(edge_indices, "head", label)
|
||||
return [Node(self.graph, self.graph.c.edges[i].head) for i in edge_indices]
|
||||
|
||||
def tails(self, label=None) -> List["Node"]:
|
||||
"""Find all matching tails of this node."""
|
||||
cdef vector[int] edge_indices
|
||||
self._find_edges(edge_indices, "tail", label)
|
||||
return [Node(self.graph, self.graph.c.edges[i].tail) for i in edge_indices]
|
||||
|
||||
def siblings(self, label=None) -> List["Node"]:
|
||||
"""Find all maching siblings of this node. Two nodes are siblings if they
|
||||
are tails of the same head.
|
||||
"""
|
||||
edges = []
|
||||
for h in self.headed():
|
||||
edges.extend([e for e in h.tailed() if e.tail.i != self.i])
|
||||
if label is None:
|
||||
return [e.tail for e in edges]
|
||||
else:
|
||||
return [e.tail for e in edges if e.label == label]
|
||||
|
||||
def headed(self, i=None, label=None) -> Edge:
|
||||
"""Find the first matching edge headed by this node.
|
||||
If no matching edge can be found, the graph's NoneEdge is returned.
|
||||
"""
|
||||
start, end = self._find_range(i, self.c.n_head[self.i])
|
||||
idx = self._find_edge("head", start, end, label)
|
||||
if idx == -1:
|
||||
return NoneEdge(self.graph)
|
||||
else:
|
||||
return Edge(self.graph, idx)
|
||||
|
||||
def tailed(self, i=None, label=None) -> Edge:
|
||||
"""Find the first matching edge tailed by this node.
|
||||
If no matching edge can be found, the graph's NoneEdge is returned.
|
||||
"""
|
||||
start, end = self._find_range(i, self.c.n_tail[self.i])
|
||||
idx = self._find_edge("tail", start, end, label)
|
||||
if idx == -1:
|
||||
return NoneEdge(self.graph)
|
||||
else:
|
||||
return Edge(self.graph, idx)
|
||||
|
||||
def headeds(self, label=None) -> List[Edge]:
|
||||
"""Find all matching edges headed by this node."""
|
||||
cdef vector[int] edge_indices
|
||||
self._find_edges(edge_indices, "head", label)
|
||||
return [Edge(self.graph, i) for i in edge_indices]
|
||||
|
||||
def taileds(self, label=None) -> List["Edge"]:
|
||||
"""Find all matching edges headed by this node."""
|
||||
cdef vector[int] edge_indices
|
||||
self._find_edges(edge_indices, "tail", label)
|
||||
return [Edge(self.graph, i) for i in edge_indices]
|
||||
|
||||
def walk_heads(self):
|
||||
cdef vector[int] node_indices
|
||||
walk_head_nodes(node_indices, &self.graph.c, self.i)
|
||||
for i in node_indices:
|
||||
yield Node(self.graph, i)
|
||||
|
||||
def walk_tails(self):
|
||||
cdef vector[int] node_indices
|
||||
walk_tail_nodes(node_indices, &self.graph.c, self.i)
|
||||
for i in node_indices:
|
||||
yield Node(self.graph, i)
|
||||
|
||||
cdef (int, int) _get_range(self, i, n):
|
||||
if i is None:
|
||||
return (0, n)
|
||||
elif i < n:
|
||||
return (i, i+1)
|
||||
else:
|
||||
return (0, 0)
|
||||
|
||||
cdef int _find_edge(self, str direction, int start, int end, label) except -2:
|
||||
if direction == "head":
|
||||
get_edges = get_head_edges
|
||||
else:
|
||||
get_edges = get_tail_edges
|
||||
cdef vector[int] edge_indices
|
||||
get_edges(edge_indices, &self.graph.c, self.i)
|
||||
if label is None:
|
||||
return edge_indices[start]
|
||||
for edge_index in edge_indices[start:end]:
|
||||
if self.graph.c.edges[edge_index].label == label:
|
||||
return edge_index
|
||||
else:
|
||||
return -1
|
||||
|
||||
cdef int _find_edges(self, vector[int]& edge_indices, str direction, label):
|
||||
if direction == "head":
|
||||
get_edges = get_head_edges
|
||||
else:
|
||||
get_edges = get_tail_edges
|
||||
if label is None:
|
||||
get_edges(edge_indices, &self.graph.c, self.i)
|
||||
return edge_indices.size()
|
||||
cdef vector[int] unfiltered
|
||||
get_edges(unfiltered, &self.graph.c, self.i)
|
||||
for edge_index in unfiltered:
|
||||
if self.graph.c.edges[edge_index].label == label:
|
||||
edge_indices.push_back(edge_index)
|
||||
return edge_indices.size()
|
||||
|
||||
|
||||
cdef class NoneEdge(Edge):
|
||||
"""An Edge subclass, representing a non-result. The NoneEdge has the same
|
||||
API as other Edge instances, but always returns NoneEdge, NoneNode, or empty
|
||||
lists.
|
||||
"""
|
||||
def __init__(self, graph):
|
||||
self.graph = graph
|
||||
self.i = -1
|
||||
|
||||
@property
|
||||
def doc(self) -> "Doc":
|
||||
return self.graph.doc
|
||||
|
||||
@property
|
||||
def head(self) -> "NoneNode":
|
||||
return NoneNode(self.graph)
|
||||
|
||||
@property
|
||||
def tail(self) -> "NoneNode":
|
||||
return NoneNode(self.graph)
|
||||
|
||||
@property
|
||||
def label(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def label_(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
cdef class NoneNode(Node):
|
||||
def __init__(self, graph):
|
||||
self.graph = graph
|
||||
self.i = -1
|
||||
|
||||
def __getitem__(self, int i):
|
||||
raise IndexError("Cannot index into NoneNode.")
|
||||
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
@property
|
||||
def is_none(self):
|
||||
return -1
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return self.graph.doc
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
return tuple()
|
||||
|
||||
def head(self, i=None, label=None):
|
||||
return self
|
||||
|
||||
def tail(self, i=None, label=None):
|
||||
return self
|
||||
|
||||
def walk_heads(self):
|
||||
yield from []
|
||||
|
||||
def walk_tails(self):
|
||||
yield from []
|
||||
|
||||
|
||||
cdef class Graph:
|
||||
"""A set of directed labelled relationships between sets of tokens.
|
||||
|
||||
EXAMPLE:
|
||||
Construction 1
|
||||
>>> graph = Graph(doc, name="srl")
|
||||
|
||||
Construction 2
|
||||
>>> graph = Graph(
|
||||
doc,
|
||||
name="srl",
|
||||
nodes=[(0,), (1, 3), (,)],
|
||||
edges=[(0, 2), (2, 1)]
|
||||
)
|
||||
|
||||
Construction 3
|
||||
>>> graph = Graph(
|
||||
doc,
|
||||
name="srl",
|
||||
nodes=[(0,), (1, 3), (,)],
|
||||
edges=[(2, 0), (0, 1)],
|
||||
labels=["word sense ID 1675", "agent"],
|
||||
weights=[-42.6, -1.7]
|
||||
)
|
||||
>>> assert graph.has_node((0,))
|
||||
>>> assert graph.has_edge((0,), (1,3), label="agent")
|
||||
"""
|
||||
def __init__(self, doc, *, name="", nodes=[], edges=[], labels=None, weights=None):
|
||||
"""Create a Graph object.
|
||||
|
||||
doc (Doc): The Doc object the graph will refer to.
|
||||
name (str): A string name to help identify the graph. Defaults to "".
|
||||
nodes (List[Tuple[int]]): A list of token-index tuples to add to the graph
|
||||
as nodes. Defaults to [].
|
||||
edges (List[Tuple[int, int]]): A list of edges between the provided nodes.
|
||||
Each edge should be a (head, tail) tuple, where `head` and `tail`
|
||||
are integers pointing into the `nodes` list. Defaults to [].
|
||||
labels (Optional[List[str]]): A list of labels for the provided edges.
|
||||
If None, all of the edges specified by the edges argument will have
|
||||
be labelled with the empty string (""). If `labels` is not `None`,
|
||||
it must have the same length as the `edges` argument.
|
||||
weights (Optional[List[float]]): A list of weights for the provided edges.
|
||||
If None, all of the edges specified by the edges argument will
|
||||
have the weight 0.0. If `weights` is not `None`, it must have the
|
||||
same length as the `edges` argument.
|
||||
"""
|
||||
if weights is not None:
|
||||
assert len(weights) == len(edges)
|
||||
else:
|
||||
weights = [0.0] * len(edges)
|
||||
if labels is not None:
|
||||
assert len(labels) == len(edges)
|
||||
else:
|
||||
labels = [""] * len(edges)
|
||||
self.c.node_map = new unordered_map[hash_t, int]()
|
||||
self.c.edge_map = new unordered_map[hash_t, int]()
|
||||
self.c.roots = new unordered_set[int]()
|
||||
self.name = name
|
||||
self.doc_ref = weakref.ref(doc)
|
||||
for node in nodes:
|
||||
self.add_node(node)
|
||||
for (head, tail), label, weight in zip(edges, labels, weights):
|
||||
self.add_edge(
|
||||
Node(self, head),
|
||||
Node(self, tail),
|
||||
label=label,
|
||||
weight=weight
|
||||
)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.c.node_map
|
||||
del self.c.edge_map
|
||||
del self.c.roots
|
||||
|
||||
@property
|
||||
def doc(self) -> "Doc":
|
||||
"""The Doc object the graph refers to."""
|
||||
return self.doc_ref()
|
||||
|
||||
@property
|
||||
def edges(self) -> Generator[Edge]:
|
||||
"""Iterate over the edges in the graph."""
|
||||
for i in range(self.c.edges.size()):
|
||||
yield Edge(self, i)
|
||||
|
||||
@property
|
||||
def nodes(self) -> Generator[Node]:
|
||||
"""Iterate over the nodes in the graph."""
|
||||
for i in range(self.c.nodes.size()):
|
||||
yield Node(self, i)
|
||||
|
||||
def add_edge(self, head, tail, *, label="", weight=None) -> Edge:
|
||||
"""Add an edge to the graph, connecting two groups of tokens.
|
||||
|
||||
If there is already an edge for the (head, tail, label) triple, it will
|
||||
be returned, and no new edge will be created. The weight of the edge
|
||||
will be updated if a weight is specified.
|
||||
"""
|
||||
label_hash = self.doc.vocab.strings.as_int(label)
|
||||
weight_float = weight if weight is not None else 0.0
|
||||
edge_index = add_edge(
|
||||
&self.c,
|
||||
EdgeC(
|
||||
head=self.add_node(head).i,
|
||||
tail=self.add_node(tail).i,
|
||||
label=self.doc.vocab.strings.as_int(label),
|
||||
),
|
||||
weight=weight if weight is not None else 0.0
|
||||
)
|
||||
return Edge(self, edge_index)
|
||||
|
||||
def get_edge(self, head, tail, *, label="") -> Edge:
|
||||
"""Look up an edge in the graph. If the graph has no matching edge,
|
||||
the NoneEdge object is returned.
|
||||
"""
|
||||
head_node = self.get_node(head)
|
||||
if head_node.is_none:
|
||||
return NoneEdge(self)
|
||||
tail_node = self.get_node(tail)
|
||||
if tail_node.is_none:
|
||||
return NoneEdge(self)
|
||||
edge_index = get_edge(
|
||||
&self.c,
|
||||
EdgeC(head=head_node.i, tail=tail_node.i, label=get_string_id(label))
|
||||
)
|
||||
if edge_index < 0:
|
||||
return NoneEdge(self)
|
||||
else:
|
||||
return Edge(self, edge_index)
|
||||
|
||||
def has_edge(self, head, tail, label) -> bool:
|
||||
"""Check whether a (head, tail, label) triple is an edge in the graph."""
|
||||
return not self.get_edge(head, tail, label=label).is_none
|
||||
|
||||
def add_node(self, indices) -> Node:
|
||||
"""Add a node to the graph and return it. Nodes refer to ordered sets
|
||||
of token indices.
|
||||
|
||||
This method is idempotent: if there is already a node for the given
|
||||
indices, it is returned without a new node being created.
|
||||
"""
|
||||
if isinstance(indices, Node):
|
||||
return indices
|
||||
cdef vector[int32_t] node
|
||||
node.reserve(len(indices))
|
||||
for idx in indices:
|
||||
node.push_back(idx)
|
||||
i = add_node(&self.c, node)
|
||||
print("Add node", indices, i)
|
||||
return Node(self, i)
|
||||
|
||||
def get_node(self, indices) -> Node:
|
||||
"""Get a node from the graph, or the NoneNode if there is no node for
|
||||
the given indices.
|
||||
"""
|
||||
if isinstance(indices, Node):
|
||||
return indices
|
||||
cdef vector[int32_t] node
|
||||
node.reserve(len(indices))
|
||||
for idx in indices:
|
||||
node.push_back(idx)
|
||||
node_index = get_node(&self.c, node)
|
||||
if node_index < 0:
|
||||
return NoneNode(self)
|
||||
else:
|
||||
print("Get node", indices, node_index)
|
||||
return Node(self, node_index)
|
||||
|
||||
def has_node(self, tuple indices) -> bool:
|
||||
"""Check whether the graph has a node for the given indices."""
|
||||
return not self.get_node(indices).is_none
|
||||
|
||||
|
||||
cdef int add_edge(GraphC* graph, EdgeC edge, float weight) nogil:
|
||||
key = hash64(&edge, sizeof(edge), 0)
|
||||
it = graph.edge_map.find(key)
|
||||
if it != graph.edge_map.end():
|
||||
edge_index = dereference(it).second
|
||||
graph.weights[edge_index] = weight
|
||||
return edge_index
|
||||
else:
|
||||
edge_index = graph.edges.size()
|
||||
graph.edge_map.insert(pair[hash_t, int](key, edge_index))
|
||||
graph.edges.push_back(edge)
|
||||
if graph.n_tails[edge.head] == 0:
|
||||
graph.first_tail[edge.head] = edge_index
|
||||
if graph.n_heads[edge.tail] == 0:
|
||||
graph.first_head[edge.tail] = edge_index
|
||||
graph.n_tails[edge.head] += 1
|
||||
graph.n_heads[edge.tail] += 1
|
||||
graph.weights.push_back(weight)
|
||||
# If we had the tail marked as a root, remove it.
|
||||
tail_root_index = graph.roots.find(edge.tail)
|
||||
if tail_root_index != graph.roots.end():
|
||||
graph.roots.erase(tail_root_index)
|
||||
return edge_index
|
||||
|
||||
|
||||
cdef int get_edge(const GraphC* graph, EdgeC edge) nogil:
|
||||
key = hash64(&edge, sizeof(edge), 0)
|
||||
it = graph.edge_map.find(key)
|
||||
if it == graph.edge_map.end():
|
||||
return -1
|
||||
else:
|
||||
return dereference(it).second
|
||||
|
||||
|
||||
cdef int has_edge(const GraphC* graph, EdgeC edge) nogil:
|
||||
return get_edge(graph, edge) >= 0
|
||||
|
||||
|
||||
cdef int add_node(GraphC* graph, vector[int32_t]& node) nogil:
|
||||
key = hash64(&node[0], node.size() * sizeof(node[0]), 0)
|
||||
it = graph.node_map.find(key)
|
||||
if it != graph.node_map.end():
|
||||
# Item found. Convert the iterator to an index value.
|
||||
return dereference(it).second
|
||||
else:
|
||||
index = graph.nodes.size()
|
||||
graph.nodes.push_back(node)
|
||||
graph.n_heads.push_back(0)
|
||||
graph.n_tails.push_back(0)
|
||||
graph.first_head.push_back(0)
|
||||
graph.first_tail.push_back(0)
|
||||
graph.roots.insert(index)
|
||||
graph.node_map.insert(pair[hash_t, int](key, index))
|
||||
return index
|
||||
|
||||
|
||||
cdef int get_node(const GraphC* graph, vector[int32_t] node) nogil:
|
||||
key = hash64(&node[0], node.size() * sizeof(node[0]), 0)
|
||||
it = graph.node_map.find(key)
|
||||
if it == graph.node_map.end():
|
||||
return -1
|
||||
else:
|
||||
return dereference(it).second
|
||||
|
||||
|
||||
cdef int has_node(const GraphC* graph, vector[int32_t] node) nogil:
|
||||
return get_node(graph, node) >= 0
|
||||
|
||||
|
||||
cdef int get_head_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
todo = graph.n_heads[node]
|
||||
if todo == 0:
|
||||
return 0
|
||||
output.reserve(output.size() + todo)
|
||||
start = graph.first_head[node]
|
||||
end = graph.edges.size()
|
||||
for i in range(start, end):
|
||||
if todo <= 0:
|
||||
break
|
||||
elif graph.edges[i].tail == node:
|
||||
output.push_back(graph.edges[i].head)
|
||||
todo -= 1
|
||||
return todo
|
||||
|
||||
|
||||
cdef int get_tail_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
todo = graph.n_tails[node]
|
||||
if todo == 0:
|
||||
return 0
|
||||
output.reserve(output.size() + todo)
|
||||
start = graph.first_tail[node]
|
||||
end = graph.edges.size()
|
||||
for i in range(start, end):
|
||||
if todo <= 0:
|
||||
break
|
||||
elif graph.edges[i].head == node:
|
||||
output.push_back(graph.edges[i].tail)
|
||||
todo -= 1
|
||||
return todo
|
||||
|
||||
|
||||
cdef int get_sibling_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
cdef vector[int] heads
|
||||
cdef vector[int] tails
|
||||
get_head_nodes(heads, graph, node)
|
||||
for i in range(heads.size()):
|
||||
get_tail_nodes(tails, graph, heads[i])
|
||||
for j in range(tails.size()):
|
||||
if tails[j] != node:
|
||||
output.push_back(tails[j])
|
||||
tails.clear()
|
||||
return output.size()
|
||||
|
||||
|
||||
cdef int get_head_edges(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
todo = graph.n_heads[node]
|
||||
if todo == 0:
|
||||
return 0
|
||||
output.reserve(output.size() + todo)
|
||||
start = graph.first_head[node]
|
||||
end = graph.edges.size()
|
||||
for i in range(start, end):
|
||||
if todo <= 0:
|
||||
break
|
||||
elif graph.edges[i].tail == node:
|
||||
output.push_back(i)
|
||||
todo -= 1
|
||||
return todo
|
||||
|
||||
|
||||
cdef int get_tail_edges(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
todo = graph.n_tails[node]
|
||||
if todo == 0:
|
||||
return 0
|
||||
output.reserve(output.size() + todo)
|
||||
start = graph.first_tail[node]
|
||||
end = graph.edges.size()
|
||||
for i in range(start, end):
|
||||
if todo <= 0:
|
||||
break
|
||||
elif graph.edges[i].head == node:
|
||||
output.push_back(i)
|
||||
todo -= 1
|
||||
return todo
|
||||
|
||||
|
||||
cdef int walk_head_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
cdef unordered_set[int] seen = unordered_set[int]()
|
||||
get_head_nodes(output, graph, node)
|
||||
seen.insert(node)
|
||||
i = 0
|
||||
while i < output.size():
|
||||
with gil:
|
||||
print("Walk up from", output[i])
|
||||
if seen.find(output[i]) == seen.end():
|
||||
seen.insert(output[i])
|
||||
get_head_nodes(output, graph, output[i])
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
cdef int walk_tail_nodes(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
cdef unordered_set[int] seen = unordered_set[int]()
|
||||
get_tail_nodes(output, graph, node)
|
||||
seen.insert(node)
|
||||
i = 0
|
||||
while i < output.size():
|
||||
if seen.find(output[i]) == seen.end():
|
||||
seen.insert(output[i])
|
||||
get_tail_nodes(output, graph, output[i])
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
cdef int walk_head_edges(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
cdef unordered_set[int] seen = unordered_set[int]()
|
||||
get_head_edges(output, graph, node)
|
||||
seen.insert(node)
|
||||
i = 0
|
||||
while i < output.size():
|
||||
if seen.find(output[i]) == seen.end():
|
||||
seen.insert(output[i])
|
||||
get_head_edges(output, graph, output[i])
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
cdef int walk_tail_edges(vector[int]& output, const GraphC* graph, int node) nogil:
|
||||
cdef unordered_set[int] seen = unordered_set[int]()
|
||||
get_tail_edges(output, graph, node)
|
||||
seen.insert(node)
|
||||
i = 0
|
||||
while i < output.size():
|
||||
if seen.find(output[i]) == seen.end():
|
||||
seen.insert(output[i])
|
||||
get_tail_edges(output, graph, output[i])
|
||||
i += 1
|
||||
return i
|
|
@ -2,18 +2,24 @@ cimport numpy as np
|
|||
|
||||
from .doc cimport Doc
|
||||
from ..typedefs cimport attr_t
|
||||
from ..structs cimport SpanC
|
||||
|
||||
|
||||
cdef class Span:
|
||||
cdef readonly Doc doc
|
||||
cdef readonly int start
|
||||
cdef readonly int end
|
||||
cdef readonly int start_char
|
||||
cdef readonly int end_char
|
||||
cdef readonly attr_t label
|
||||
cdef readonly attr_t kb_id
|
||||
|
||||
cdef SpanC c
|
||||
cdef public _vector
|
||||
cdef public _vector_norm
|
||||
|
||||
@staticmethod
|
||||
cdef inline Span cinit(Doc doc, SpanC span):
|
||||
cdef Span self = Span.__new__(
|
||||
Span,
|
||||
doc,
|
||||
start=span.start,
|
||||
end=span.end
|
||||
)
|
||||
self.c = span
|
||||
return self
|
||||
|
||||
cpdef np.ndarray to_array(self, object features)
|
||||
|
|
|
@ -97,23 +97,23 @@ cdef class Span:
|
|||
if not (0 <= start <= end <= len(doc)):
|
||||
raise IndexError(Errors.E035.format(start=start, end=end, length=len(doc)))
|
||||
self.doc = doc
|
||||
self.start = start
|
||||
self.start_char = self.doc[start].idx if start < self.doc.length else 0
|
||||
self.end = end
|
||||
if end >= 1:
|
||||
self.end_char = self.doc[end - 1].idx + len(self.doc[end - 1])
|
||||
else:
|
||||
self.end_char = 0
|
||||
if isinstance(label, str):
|
||||
label = doc.vocab.strings.add(label)
|
||||
if isinstance(kb_id, str):
|
||||
kb_id = doc.vocab.strings.add(kb_id)
|
||||
if label not in doc.vocab.strings:
|
||||
raise ValueError(Errors.E084.format(label=label))
|
||||
self.label = label
|
||||
|
||||
self.c = SpanC(
|
||||
label=label,
|
||||
kb_id=kb_id,
|
||||
start=start,
|
||||
end=end,
|
||||
start_char=doc[start].idx if start < doc.length else 0,
|
||||
end_char=doc[end - 1].idx + len(doc[end - 1]) if end >= 1 else 0,
|
||||
)
|
||||
self._vector = vector
|
||||
self._vector_norm = vector_norm
|
||||
self.kb_id = kb_id
|
||||
|
||||
def __richcmp__(self, Span other, int op):
|
||||
if other is None:
|
||||
|
@ -123,25 +123,39 @@ cdef class Span:
|
|||
return True
|
||||
# <
|
||||
if op == 0:
|
||||
return self.start_char < other.start_char
|
||||
return self.c.start_char < other.c.start_char
|
||||
# <=
|
||||
elif op == 1:
|
||||
return self.start_char <= other.start_char
|
||||
return self.c.start_char <= other.c.start_char
|
||||
# ==
|
||||
elif op == 2:
|
||||
return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) == (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
|
||||
# Do the cheap comparisons first
|
||||
return (
|
||||
(self.c.start_char == other.c.start_char) and \
|
||||
(self.c.end_char == other.c.end_char) and \
|
||||
(self.c.label == other.c.label) and \
|
||||
(self.c.kb_id == other.c.kb_id) and \
|
||||
(self.doc == other.doc)
|
||||
)
|
||||
# !=
|
||||
elif op == 3:
|
||||
return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) != (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
|
||||
# Do the cheap comparisons first
|
||||
return not (
|
||||
(self.c.start_char == other.c.start_char) and \
|
||||
(self.c.end_char == other.c.end_char) and \
|
||||
(self.c.label == other.c.label) and \
|
||||
(self.c.kb_id == other.c.kb_id) and \
|
||||
(self.doc == other.doc)
|
||||
)
|
||||
# >
|
||||
elif op == 4:
|
||||
return self.start_char > other.start_char
|
||||
return self.c.start_char > other.c.start_char
|
||||
# >=
|
||||
elif op == 5:
|
||||
return self.start_char >= other.start_char
|
||||
return self.c.start_char >= other.c.start_char
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.doc, self.start_char, self.end_char, self.label, self.kb_id))
|
||||
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of tokens in the span.
|
||||
|
@ -150,9 +164,9 @@ cdef class Span:
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/span#len
|
||||
"""
|
||||
if self.end < self.start:
|
||||
if self.c.end < self.c.start:
|
||||
return 0
|
||||
return self.end - self.start
|
||||
return self.c.end - self.c.start
|
||||
|
||||
def __repr__(self):
|
||||
return self.text
|
||||
|
@ -171,10 +185,10 @@ cdef class Span:
|
|||
return Span(self.doc, start + self.start, end + self.start)
|
||||
else:
|
||||
if i < 0:
|
||||
token_i = self.end + i
|
||||
token_i = self.c.end + i
|
||||
else:
|
||||
token_i = self.start + i
|
||||
if self.start <= token_i < self.end:
|
||||
token_i = self.c.start + i
|
||||
if self.c.start <= token_i < self.c.end:
|
||||
return self.doc[token_i]
|
||||
else:
|
||||
raise IndexError(Errors.E1002)
|
||||
|
@ -186,7 +200,7 @@ cdef class Span:
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/span#iter
|
||||
"""
|
||||
for i in range(self.start, self.end):
|
||||
for i in range(self.c.start, self.c.end):
|
||||
yield self.doc[i]
|
||||
|
||||
def __reduce__(self):
|
||||
|
@ -196,7 +210,7 @@ cdef class Span:
|
|||
def _(self):
|
||||
"""Custom extension attributes registered via `set_extension`."""
|
||||
return Underscore(Underscore.span_extensions, self,
|
||||
start=self.start_char, end=self.end_char)
|
||||
start=self.c.start_char, end=self.c.end_char)
|
||||
|
||||
def as_doc(self, *, bint copy_user_data=False):
|
||||
"""Create a `Doc` object with a copy of the `Span`'s data.
|
||||
|
@ -242,7 +256,7 @@ cdef class Span:
|
|||
for i in range(length):
|
||||
# if the HEAD refers to a token outside this span, find a more appropriate ancestor
|
||||
token = self[i]
|
||||
ancestor_i = token.head.i - self.start # span offset
|
||||
ancestor_i = token.head.i - self.c.start # span offset
|
||||
if ancestor_i not in range(length):
|
||||
if DEP in attrs:
|
||||
array[i, attrs.index(DEP)] = dep
|
||||
|
@ -250,7 +264,7 @@ cdef class Span:
|
|||
# try finding an ancestor within this span
|
||||
ancestors = token.ancestors
|
||||
for ancestor in ancestors:
|
||||
ancestor_i = ancestor.i - self.start
|
||||
ancestor_i = ancestor.i - self.c.start
|
||||
if ancestor_i in range(length):
|
||||
array[i, head_col] = ancestor_i - i
|
||||
|
||||
|
@ -279,7 +293,7 @@ cdef class Span:
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/span#get_lca_matrix
|
||||
"""
|
||||
return numpy.asarray(_get_lca_matrix(self.doc, self.start, self.end))
|
||||
return numpy.asarray(_get_lca_matrix(self.doc, self.c.start, self.c.end))
|
||||
|
||||
def similarity(self, other):
|
||||
"""Make a semantic similarity estimate. The default estimate is cosine
|
||||
|
@ -373,10 +387,14 @@ cdef class Span:
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/span#ents
|
||||
"""
|
||||
cdef Span ent
|
||||
ents = []
|
||||
for ent in self.doc.ents:
|
||||
if ent.start >= self.start and ent.end <= self.end:
|
||||
ents.append(ent)
|
||||
if ent.c.start >= self.c.start:
|
||||
if ent.c.end <= self.c.end:
|
||||
ents.append(ent)
|
||||
else:
|
||||
break
|
||||
return ents
|
||||
|
||||
@property
|
||||
|
@ -513,7 +531,7 @@ cdef class Span:
|
|||
# with head==0, i.e. a sentence root. If so, we can return it. The
|
||||
# longer the span, the more likely it contains a sentence root, and
|
||||
# in this case we return in linear time.
|
||||
for i in range(self.start, self.end):
|
||||
for i in range(self.c.start, self.c.end):
|
||||
if self.doc.c[i].head == 0:
|
||||
return self.doc[i]
|
||||
# If we don't have a sentence root, we do something that's not so
|
||||
|
@ -524,15 +542,15 @@ cdef class Span:
|
|||
# think this should be okay.
|
||||
cdef int current_best = self.doc.length
|
||||
cdef int root = -1
|
||||
for i in range(self.start, self.end):
|
||||
if self.start <= (i+self.doc.c[i].head) < self.end:
|
||||
for i in range(self.c.start, self.c.end):
|
||||
if self.c.start <= (i+self.doc.c[i].head) < self.c.end:
|
||||
continue
|
||||
words_to_root = _count_words_to_root(&self.doc.c[i], self.doc.length)
|
||||
if words_to_root < current_best:
|
||||
current_best = words_to_root
|
||||
root = i
|
||||
if root == -1:
|
||||
return self.doc[self.start]
|
||||
return self.doc[self.c.start]
|
||||
else:
|
||||
return self.doc[root]
|
||||
|
||||
|
@ -548,8 +566,8 @@ cdef class Span:
|
|||
the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
"""
|
||||
start_idx += self.start_char
|
||||
end_idx += self.start_char
|
||||
start_idx += self.c.start_char
|
||||
end_idx += self.c.start_char
|
||||
return self.doc.char_span(start_idx, end_idx)
|
||||
|
||||
@property
|
||||
|
@ -628,6 +646,56 @@ cdef class Span:
|
|||
for word in self.rights:
|
||||
yield from word.subtree
|
||||
|
||||
property start:
|
||||
def __get__(self):
|
||||
return self.c.start
|
||||
|
||||
def __set__(self, int start):
|
||||
if start < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.start = start
|
||||
|
||||
property end:
|
||||
def __get__(self):
|
||||
return self.c.end
|
||||
|
||||
def __set__(self, int end):
|
||||
if end < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.end = end
|
||||
|
||||
property start_char:
|
||||
def __get__(self):
|
||||
return self.c.start_char
|
||||
|
||||
def __set__(self, int start_char):
|
||||
if start_char < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.start_char = start_char
|
||||
|
||||
property end_char:
|
||||
def __get__(self):
|
||||
return self.c.end_char
|
||||
|
||||
def __set__(self, int end_char):
|
||||
if end_char < 0:
|
||||
raise IndexError("TODO")
|
||||
self.c.end_char = end_char
|
||||
|
||||
property label:
|
||||
def __get__(self):
|
||||
return self.c.label
|
||||
|
||||
def __set__(self, attr_t label):
|
||||
self.c.label = label
|
||||
|
||||
property kb_id:
|
||||
def __get__(self):
|
||||
return self.c.kb_id
|
||||
|
||||
def __set__(self, attr_t kb_id):
|
||||
self.c.kb_id = kb_id
|
||||
|
||||
property ent_id:
|
||||
"""RETURNS (uint64): The entity ID."""
|
||||
def __get__(self):
|
||||
|
|
10
spacy/tokens/span_group.pxd
Normal file
10
spacy/tokens/span_group.pxd
Normal file
|
@ -0,0 +1,10 @@
|
|||
from libcpp.vector cimport vector
|
||||
from ..structs cimport SpanC
|
||||
|
||||
cdef class SpanGroup:
|
||||
cdef public object _doc_ref
|
||||
cdef public str name
|
||||
cdef public dict attrs
|
||||
cdef vector[SpanC] c
|
||||
|
||||
cdef void push_back(self, SpanC span) nogil
|
183
spacy/tokens/span_group.pyx
Normal file
183
spacy/tokens/span_group.pyx
Normal file
|
@ -0,0 +1,183 @@
|
|||
import weakref
|
||||
import struct
|
||||
import srsly
|
||||
from .span cimport Span
|
||||
from libc.stdint cimport uint64_t, uint32_t, int32_t
|
||||
|
||||
|
||||
cdef class SpanGroup:
|
||||
"""A group of spans that all belong to the same Doc object. The group
|
||||
can be named, and you can attach additional attributes to it. Span groups
|
||||
are generally accessed via the `doc.spans` attribute. The `doc.spans`
|
||||
attribute will convert lists of spans into a `SpanGroup` object for you
|
||||
automatically on assignment.
|
||||
|
||||
Example:
|
||||
Construction 1
|
||||
>>> doc = nlp("Their goi ng home")
|
||||
>>> doc.spans["errors"] = SpanGroup(
|
||||
doc,
|
||||
name="errors",
|
||||
spans=[doc[0:1], doc[2:4]],
|
||||
attrs={"annotator": "matt"}
|
||||
)
|
||||
|
||||
Construction 2
|
||||
>>> doc = nlp("Their goi ng home")
|
||||
>>> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
>>> assert isinstance(doc.spans["errors"], SpanGroup)
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup
|
||||
"""
|
||||
def __init__(self, doc, *, name="", attrs={}, spans=[]):
|
||||
"""Create a SpanGroup.
|
||||
|
||||
doc (Doc): The reference Doc object.
|
||||
name (str): The group name.
|
||||
attrs (Dict[str, Any]): Optional JSON-serializable attributes to attach.
|
||||
spans (Iterable[Span]): The spans to add to the group.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#init
|
||||
"""
|
||||
# We need to make this a weak reference, so that the Doc object can
|
||||
# own the SpanGroup without circular references. We do want to get
|
||||
# the Doc though, because otherwise the API gets annoying.
|
||||
self._doc_ref = weakref.ref(doc)
|
||||
self.name = name
|
||||
self.attrs = dict(attrs) if attrs is not None else {}
|
||||
cdef Span span
|
||||
for span in spans:
|
||||
self.push_back(span.c)
|
||||
|
||||
def __repr__(self):
|
||||
return str(list(self))
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
"""RETURNS (Doc): The reference document.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#doc
|
||||
"""
|
||||
return self._doc_ref()
|
||||
|
||||
@property
|
||||
def has_overlap(self):
|
||||
"""RETURNS (bool): Whether the group contains overlapping spans.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#has_overlap
|
||||
"""
|
||||
if not len(self):
|
||||
return False
|
||||
sorted_spans = list(sorted(self))
|
||||
last_end = sorted_spans[0].end
|
||||
for span in sorted_spans[1:]:
|
||||
if span.start < last_end:
|
||||
return True
|
||||
last_end = span.end
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
"""RETURNS (int): The number of spans in the group.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#len
|
||||
"""
|
||||
return self.c.size()
|
||||
|
||||
def append(self, Span span):
|
||||
"""Add a span to the group. The span must refer to the same Doc
|
||||
object as the span group.
|
||||
|
||||
span (Span): The span to append.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#append
|
||||
"""
|
||||
if span.doc is not self.doc:
|
||||
raise ValueError("Cannot add span to group: refers to different Doc.")
|
||||
self.push_back(span.c)
|
||||
|
||||
def extend(self, spans):
|
||||
"""Add multiple spans to the group. All spans must refer to the same
|
||||
Doc object as the span group.
|
||||
|
||||
spans (Iterable[Span]): The spans to add.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#extend
|
||||
"""
|
||||
cdef Span span
|
||||
for span in spans:
|
||||
self.append(span)
|
||||
|
||||
def __getitem__(self, int i):
|
||||
"""Get a span from the group.
|
||||
|
||||
i (int): The item index.
|
||||
RETURNS (Span): The span at the given index.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#getitem
|
||||
"""
|
||||
cdef int size = self.c.size()
|
||||
if i < -size or i >= size:
|
||||
raise IndexError(f"list index {i} out of range")
|
||||
if i < 0:
|
||||
i += size
|
||||
return Span.cinit(self.doc, self.c[i])
|
||||
|
||||
def to_bytes(self):
|
||||
"""Serialize the SpanGroup's contents to a byte string.
|
||||
|
||||
RETURNS (bytes): The serialized span group.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#to_bytes
|
||||
"""
|
||||
output = {"name": self.name, "attrs": self.attrs, "spans": []}
|
||||
for i in range(self.c.size()):
|
||||
span = self.c[i]
|
||||
# The struct.pack here is probably overkill, but it might help if
|
||||
# you're saving tonnes of spans, and it doesn't really add any
|
||||
# complexity. We do take care to specify little-endian byte order
|
||||
# though, to ensure the message can be loaded back on a different
|
||||
# arch.
|
||||
# Q: uint64_t
|
||||
# q: int64_t
|
||||
# L: uint32_t
|
||||
# l: int32_t
|
||||
output["spans"].append(struct.pack(
|
||||
">QQQllll",
|
||||
span.id,
|
||||
span.kb_id,
|
||||
span.label,
|
||||
span.start,
|
||||
span.end,
|
||||
span.start_char,
|
||||
span.end_char
|
||||
))
|
||||
return srsly.msgpack_dumps(output)
|
||||
|
||||
def from_bytes(self, bytes_data):
|
||||
"""Deserialize the SpanGroup's contents from a byte string.
|
||||
|
||||
bytes_data (bytes): The span group to load.
|
||||
RETURNS (SpanGroup): The deserialized span group.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/spangroup#from_bytes
|
||||
"""
|
||||
msg = srsly.msgpack_loads(bytes_data)
|
||||
self.name = msg["name"]
|
||||
self.attrs = dict(msg["attrs"])
|
||||
self.c.clear()
|
||||
self.c.reserve(len(msg["spans"]))
|
||||
cdef SpanC span
|
||||
for span_data in msg["spans"]:
|
||||
items = struct.unpack(">QQQllll", span_data)
|
||||
span.id = items[0]
|
||||
span.kb_id = items[1]
|
||||
span.label = items[2]
|
||||
span.start = items[3]
|
||||
span.end = items[4]
|
||||
span.start_char = items[5]
|
||||
span.end_char = items[6]
|
||||
self.c.push_back(span)
|
||||
return self
|
||||
|
||||
cdef void push_back(self, SpanC span) nogil:
|
||||
self.c.push_back(span)
|
|
@ -575,6 +575,39 @@ objects, if the entity recognizer has been applied.
|
|||
| ----------- | --------------------------------------------------------------------- |
|
||||
| **RETURNS** | Entities in the document, one `Span` per entity. ~~Tuple[Span, ...]~~ |
|
||||
|
||||
## Doc.spans {#spans tag="property"}
|
||||
|
||||
A dictionary of named span groups, to store and access additional span
|
||||
annotations. You can write to it by assigning a list of [`Span`](/api/span)
|
||||
objects or a [`SpanGroup`](/api/spangroup) to a given key.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------------------------------------ |
|
||||
| **RETURNS** | The span groups assigned to the document. ~~Dict[str, SpanGroup]~~ |
|
||||
|
||||
## Doc.cats {#cats tag="property" model="text classifier"}
|
||||
|
||||
Maps a label to a score for categories applied to the document. Typically set by
|
||||
the [`TextCategorizer`](/api/textcategorizer).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("This is a text about football.")
|
||||
> print(doc.cats)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------------------------------------- |
|
||||
| **RETURNS** | The text categories mapped to scores. ~~Dict[str, float]~~ |
|
||||
|
||||
## Doc.noun_chunks {#noun_chunks tag="property" model="parser"}
|
||||
|
||||
Iterate over the base noun phrases in the document. Yields base noun-phrase
|
||||
|
@ -668,23 +701,22 @@ The L2 norm of the document's vector representation.
|
|||
|
||||
## Attributes {#attributes}
|
||||
|
||||
| Name | Description |
|
||||
| ------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `text` | A string representation of the document text. ~~str~~ |
|
||||
| `text_with_ws` | An alias of `Doc.text`, provided for duck-type compatibility with `Span` and `Token`. ~~str~~ |
|
||||
| `mem` | The document's local memory heap, for all C data it owns. ~~cymem.Pool~~ |
|
||||
| `vocab` | The store of lexical types. ~~Vocab~~ |
|
||||
| `tensor` <Tag variant="new">2</Tag> | Container for dense vector representations. ~~numpy.ndarray~~ |
|
||||
| `cats` <Tag variant="new">2</Tag> | Maps a label to a score for categories applied to the document. The label is a string and the score should be a float. ~~Dict[str, float]~~ |
|
||||
| `user_data` | A generic storage area, for user custom data. ~~Dict[str, Any]~~ |
|
||||
| `lang` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~int~~ |
|
||||
| `lang_` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~str~~ |
|
||||
| `sentiment` | The document's positivity/negativity score, if available. ~~float~~ |
|
||||
| `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ |
|
||||
| `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ |
|
||||
| `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ |
|
||||
| `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ |
|
||||
| `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ |
|
||||
| Name | Description |
|
||||
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `text` | A string representation of the document text. ~~str~~ |
|
||||
| `text_with_ws` | An alias of `Doc.text`, provided for duck-type compatibility with `Span` and `Token`. ~~str~~ |
|
||||
| `mem` | The document's local memory heap, for all C data it owns. ~~cymem.Pool~~ |
|
||||
| `vocab` | The store of lexical types. ~~Vocab~~ |
|
||||
| `tensor` <Tag variant="new">2</Tag> | Container for dense vector representations. ~~numpy.ndarray~~ |
|
||||
| `user_data` | A generic storage area, for user custom data. ~~Dict[str, Any]~~ |
|
||||
| `lang` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~int~~ |
|
||||
| `lang_` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~str~~ |
|
||||
| `sentiment` | The document's positivity/negativity score, if available. ~~float~~ |
|
||||
| `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ |
|
||||
| `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ |
|
||||
| `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ |
|
||||
| `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ |
|
||||
| `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
|
|
185
website/docs/api/spangroup.md
Normal file
185
website/docs/api/spangroup.md
Normal file
|
@ -0,0 +1,185 @@
|
|||
---
|
||||
title: SpanGroup
|
||||
tag: class
|
||||
source: spacy/tokens/span_group.pyx
|
||||
new: 3
|
||||
---
|
||||
|
||||
A group of arbitrary, potentially overlapping [`Span`](/api/span) objects that
|
||||
all belong to the same [`Doc`](/api/doc) object. The group can be named, and you
|
||||
can attach additional attributes to it. Span groups are generally accessed via
|
||||
the [`Doc.spans`](/api/doc#spans) attribute, which will convert lists of spans
|
||||
into a `SpanGroup` object for you automatically on assignment. `SpanGroup`
|
||||
objects behave similar to `list`s, so you can append `Span` objects to them or
|
||||
access a member at a given index.
|
||||
|
||||
## SpanGroup.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Create a `SpanGroup`.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> spans = [doc[0:1], doc[2:4]]
|
||||
>
|
||||
> # Construction 1
|
||||
> from spacy.tokens import SpanGroup
|
||||
>
|
||||
> group = SpanGroup(doc, name="errors", spans=spans, attrs={"annotator": "matt"})
|
||||
> doc.spans["errors"] = group
|
||||
>
|
||||
> # Construction 2
|
||||
> doc.spans["errors"] = spans
|
||||
> assert isinstance(doc.spans["errors"], SpanGroup)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `doc` | The document the span group belongs to. ~~Doc~~ |
|
||||
| _keyword-only_ | |
|
||||
| `name` | The name of the span group. If the span group is created automatically on assignment to `doc.spans`, the key name is used. Defaults to `""`. ~~str~~ |
|
||||
| `attrs` | Optional JSON-serializable attributes to attach to the span group. ~~Dict[str, Any]~~ |
|
||||
| `spans` | The spans to add to the span group. ~~Iterable[Span]~~ |
|
||||
|
||||
## SpanGroup.doc {#doc tag="property"}
|
||||
|
||||
The [`Doc`](/api/doc) object the span group is referring to.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> assert doc.spans["errors"].doc == doc
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------- |
|
||||
| **RETURNS** | The reference document. ~~Doc~~ |
|
||||
|
||||
## SpanGroup.has_overlap {#has_overlap tag="property"}
|
||||
|
||||
Check whether the span group contains overlapping spans.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> assert not doc.spans["errors"].has_overlap
|
||||
> doc.spans["errors"].append(doc[1:2])
|
||||
> assert doc.spans["errors"].has_overlap
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | -------------------------------------------------- |
|
||||
| **RETURNS** | Whether the span group contains overlaps. ~~bool~~ |
|
||||
|
||||
## SpanGroup.\_\_len\_\_ {#len tag="method"}
|
||||
|
||||
Get the number of spans in the group.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> assert len(doc.spans["errors"]) == 2
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ----------------------------------------- |
|
||||
| **RETURNS** | The number of spans in the group. ~~int~~ |
|
||||
|
||||
## SpanGroup.\_\_getitem\_\_ {#getitem tag="method"}
|
||||
|
||||
Get a span from the group.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> span = doc.spans["errors"][1]
|
||||
> assert span.text == "goi ng"
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------- |
|
||||
| `i` | The item index. ~~int~~ |
|
||||
| **RETURNS** | The span at the given index. ~~Span~~ |
|
||||
|
||||
## SpanGroup.append {#append tag="method"}
|
||||
|
||||
Add a [`Span`](/api/span) object to the group. The span must refer to the same
|
||||
[`Doc`](/api/doc) object as the span group.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1]]
|
||||
> doc.spans["errors"].append(doc[2:4])
|
||||
> assert len(doc.spans["errors"]) == 2
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ---------------------------- |
|
||||
| `span` | The span to append. ~~Span~~ |
|
||||
|
||||
## SpanGroup.extend {#extend tag="method"}
|
||||
|
||||
Add multiple [`Span`](/api/span) objects to the group. All spans must refer to
|
||||
the same [`Doc`](/api/doc) object as the span group.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = []
|
||||
> doc.spans["errors"].extend([doc[2:4], doc[0:1]])
|
||||
> assert len(doc.spans["errors"]) == 2
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------- | ------------------------------------ |
|
||||
| `spans` | The spans to add. ~~Iterable[Span]~~ |
|
||||
|
||||
## SpanGroup.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
Serialize the span group to a bytestring.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> group_bytes = doc.spans["errors"].to_bytes()
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | ------------------------------------- |
|
||||
| **RETURNS** | The serialized `SpanGroup`. ~~bytes~~ |
|
||||
|
||||
## SpanGroup.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
Load the span group from a bytestring. Modifies the object in place and returns
|
||||
it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.tokens import SpanGroup
|
||||
>
|
||||
> doc = nlp("Their goi ng home")
|
||||
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
|
||||
> group_bytes = doc.spans["errors"].to_bytes()
|
||||
> new_group = SpanGroup()
|
||||
> new_group.from_bytes(group_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | ------------------------------------- |
|
||||
| `bytes_data` | The data to load from. ~~bytes~~ |
|
||||
| **RETURNS** | The `SpanGroup` object. ~~SpanGroup~~ |
|
|
@ -18,15 +18,16 @@ It also orchestrates training and serialization.
|
|||
|
||||
### Container objects {#architecture-containers}
|
||||
|
||||
| Name | Description |
|
||||
| --------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`Doc`](/api/doc) | A container for accessing linguistic annotations. |
|
||||
| [`DocBin`](/api/docbin) | A collection of `Doc` objects for efficient binary serialization. Also used for [training data](/api/data-formats#binary-training). |
|
||||
| [`Example`](/api/example) | A collection of training annotations, containing two `Doc` objects: the reference data and the predictions. |
|
||||
| [`Language`](/api/language) | Processing class that turns text into `Doc` objects. Different languages implement their own subclasses of it. The variable is typically called `nlp`. |
|
||||
| [`Lexeme`](/api/lexeme) | An entry in the vocabulary. It's a word type with no context, as opposed to a word token. It therefore has no part-of-speech tag, dependency parse etc. |
|
||||
| [`Span`](/api/span) | A slice from a `Doc` object. |
|
||||
| [`Token`](/api/token) | An individual token — i.e. a word, punctuation symbol, whitespace, etc. |
|
||||
| Name | Description |
|
||||
| ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`Doc`](/api/doc) | A container for accessing linguistic annotations. |
|
||||
| [`DocBin`](/api/docbin) | A collection of `Doc` objects for efficient binary serialization. Also used for [training data](/api/data-formats#binary-training). |
|
||||
| [`Example`](/api/example) | A collection of training annotations, containing two `Doc` objects: the reference data and the predictions. |
|
||||
| [`Language`](/api/language) | Processing class that turns text into `Doc` objects. Different languages implement their own subclasses of it. The variable is typically called `nlp`. |
|
||||
| [`Lexeme`](/api/lexeme) | An entry in the vocabulary. It's a word type with no context, as opposed to a word token. It therefore has no part-of-speech tag, dependency parse etc. |
|
||||
| [`Span`](/api/span) | A slice from a `Doc` object. |
|
||||
| [`SpanGroup`](/api/spangroup) | A named collection of spans belonging to a `Doc`. |
|
||||
| [`Token`](/api/token) | An individual token — i.e. a word, punctuation symbol, whitespace, etc. |
|
||||
|
||||
### Processing pipeline {#architecture-pipeline}
|
||||
|
||||
|
|
|
@ -501,7 +501,7 @@ format for documenting argument and return types.
|
|||
[`AttributeRuler`](/api/attributeruler),
|
||||
[`SentenceRecognizer`](/api/sentencerecognizer),
|
||||
[`DependencyMatcher`](/api/dependencymatcher), [`TrainablePipe`](/api/pipe),
|
||||
[`Corpus`](/api/corpus)
|
||||
[`Corpus`](/api/corpus), [`SpanGroup`](/api/spangroup),
|
||||
|
||||
</Infobox>
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@
|
|||
{ "text": "Language", "url": "/api/language" },
|
||||
{ "text": "Lexeme", "url": "/api/lexeme" },
|
||||
{ "text": "Span", "url": "/api/span" },
|
||||
{ "text": "SpanGroup", "url": "/api/spangroup" },
|
||||
{ "text": "Token", "url": "/api/token" }
|
||||
]
|
||||
},
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
"Doc": "/api/doc",
|
||||
"Token": "/api/token",
|
||||
"Span": "/api/span",
|
||||
"SpanGroup": "/api/spangroup",
|
||||
"Lexeme": "/api/lexeme",
|
||||
"Example": "/api/example",
|
||||
"Alignment": "/api/example#alignment-object",
|
||||
|
|
Loading…
Reference in New Issue
Block a user