mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Fix tokens/.
This commit is contained in:
parent
1ac29fd8df
commit
9f62a49ebb
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
@ -47,7 +47,7 @@ jobs:
|
||||||
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
|
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
|
||||||
- name: cython-lint
|
- name: cython-lint
|
||||||
run: |
|
run: |
|
||||||
python -m pip install cython-lint -c requirements.txt --ignore E501,W291
|
python -m pip install cython-lint -c requirements.txt --ignore E501,W291,E266
|
||||||
cython-lint spacy
|
cython-lint spacy
|
||||||
|
|
||||||
tests:
|
tests:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# cython: infer_types=True, bounds_check=False, profile=True
|
# cython: infer_types=True, bounds_check=False, profile=True
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from libc.stdlib cimport free, malloc
|
from libc.string cimport memset
|
||||||
from libc.string cimport memcpy, memset
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from thinc.api import get_array_module
|
from thinc.api import get_array_module
|
||||||
|
@ -10,7 +9,7 @@ from ..attrs cimport MORPH, NORM
|
||||||
from ..lexeme cimport EMPTY_LEXEME, Lexeme
|
from ..lexeme cimport EMPTY_LEXEME, Lexeme
|
||||||
from ..structs cimport LexemeC, TokenC
|
from ..structs cimport LexemeC, TokenC
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
from .doc cimport Doc, set_children_from_heads, token_by_end, token_by_start
|
from .doc cimport Doc, set_children_from_heads, token_by_start
|
||||||
from .span cimport Span
|
from .span cimport Span
|
||||||
from .token cimport Token
|
from .token cimport Token
|
||||||
|
|
||||||
|
@ -147,7 +146,7 @@ def _merge(Doc doc, merges):
|
||||||
syntactic root of the span.
|
syntactic root of the span.
|
||||||
RETURNS (Token): The first newly merged token.
|
RETURNS (Token): The first newly merged token.
|
||||||
"""
|
"""
|
||||||
cdef int i, merge_index, start, end, token_index, current_span_index, current_offset, offset, span_index
|
cdef int i, merge_index, start, token_index, current_span_index, current_offset, offset, span_index
|
||||||
cdef Span span
|
cdef Span span
|
||||||
cdef const LexemeC* lex
|
cdef const LexemeC* lex
|
||||||
cdef TokenC* token
|
cdef TokenC* token
|
||||||
|
@ -165,7 +164,6 @@ def _merge(Doc doc, merges):
|
||||||
merges.sort(key=_get_start)
|
merges.sort(key=_get_start)
|
||||||
for merge_index, (span, attributes) in enumerate(merges):
|
for merge_index, (span, attributes) in enumerate(merges):
|
||||||
start = span.start
|
start = span.start
|
||||||
end = span.end
|
|
||||||
spans.append(span)
|
spans.append(span)
|
||||||
# House the new merged token where it starts
|
# House the new merged token where it starts
|
||||||
token = &doc.c[start]
|
token = &doc.c[start]
|
||||||
|
@ -203,8 +201,9 @@ def _merge(Doc doc, merges):
|
||||||
# for the merged region. To do this, we create a boolean array indicating
|
# for the merged region. To do this, we create a boolean array indicating
|
||||||
# whether the row is to be deleted, then use numpy.delete
|
# whether the row is to be deleted, then use numpy.delete
|
||||||
if doc.tensor is not None and doc.tensor.size != 0:
|
if doc.tensor is not None and doc.tensor.size != 0:
|
||||||
doc.tensor = _resize_tensor(doc.tensor,
|
doc.tensor = _resize_tensor(
|
||||||
[(m[0].start, m[0].end) for m in merges])
|
doc.tensor, [(m[0].start, m[0].end) for m in merges]
|
||||||
|
)
|
||||||
# Memorize span roots and sets dependencies of the newly merged
|
# Memorize span roots and sets dependencies of the newly merged
|
||||||
# tokens to the dependencies of their roots.
|
# tokens to the dependencies of their roots.
|
||||||
span_roots = []
|
span_roots = []
|
||||||
|
@ -267,11 +266,11 @@ def _merge(Doc doc, merges):
|
||||||
span_index += 1
|
span_index += 1
|
||||||
if span_index < len(spans) and i == spans[span_index].start:
|
if span_index < len(spans) and i == spans[span_index].start:
|
||||||
# First token in a span
|
# First token in a span
|
||||||
doc.c[i - offset] = doc.c[i] # move token to its place
|
doc.c[i - offset] = doc.c[i] # move token to its place
|
||||||
offset += (spans[span_index].end - spans[span_index].start) - 1
|
offset += (spans[span_index].end - spans[span_index].start) - 1
|
||||||
in_span = True
|
in_span = True
|
||||||
if not in_span:
|
if not in_span:
|
||||||
doc.c[i - offset] = doc.c[i] # move token to its place
|
doc.c[i - offset] = doc.c[i] # move token to its place
|
||||||
|
|
||||||
for i in range(doc.length - offset, doc.length):
|
for i in range(doc.length - offset, doc.length):
|
||||||
memset(&doc.c[i], 0, sizeof(TokenC))
|
memset(&doc.c[i], 0, sizeof(TokenC))
|
||||||
|
@ -345,7 +344,11 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
||||||
if to_process_tensor:
|
if to_process_tensor:
|
||||||
xp = get_array_module(doc.tensor)
|
xp = get_array_module(doc.tensor)
|
||||||
if xp is numpy:
|
if xp is numpy:
|
||||||
doc.tensor = xp.append(doc.tensor, xp.zeros((nb_subtokens,doc.tensor.shape[1]), dtype="float32"), axis=0)
|
doc.tensor = xp.append(
|
||||||
|
doc.tensor,
|
||||||
|
xp.zeros((nb_subtokens, doc.tensor.shape[1]), dtype="float32"),
|
||||||
|
axis=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
shape = (doc.tensor.shape[0] + nb_subtokens, doc.tensor.shape[1])
|
shape = (doc.tensor.shape[0] + nb_subtokens, doc.tensor.shape[1])
|
||||||
resized_array = xp.zeros(shape, dtype="float32")
|
resized_array = xp.zeros(shape, dtype="float32")
|
||||||
|
@ -367,7 +370,8 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
||||||
token.norm = 0 # reset norm
|
token.norm = 0 # reset norm
|
||||||
if to_process_tensor:
|
if to_process_tensor:
|
||||||
# setting the tensors of the split tokens to array of zeros
|
# setting the tensors of the split tokens to array of zeros
|
||||||
doc.tensor[token_index + i:token_index + i + 1] = xp.zeros((1,doc.tensor.shape[1]), dtype="float32")
|
doc.tensor[token_index + i:token_index + i + 1] = \
|
||||||
|
xp.zeros((1, doc.tensor.shape[1]), dtype="float32")
|
||||||
# Update the character offset of the subtokens
|
# Update the character offset of the subtokens
|
||||||
if i != 0:
|
if i != 0:
|
||||||
token.idx = orig_token.idx + idx_offset
|
token.idx = orig_token.idx + idx_offset
|
||||||
|
@ -455,7 +459,6 @@ def normalize_token_attrs(Vocab vocab, attrs):
|
||||||
def set_token_attrs(Token py_token, attrs):
|
def set_token_attrs(Token py_token, attrs):
|
||||||
cdef TokenC* token = py_token.c
|
cdef TokenC* token = py_token.c
|
||||||
cdef const LexemeC* lex = token.lex
|
cdef const LexemeC* lex = token.lex
|
||||||
cdef Doc doc = py_token.doc
|
|
||||||
# Assign attributes
|
# Assign attributes
|
||||||
for attr_name, attr_value in attrs.items():
|
for attr_name, attr_value in attrs.items():
|
||||||
if attr_name == "_": # Set extension attributes
|
if attr_name == "_": # Set extension attributes
|
||||||
|
|
|
@ -31,7 +31,7 @@ cdef int token_by_start(const TokenC* tokens, int length, int start_char) except
|
||||||
cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
|
cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
|
||||||
|
|
||||||
|
|
||||||
cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
|
cdef int [:, :] _get_lca_matrix(Doc, int start, int end)
|
||||||
|
|
||||||
|
|
||||||
cdef class Doc:
|
cdef class Doc:
|
||||||
|
@ -61,7 +61,6 @@ cdef class Doc:
|
||||||
cdef int length
|
cdef int length
|
||||||
cdef int max_length
|
cdef int max_length
|
||||||
|
|
||||||
|
|
||||||
cdef public object noun_chunks_iterator
|
cdef public object noun_chunks_iterator
|
||||||
|
|
||||||
cdef object __weakref__
|
cdef object __weakref__
|
||||||
|
|
|
@ -43,14 +43,13 @@ from ..attrs cimport (
|
||||||
attr_id_t,
|
attr_id_t,
|
||||||
)
|
)
|
||||||
from ..lexeme cimport EMPTY_LEXEME, Lexeme
|
from ..lexeme cimport EMPTY_LEXEME, Lexeme
|
||||||
from ..typedefs cimport attr_t, flags_t
|
from ..typedefs cimport attr_t
|
||||||
from .token cimport Token
|
from .token cimport Token
|
||||||
|
|
||||||
from .. import parts_of_speech, schemas, util
|
from .. import parts_of_speech, schemas, util
|
||||||
from ..attrs import IDS, intify_attr
|
from ..attrs import IDS, intify_attr
|
||||||
from ..compat import copy_reg, pickle
|
from ..compat import copy_reg
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..morphology import Morphology
|
|
||||||
from ..util import get_words_and_spaces
|
from ..util import get_words_and_spaces
|
||||||
from ._retokenize import Retokenizer
|
from ._retokenize import Retokenizer
|
||||||
from .underscore import Underscore, get_ext_args
|
from .underscore import Underscore, get_ext_args
|
||||||
|
@ -784,7 +783,7 @@ cdef class Doc:
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. Test basic data-driven ORTH gazetteer
|
# 1. Test basic data-driven ORTH gazetteer
|
||||||
# 2. Test more nuanced date and currency regex
|
# 2. Test more nuanced date and currency regex
|
||||||
cdef attr_t entity_type, kb_id, ent_id
|
cdef attr_t kb_id, ent_id
|
||||||
cdef int ent_start, ent_end
|
cdef int ent_start, ent_end
|
||||||
ent_spans = []
|
ent_spans = []
|
||||||
for ent_info in ents:
|
for ent_info in ents:
|
||||||
|
@ -987,7 +986,6 @@ cdef class Doc:
|
||||||
>>> np_array = doc.to_array([LOWER, POS, ENT_TYPE, IS_ALPHA])
|
>>> np_array = doc.to_array([LOWER, POS, ENT_TYPE, IS_ALPHA])
|
||||||
"""
|
"""
|
||||||
cdef int i, j
|
cdef int i, j
|
||||||
cdef attr_id_t feature
|
|
||||||
cdef np.ndarray[attr_t, ndim=2] output
|
cdef np.ndarray[attr_t, ndim=2] output
|
||||||
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
||||||
# See also #3064
|
# See also #3064
|
||||||
|
@ -999,8 +997,10 @@ cdef class Doc:
|
||||||
py_attr_ids = [py_attr_ids]
|
py_attr_ids = [py_attr_ids]
|
||||||
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
||||||
try:
|
try:
|
||||||
py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
py_attr_ids = [
|
||||||
for id_ in py_attr_ids]
|
(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||||
|
for id_ in py_attr_ids
|
||||||
|
]
|
||||||
except KeyError as msg:
|
except KeyError as msg:
|
||||||
keys = [k for k in IDS.keys() if not k.startswith("FLAG")]
|
keys = [k for k in IDS.keys() if not k.startswith("FLAG")]
|
||||||
raise KeyError(Errors.E983.format(dict="IDS", key=msg, keys=keys)) from None
|
raise KeyError(Errors.E983.format(dict="IDS", key=msg, keys=keys)) from None
|
||||||
|
@ -1030,8 +1030,6 @@ cdef class Doc:
|
||||||
DOCS: https://spacy.io/api/doc#count_by
|
DOCS: https://spacy.io/api/doc#count_by
|
||||||
"""
|
"""
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef attr_t attr
|
|
||||||
cdef size_t count
|
|
||||||
|
|
||||||
if counts is None:
|
if counts is None:
|
||||||
counts = Counter()
|
counts = Counter()
|
||||||
|
@ -1093,7 +1091,6 @@ cdef class Doc:
|
||||||
cdef int i, col
|
cdef int i, col
|
||||||
cdef int32_t abs_head_index
|
cdef int32_t abs_head_index
|
||||||
cdef attr_id_t attr_id
|
cdef attr_id_t attr_id
|
||||||
cdef TokenC* tokens = self.c
|
|
||||||
cdef int length = len(array)
|
cdef int length = len(array)
|
||||||
if length != len(self):
|
if length != len(self):
|
||||||
raise ValueError(Errors.E971.format(array_length=length, doc_length=len(self)))
|
raise ValueError(Errors.E971.format(array_length=length, doc_length=len(self)))
|
||||||
|
@ -1225,7 +1222,7 @@ cdef class Doc:
|
||||||
span.label,
|
span.label,
|
||||||
span.kb_id,
|
span.kb_id,
|
||||||
span.id,
|
span.id,
|
||||||
span.text, # included as a check
|
span.text, # included as a check
|
||||||
))
|
))
|
||||||
char_offset += len(doc.text)
|
char_offset += len(doc.text)
|
||||||
if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_):
|
if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_):
|
||||||
|
@ -1508,7 +1505,6 @@ cdef class Doc:
|
||||||
attributes are inherited from the syntactic root of the span.
|
attributes are inherited from the syntactic root of the span.
|
||||||
RETURNS (Token): The first newly merged token.
|
RETURNS (Token): The first newly merged token.
|
||||||
"""
|
"""
|
||||||
cdef str tag, lemma, ent_type
|
|
||||||
attr_len = len(attributes)
|
attr_len = len(attributes)
|
||||||
span_len = len(spans)
|
span_len = len(spans)
|
||||||
if not attr_len == span_len:
|
if not attr_len == span_len:
|
||||||
|
@ -1624,7 +1620,6 @@ cdef class Doc:
|
||||||
for token in char_span[1:]:
|
for token in char_span[1:]:
|
||||||
token.is_sent_start = False
|
token.is_sent_start = False
|
||||||
|
|
||||||
|
|
||||||
for span_group in doc_json.get("spans", {}):
|
for span_group in doc_json.get("spans", {}):
|
||||||
spans = []
|
spans = []
|
||||||
for span in doc_json["spans"][span_group]:
|
for span in doc_json["spans"][span_group]:
|
||||||
|
@ -1656,7 +1651,7 @@ cdef class Doc:
|
||||||
start = token_by_char(self.c, self.length, token_data["start"])
|
start = token_by_char(self.c, self.length, token_data["start"])
|
||||||
value = token_data["value"]
|
value = token_data["value"]
|
||||||
self[start]._.set(token_attr, value)
|
self[start]._.set(token_attr, value)
|
||||||
|
|
||||||
for span_attr in doc_json.get("underscore_span", {}):
|
for span_attr in doc_json.get("underscore_span", {}):
|
||||||
if not Span.has_extension(span_attr):
|
if not Span.has_extension(span_attr):
|
||||||
Span.set_extension(span_attr)
|
Span.set_extension(span_attr)
|
||||||
|
@ -1698,7 +1693,7 @@ cdef class Doc:
|
||||||
token_data["dep"] = token.dep_
|
token_data["dep"] = token.dep_
|
||||||
token_data["head"] = token.head.i
|
token_data["head"] = token.head.i
|
||||||
data["tokens"].append(token_data)
|
data["tokens"].append(token_data)
|
||||||
|
|
||||||
if self.spans:
|
if self.spans:
|
||||||
data["spans"] = {}
|
data["spans"] = {}
|
||||||
for span_group in self.spans:
|
for span_group in self.spans:
|
||||||
|
@ -1769,7 +1764,6 @@ cdef class Doc:
|
||||||
output.fill(255)
|
output.fill(255)
|
||||||
cdef int i, j, start_idx, end_idx
|
cdef int i, j, start_idx, end_idx
|
||||||
cdef bytes byte_string
|
cdef bytes byte_string
|
||||||
cdef unsigned char utf8_char
|
|
||||||
for i, byte_string in enumerate(byte_strings):
|
for i, byte_string in enumerate(byte_strings):
|
||||||
j = 0
|
j = 0
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
@ -1822,8 +1816,6 @@ cdef int token_by_char(const TokenC* tokens, int length, int char_idx) except -2
|
||||||
|
|
||||||
cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1:
|
cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1:
|
||||||
# note: end is exclusive
|
# note: end is exclusive
|
||||||
cdef TokenC* head
|
|
||||||
cdef TokenC* child
|
|
||||||
cdef int i
|
cdef int i
|
||||||
# Set number of left/right children to 0. We'll increment it in the loops.
|
# Set number of left/right children to 0. We'll increment it in the loops.
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
|
@ -1923,7 +1915,7 @@ cdef int _get_tokens_lca(Token token_j, Token token_k):
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
|
||||||
cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
cdef int [:, :] _get_lca_matrix(Doc doc, int start, int end):
|
||||||
"""Given a doc and a start and end position defining a set of contiguous
|
"""Given a doc and a start and end position defining a set of contiguous
|
||||||
tokens within it, returns a matrix of Lowest Common Ancestors (LCA), where
|
tokens within it, returns a matrix of Lowest Common Ancestors (LCA), where
|
||||||
LCA[i, j] is the index of the lowest common ancestor among token i and j.
|
LCA[i, j] is the index of the lowest common ancestor among token i and j.
|
||||||
|
@ -1936,7 +1928,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
||||||
RETURNS (int [:, :]): memoryview of numpy.array[ndim=2, dtype=numpy.int32],
|
RETURNS (int [:, :]): memoryview of numpy.array[ndim=2, dtype=numpy.int32],
|
||||||
with shape (n, n), where n = len(doc).
|
with shape (n, n), where n = len(doc).
|
||||||
"""
|
"""
|
||||||
cdef int [:,:] lca_matrix
|
cdef int [:, :] lca_matrix
|
||||||
cdef int j, k
|
cdef int j, k
|
||||||
n_tokens= end - start
|
n_tokens= end - start
|
||||||
lca_mat = numpy.empty((n_tokens, n_tokens), dtype=numpy.int32)
|
lca_mat = numpy.empty((n_tokens, n_tokens), dtype=numpy.int32)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Generator, List, Tuple
|
||||||
|
|
||||||
cimport cython
|
cimport cython
|
||||||
from cython.operator cimport dereference
|
from cython.operator cimport dereference
|
||||||
from libc.stdint cimport int32_t, int64_t
|
from libc.stdint cimport int32_t
|
||||||
from libcpp.pair cimport pair
|
from libcpp.pair cimport pair
|
||||||
from libcpp.unordered_map cimport unordered_map
|
from libcpp.unordered_map cimport unordered_map
|
||||||
from libcpp.unordered_set cimport unordered_set
|
from libcpp.unordered_set cimport unordered_set
|
||||||
|
@ -11,7 +11,6 @@ from libcpp.unordered_set cimport unordered_set
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
from preshed.maps cimport map_get_unless_missing
|
|
||||||
|
|
||||||
from .. import Errors
|
from .. import Errors
|
||||||
|
|
||||||
|
@ -28,7 +27,7 @@ from .token import Token
|
||||||
cdef class Edge:
|
cdef class Edge:
|
||||||
cdef readonly Graph graph
|
cdef readonly Graph graph
|
||||||
cdef readonly int i
|
cdef readonly int i
|
||||||
|
|
||||||
def __init__(self, Graph graph, int i):
|
def __init__(self, Graph graph, int i):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.i = i
|
self.i = i
|
||||||
|
@ -44,7 +43,7 @@ cdef class Edge:
|
||||||
@property
|
@property
|
||||||
def head(self) -> "Node":
|
def head(self) -> "Node":
|
||||||
return Node(self.graph, self.graph.c.edges[self.i].head)
|
return Node(self.graph, self.graph.c.edges[self.i].head)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tail(self) -> "Tail":
|
def tail(self) -> "Tail":
|
||||||
return Node(self.graph, self.graph.c.edges[self.i].tail)
|
return Node(self.graph, self.graph.c.edges[self.i].tail)
|
||||||
|
@ -70,7 +69,7 @@ cdef class Node:
|
||||||
def __init__(self, Graph graph, int i):
|
def __init__(self, Graph graph, int i):
|
||||||
"""A reference to a node of an annotation graph. Each node is made up of
|
"""A reference to a node of an annotation graph. Each node is made up of
|
||||||
an ordered set of zero or more token indices.
|
an ordered set of zero or more token indices.
|
||||||
|
|
||||||
Node references are usually created by the Graph object itself, or from
|
Node references are usually created by the Graph object itself, or from
|
||||||
the Node or Edge objects. You usually won't need to instantiate this
|
the Node or Edge objects. You usually won't need to instantiate this
|
||||||
class yourself.
|
class yourself.
|
||||||
|
@ -109,13 +108,13 @@ cdef class Node:
|
||||||
@property
|
@property
|
||||||
def is_none(self) -> bool:
|
def is_none(self) -> bool:
|
||||||
"""Whether the node is a special value, indicating 'none'.
|
"""Whether the node is a special value, indicating 'none'.
|
||||||
|
|
||||||
The NoneNode type is returned by the Graph, Edge and Node objects when
|
The NoneNode type is returned by the Graph, Edge and Node objects when
|
||||||
there is no match to a query. It has the same API as Node, but it always
|
there is no match to a query. It has the same API as Node, but it always
|
||||||
returns NoneNode, NoneEdge or empty lists for its queries.
|
returns NoneNode, NoneEdge or empty lists for its queries.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self) -> "Doc":
|
def doc(self) -> "Doc":
|
||||||
"""The Doc object that the graph refers to."""
|
"""The Doc object that the graph refers to."""
|
||||||
|
@ -130,19 +129,19 @@ cdef class Node:
|
||||||
def head(self, i=None, label=None) -> "Node":
|
def head(self, i=None, label=None) -> "Node":
|
||||||
"""Get the head of the first matching edge, searching by index, label,
|
"""Get the head of the first matching edge, searching by index, label,
|
||||||
both or neither.
|
both or neither.
|
||||||
|
|
||||||
For instance, `node.head(i=1)` will get the head of the second edge that
|
For instance, `node.head(i=1)` will get the head of the second edge that
|
||||||
this node is a tail of. `node.head(i=1, label="ARG0")` will further
|
this node is a tail of. `node.head(i=1, label="ARG0")` will further
|
||||||
check that the second edge has the label `"ARG0"`.
|
check that the second edge has the label `"ARG0"`.
|
||||||
|
|
||||||
If no matching node can be found, the graph's NoneNode is returned.
|
If no matching node can be found, the graph's NoneNode is returned.
|
||||||
"""
|
"""
|
||||||
return self.headed(i=i, label=label)
|
return self.headed(i=i, label=label)
|
||||||
|
|
||||||
def tail(self, i=None, label=None) -> "Node":
|
def tail(self, i=None, label=None) -> "Node":
|
||||||
"""Get the tail of the first matching edge, searching by index, label,
|
"""Get the tail of the first matching edge, searching by index, label,
|
||||||
both or neither.
|
both or neither.
|
||||||
|
|
||||||
If no matching node can be found, the graph's NoneNode is returned.
|
If no matching node can be found, the graph's NoneNode is returned.
|
||||||
"""
|
"""
|
||||||
return self.tailed(i=i, label=label).tail
|
return self.tailed(i=i, label=label).tail
|
||||||
|
@ -171,7 +170,7 @@ cdef class Node:
|
||||||
cdef vector[int] edge_indices
|
cdef vector[int] edge_indices
|
||||||
self._find_edges(edge_indices, "head", label)
|
self._find_edges(edge_indices, "head", label)
|
||||||
return [Node(self.graph, self.graph.c.edges[i].head) for i in edge_indices]
|
return [Node(self.graph, self.graph.c.edges[i].head) for i in edge_indices]
|
||||||
|
|
||||||
def tails(self, label=None) -> List["Node"]:
|
def tails(self, label=None) -> List["Node"]:
|
||||||
"""Find all matching tails of this node."""
|
"""Find all matching tails of this node."""
|
||||||
cdef vector[int] edge_indices
|
cdef vector[int] edge_indices
|
||||||
|
@ -200,7 +199,7 @@ cdef class Node:
|
||||||
return NoneEdge(self.graph)
|
return NoneEdge(self.graph)
|
||||||
else:
|
else:
|
||||||
return Edge(self.graph, idx)
|
return Edge(self.graph, idx)
|
||||||
|
|
||||||
def tailed(self, i=None, label=None) -> Edge:
|
def tailed(self, i=None, label=None) -> Edge:
|
||||||
"""Find the first matching edge tailed by this node.
|
"""Find the first matching edge tailed by this node.
|
||||||
If no matching edge can be found, the graph's NoneEdge is returned.
|
If no matching edge can be found, the graph's NoneEdge is returned.
|
||||||
|
@ -283,7 +282,7 @@ cdef class NoneEdge(Edge):
|
||||||
def __init__(self, graph):
|
def __init__(self, graph):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.i = -1
|
self.i = -1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc(self) -> "Doc":
|
def doc(self) -> "Doc":
|
||||||
return self.graph.doc
|
return self.graph.doc
|
||||||
|
@ -291,7 +290,7 @@ cdef class NoneEdge(Edge):
|
||||||
@property
|
@property
|
||||||
def head(self) -> "NoneNode":
|
def head(self) -> "NoneNode":
|
||||||
return NoneNode(self.graph)
|
return NoneNode(self.graph)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tail(self) -> "NoneNode":
|
def tail(self) -> "NoneNode":
|
||||||
return NoneNode(self.graph)
|
return NoneNode(self.graph)
|
||||||
|
@ -319,7 +318,7 @@ cdef class NoneNode(Node):
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_none(self):
|
def is_none(self):
|
||||||
return -1
|
return -1
|
||||||
|
@ -340,14 +339,14 @@ cdef class NoneNode(Node):
|
||||||
|
|
||||||
def walk_heads(self):
|
def walk_heads(self):
|
||||||
yield from []
|
yield from []
|
||||||
|
|
||||||
def walk_tails(self):
|
def walk_tails(self):
|
||||||
yield from []
|
yield from []
|
||||||
|
|
||||||
|
|
||||||
cdef class Graph:
|
cdef class Graph:
|
||||||
"""A set of directed labelled relationships between sets of tokens.
|
"""A set of directed labelled relationships between sets of tokens.
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
Construction 1
|
Construction 1
|
||||||
>>> graph = Graph(doc, name="srl")
|
>>> graph = Graph(doc, name="srl")
|
||||||
|
@ -372,7 +371,9 @@ cdef class Graph:
|
||||||
>>> assert graph.has_node((0,))
|
>>> assert graph.has_node((0,))
|
||||||
>>> assert graph.has_edge((0,), (1,3), label="agent")
|
>>> assert graph.has_edge((0,), (1,3), label="agent")
|
||||||
"""
|
"""
|
||||||
def __init__(self, doc, *, name="", nodes=[], edges=[], labels=None, weights=None):
|
def __init__(
|
||||||
|
self, doc, *, name="", nodes=[], edges=[], labels=None, weights=None # no-cython-lint
|
||||||
|
):
|
||||||
"""Create a Graph object.
|
"""Create a Graph object.
|
||||||
|
|
||||||
doc (Doc): The Doc object the graph will refer to.
|
doc (Doc): The Doc object the graph will refer to.
|
||||||
|
@ -438,13 +439,11 @@ cdef class Graph:
|
||||||
|
|
||||||
def add_edge(self, head, tail, *, label="", weight=None) -> Edge:
|
def add_edge(self, head, tail, *, label="", weight=None) -> Edge:
|
||||||
"""Add an edge to the graph, connecting two groups of tokens.
|
"""Add an edge to the graph, connecting two groups of tokens.
|
||||||
|
|
||||||
If there is already an edge for the (head, tail, label) triple, it will
|
If there is already an edge for the (head, tail, label) triple, it will
|
||||||
be returned, and no new edge will be created. The weight of the edge
|
be returned, and no new edge will be created. The weight of the edge
|
||||||
will be updated if a weight is specified.
|
will be updated if a weight is specified.
|
||||||
"""
|
"""
|
||||||
label_hash = self.doc.vocab.strings.as_int(label)
|
|
||||||
weight_float = weight if weight is not None else 0.0
|
|
||||||
edge_index = add_edge(
|
edge_index = add_edge(
|
||||||
&self.c,
|
&self.c,
|
||||||
EdgeC(
|
EdgeC(
|
||||||
|
@ -478,11 +477,11 @@ cdef class Graph:
|
||||||
def has_edge(self, head, tail, label) -> bool:
|
def has_edge(self, head, tail, label) -> bool:
|
||||||
"""Check whether a (head, tail, label) triple is an edge in the graph."""
|
"""Check whether a (head, tail, label) triple is an edge in the graph."""
|
||||||
return not self.get_edge(head, tail, label=label).is_none
|
return not self.get_edge(head, tail, label=label).is_none
|
||||||
|
|
||||||
def add_node(self, indices) -> Node:
|
def add_node(self, indices) -> Node:
|
||||||
"""Add a node to the graph and return it. Nodes refer to ordered sets
|
"""Add a node to the graph and return it. Nodes refer to ordered sets
|
||||||
of token indices.
|
of token indices.
|
||||||
|
|
||||||
This method is idempotent: if there is already a node for the given
|
This method is idempotent: if there is already a node for the given
|
||||||
indices, it is returned without a new node being created.
|
indices, it is returned without a new node being created.
|
||||||
"""
|
"""
|
||||||
|
@ -510,7 +509,7 @@ cdef class Graph:
|
||||||
return NoneNode(self)
|
return NoneNode(self)
|
||||||
else:
|
else:
|
||||||
return Node(self, node_index)
|
return Node(self, node_index)
|
||||||
|
|
||||||
def has_node(self, tuple indices) -> bool:
|
def has_node(self, tuple indices) -> bool:
|
||||||
"""Check whether the graph has a node for the given indices."""
|
"""Check whether the graph has a node for the given indices."""
|
||||||
return not self.get_node(indices).is_none
|
return not self.get_node(indices).is_none
|
||||||
|
@ -570,7 +569,7 @@ cdef int add_node(GraphC* graph, vector[int32_t]& node) nogil:
|
||||||
graph.roots.insert(index)
|
graph.roots.insert(index)
|
||||||
graph.node_map.insert(pair[hash_t, int](key, index))
|
graph.node_map.insert(pair[hash_t, int](key, index))
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
cdef int get_node(const GraphC* graph, vector[int32_t] node) nogil:
|
cdef int get_node(const GraphC* graph, vector[int32_t] node) nogil:
|
||||||
key = hash64(&node[0], node.size() * sizeof(node[0]), 0)
|
key = hash64(&node[0], node.size() * sizeof(node[0]), 0)
|
||||||
|
|
|
@ -89,4 +89,3 @@ cdef class MorphAnalysis:
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.to_json()
|
return self.to_json()
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
from libc.math cimport sqrt
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -10,11 +9,10 @@ from thinc.api import get_array_module
|
||||||
from ..attrs cimport *
|
from ..attrs cimport *
|
||||||
from ..attrs cimport ORTH, attr_id_t
|
from ..attrs cimport ORTH, attr_id_t
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..parts_of_speech cimport univ_pos_t
|
from ..structs cimport TokenC
|
||||||
from ..structs cimport LexemeC, TokenC
|
|
||||||
from ..symbols cimport dep
|
from ..symbols cimport dep
|
||||||
from ..typedefs cimport attr_t, flags_t, hash_t
|
from ..typedefs cimport attr_t, hash_t
|
||||||
from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start
|
from .doc cimport _get_lca_matrix, get_token_attr
|
||||||
from .token cimport Token
|
from .token cimport Token
|
||||||
|
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
|
@ -595,7 +593,6 @@ cdef class Span:
|
||||||
"""
|
"""
|
||||||
return "".join([t.text_with_ws for t in self])
|
return "".join([t.text_with_ws for t in self])
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def noun_chunks(self):
|
def noun_chunks(self):
|
||||||
"""Iterate over the base noun phrases in the span. Yields base
|
"""Iterate over the base noun phrases in the span. Yields base
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import struct
|
import struct
|
||||||
import weakref
|
import weakref
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ cdef class SpanGroup:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/spangroup
|
DOCS: https://spacy.io/api/spangroup
|
||||||
"""
|
"""
|
||||||
def __init__(self, doc, *, name="", attrs={}, spans=[]):
|
def __init__(self, doc, *, name="", attrs={}, spans=[]): # no-cython-lint
|
||||||
"""Create a SpanGroup.
|
"""Create a SpanGroup.
|
||||||
|
|
||||||
doc (Doc): The reference Doc object.
|
doc (Doc): The reference Doc object.
|
||||||
|
@ -311,7 +311,7 @@ cdef class SpanGroup:
|
||||||
|
|
||||||
other_attrs = deepcopy(other_group.attrs)
|
other_attrs = deepcopy(other_group.attrs)
|
||||||
span_group.attrs.update({
|
span_group.attrs.update({
|
||||||
key: value for key, value in other_attrs.items() \
|
key: value for key, value in other_attrs.items()
|
||||||
if key not in span_group.attrs
|
if key not in span_group.attrs
|
||||||
})
|
})
|
||||||
if len(other_group):
|
if len(other_group):
|
||||||
|
|
|
@ -26,7 +26,7 @@ cdef class Token:
|
||||||
cdef Token self = Token.__new__(Token, vocab, doc, offset)
|
cdef Token self = Token.__new__(Token, vocab, doc, offset)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
#cdef inline TokenC struct_from_attrs(Vocab vocab, attrs):
|
# cdef inline TokenC struct_from_attrs(Vocab vocab, attrs):
|
||||||
# cdef TokenC token
|
# cdef TokenC token
|
||||||
# attrs = normalize_attrs(attrs)
|
# attrs = normalize_attrs(attrs)
|
||||||
|
|
||||||
|
@ -98,12 +98,10 @@ cdef class Token:
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
token.sent_start = value
|
token.sent_start = value
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline int missing_dep(const TokenC* token) nogil:
|
cdef inline int missing_dep(const TokenC* token) nogil:
|
||||||
return token.dep == MISSING_DEP
|
return token.dep == MISSING_DEP
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef inline int missing_head(const TokenC* token) nogil:
|
cdef inline int missing_head(const TokenC* token) nogil:
|
||||||
return Token.missing_dep(token)
|
return Token.missing_dep(token)
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
# Compiler crashes on memory view coercion without this. Should report bug.
|
# Compiler crashes on memory view coercion without this. Should report bug.
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
from cython.view cimport array as cvarray
|
|
||||||
|
|
||||||
np.import_array()
|
np.import_array()
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy
|
|
||||||
from thinc.api import get_array_module
|
from thinc.api import get_array_module
|
||||||
|
|
||||||
from ..attrs cimport (
|
from ..attrs cimport (
|
||||||
|
@ -238,7 +236,7 @@ cdef class Token:
|
||||||
result = xp.dot(vector, other.vector) / (self.vector_norm * other.vector_norm)
|
result = xp.dot(vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||||
# ensure we get a scalar back (numpy does this automatically but cupy doesn't)
|
# ensure we get a scalar back (numpy does this automatically but cupy doesn't)
|
||||||
return result.item()
|
return result.item()
|
||||||
|
|
||||||
def has_morph(self):
|
def has_morph(self):
|
||||||
"""Check whether the token has annotated morph information.
|
"""Check whether the token has annotated morph information.
|
||||||
Return False when the morph annotation is unset/missing.
|
Return False when the morph annotation is unset/missing.
|
||||||
|
@ -545,9 +543,9 @@ cdef class Token:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if self.i + 1 == len(self.doc):
|
if self.i + 1 == len(self.doc):
|
||||||
return True
|
return True
|
||||||
elif self.doc[self.i+1].is_sent_start == None:
|
elif self.doc[self.i+1].is_sent_start is None:
|
||||||
return None
|
return None
|
||||||
elif self.doc[self.i+1].is_sent_start == True:
|
elif self.doc[self.i+1].is_sent_start is True:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
Loading…
Reference in New Issue
Block a user