Clean up spacy.tokens (#6046)

* Clean up spacy.tokens

* Update `set_children_from_heads`:
  * Don't check `dep` when setting lr_* or sentence starts
  * Set all non-sentence starts to `False`

* Use `set_children_from_heads` in `Token.head` setter
  * Reduce similar/duplicate code (admittedly adds a bit of overhead)
  * Update sentence starts consistently

* Remove unused `Doc.set_parse`

* Minor changes:
  * Declare cython variables (to avoid cython warnings)
  * Clean up imports

* Modify set_children_from_heads to set token range

Modify `set_children_from_heads` so that it adjust tokens within a
specified range rather then the whole document.

Modify the `Token.head` setter to adjust only the tokens affected by the
new head assignment.
This commit is contained in:
Adriane Boyd 2020-09-16 20:32:38 +02:00 committed by GitHub
parent c776594ab1
commit a119667a36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 85 additions and 142 deletions

View File

@ -680,7 +680,7 @@ cdef class ArcEager(TransitionSystem):
def finalize_doc(self, Doc doc):
doc.is_parsed = True
set_children_from_heads(doc.c, doc.length)
set_children_from_heads(doc.c, 0, doc.length)
def has_gold(self, Example eg, start=0, end=None):
for word in eg.y[start:end]:

View File

@ -119,7 +119,7 @@ cpdef deprojectivize(Doc doc):
new_head = _find_new_head(doc[i], head_label)
doc.c[i].head = new_head.i - i
doc.c[i].dep = doc.vocab.strings.add(new_label)
set_children_from_heads(doc.c, doc.length)
set_children_from_heads(doc.c, 0, doc.length)
return doc

View File

@ -265,17 +265,11 @@ def test_doc_is_nered(en_vocab):
def test_doc_from_array_sent_starts(en_vocab):
words = ["I", "live", "in", "New", "York", ".", "I", "like", "cats", "."]
heads = [0, 0, 0, 0, 0, 0, 6, 6, 6, 6]
heads = [0, -1, -2, -3, -4, -5, 0, -1, -2, -3]
# fmt: off
deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep", "dep"]
deps = ["ROOT", "dep", "dep", "dep", "dep", "dep", "ROOT", "dep", "dep", "dep"]
# fmt: on
doc = Doc(en_vocab, words=words)
for i, (dep, head) in enumerate(zip(deps, heads)):
doc[i].dep_ = dep
doc[i].head = doc[head]
if head == i:
doc[i].is_sent_start = True
doc.is_parsed
doc = get_doc(en_vocab, words=words, heads=heads, deps=deps)
attrs = [SENT_START, HEAD]
arr = doc.to_array(attrs)

View File

@ -112,7 +112,6 @@ def test_doc_token_api_ancestors(en_tokenizer):
def test_doc_token_api_head_setter(en_tokenizer):
# the structure of this sentence depends on the English annotation scheme
text = "Yesterday I saw a dog that barked loudly."
heads = [2, 1, 0, 1, -2, 1, -2, -1, -6]
tokens = en_tokenizer(text)
@ -169,6 +168,40 @@ def test_doc_token_api_head_setter(en_tokenizer):
with pytest.raises(ValueError):
doc[0].head = doc2[0]
# test sentence starts when two sentences are joined
text = "This is one sentence. This is another sentence."
heads = [0, -1, -2, -3, -4, 0, -1, -2, -3, -4]
tokens = en_tokenizer(text)
doc = get_doc(
tokens.vocab,
words=[t.text for t in tokens],
heads=heads,
deps=["dep"] * len(heads),
)
# initially two sentences
assert doc[0].is_sent_start
assert doc[5].is_sent_start
assert doc[0].left_edge == doc[0]
assert doc[0].right_edge == doc[4]
assert doc[5].left_edge == doc[5]
assert doc[5].right_edge == doc[9]
# modifying with a sentence doesn't change sent starts
doc[2].head = doc[3]
assert doc[0].is_sent_start
assert doc[5].is_sent_start
assert doc[0].left_edge == doc[0]
assert doc[0].right_edge == doc[4]
assert doc[5].left_edge == doc[5]
assert doc[5].right_edge == doc[9]
# attach the second sentence to the first, resulting in one sentence
doc[5].head = doc[0]
assert doc[0].is_sent_start
assert not doc[5].is_sent_start
assert doc[0].left_edge == doc[0]
assert doc[0].right_edge == doc[9]
def test_is_sent_start(en_tokenizer):
doc = en_tokenizer("This is a sentence. This is another.")

View File

@ -184,7 +184,7 @@ def test_parser_set_sent_starts(en_vocab):
if i == 0 or i == 3:
assert doc[i].is_sent_start is True
else:
assert doc[i].is_sent_start is None
assert not doc[i].is_sent_start
for sent in doc.sents:
for token in sent:
assert token.head in sent

View File

@ -123,7 +123,7 @@ def test_issue2772(en_vocab):
heads = [4, 1, 7, -1, -2, -1, 3, 2, 1, 0, 2, 1, -3, -4]
deps = ["dep"] * len(heads)
doc = get_doc(en_vocab, words=words, heads=heads, deps=deps)
assert doc[1].is_sent_start is None
assert not doc[1].is_sent_start
@pytest.mark.parametrize("text", ["-0.23", "+123,456", "±1"])

View File

@ -274,7 +274,7 @@ def _merge(Doc doc, merges):
for i in range(doc.length):
doc.c[i].head -= i
# Set the left/right children, left/right edges
set_children_from_heads(doc.c, doc.length)
set_children_from_heads(doc.c, 0, doc.length)
# Make sure ent_iob remains consistent
make_iob_consistent(doc.c, doc.length)
# Return the merged Python object
@ -381,7 +381,7 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
for i in range(doc.length):
doc.c[i].head -= i
# set children from head
set_children_from_heads(doc.c, doc.length)
set_children_from_heads(doc.c, 0, doc.length)
def _validate_extensions(extensions):
@ -408,7 +408,6 @@ cdef make_iob_consistent(TokenC* tokens, int length):
def normalize_token_attrs(Vocab vocab, attrs):
if "_" in attrs: # Extension attributes
extensions = attrs["_"]
print("EXTENSIONS", extensions)
_validate_extensions(extensions)
attrs = {key: value for key, value in attrs.items() if key != "_"}
attrs = intify_attrs(attrs, strings_map=vocab.strings)

View File

@ -19,10 +19,10 @@ ctypedef fused LexemeOrToken:
const_TokenC_ptr
cdef int set_children_from_heads(TokenC* tokens, int length) except -1
cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1
cdef int _set_lr_kids_and_edges(TokenC* tokens, int length, int loop_count) except -1
cdef int _set_lr_kids_and_edges(TokenC* tokens, int start, int end, int loop_count) except -1
cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2
@ -31,9 +31,6 @@ cdef int token_by_start(const TokenC* tokens, int length, int start_char) except
cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
cdef int set_children_from_heads(TokenC* tokens, int length) except -1
cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
cdef class Doc:
@ -74,5 +71,3 @@ cdef class Doc:
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1
cpdef np.ndarray to_array(self, object features)
cdef void set_parse(self, const TokenC* parsed) nogil

View File

@ -1,32 +1,27 @@
# cython: infer_types=True, bounds_check=False, profile=True
cimport cython
cimport numpy as np
from libc.string cimport memcpy, memset
from libc.string cimport memcpy
from libc.math cimport sqrt
from libc.stdint cimport int32_t, uint64_t
import copy
from collections import Counter
import numpy
import numpy.linalg
import struct
import srsly
from thinc.api import get_array_module
from thinc.util import copy_array
import warnings
import copy
from .span cimport Span
from .token cimport Token
from ..lexeme cimport Lexeme, EMPTY_LEXEME
from ..typedefs cimport attr_t, flags_t
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
from ..attrs cimport attr_id_t
from ..attrs cimport LENGTH, POS, LEMMA, TAG, MORPH, DEP, HEAD, SPACY, ENT_IOB
from ..attrs cimport ENT_TYPE, ENT_ID, ENT_KB_ID, SENT_START, IDX, attr_id_t
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
from ..attrs cimport ENT_TYPE, ENT_ID, ENT_KB_ID, SENT_START, IDX, NORM
from ..attrs import intify_attr, intify_attrs, IDS
from ..util import normalize_slice
from ..attrs import intify_attr, IDS
from ..compat import copy_reg, pickle
from ..errors import Errors, Warnings
from .. import util
@ -291,7 +286,7 @@ cdef class Doc:
DOCS: https://nightly.spacy.io/api/doc#getitem
"""
if isinstance(i, slice):
start, stop = normalize_slice(len(self), i.start, i.stop, i.step)
start, stop = util.normalize_slice(len(self), i.start, i.stop, i.step)
return Span(self, start, stop, label=0)
if i < 0:
i = self.length + i
@ -627,10 +622,7 @@ cdef class Doc:
@property
def sents(self):
"""Iterate over the sentences in the document. Yields sentence `Span`
objects. Sentence spans have no label. To improve accuracy on informal
texts, spaCy calculates sentence boundaries from the syntactic
dependency parse. If the parser is disabled, the `sents` iterator will
be unavailable.
objects. Sentence spans have no label.
YIELDS (Span): Sentences in the document.
@ -786,14 +778,6 @@ cdef class Doc:
for i in range(self.length, self.max_length + PADDING):
self.c[i].lex = &EMPTY_LEXEME
cdef void set_parse(self, const TokenC* parsed) nogil:
# TODO: This method is fairly misleading atm. It's used by Parser
# to actually apply the parse calculated. Need to rethink this.
# Probably we should use from_array?
self.is_parsed = True
for i in range(self.length):
self.c[i] = parsed[i]
def from_array(self, attrs, array):
"""Load attributes from a numpy array. Write to a `Doc` object, from an
`(M, N)` array of attributes.
@ -884,7 +868,7 @@ cdef class Doc:
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
# If document is parsed, set children
if self.is_parsed:
set_children_from_heads(self.c, length)
set_children_from_heads(self.c, 0, length)
return self
@staticmethod
@ -1321,13 +1305,13 @@ cdef int token_by_char(const TokenC* tokens, int length, int char_idx) except -2
return mid
return -1
cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1:
# note: end is exclusive
cdef TokenC* head
cdef TokenC* child
cdef int i
# Set number of left/right children to 0. We'll increment it in the loops.
for i in range(length):
for i in range(start, end):
tokens[i].l_kids = 0
tokens[i].r_kids = 0
tokens[i].l_edge = i
@ -1341,38 +1325,40 @@ cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
# without risking getting stuck in an infinite loop if something is
# terribly malformed.
while not heads_within_sents:
heads_within_sents = _set_lr_kids_and_edges(tokens, length, loop_count)
heads_within_sents = _set_lr_kids_and_edges(tokens, start, end, loop_count)
if loop_count > 10:
warnings.warn(Warnings.W026)
break
loop_count += 1
# Set sentence starts
for i in range(length):
if tokens[i].head == 0 and tokens[i].dep != 0:
for i in range(start, end):
tokens[i].sent_start = -1
for i in range(start, end):
if tokens[i].head == 0:
tokens[tokens[i].l_edge].sent_start = True
cdef int _set_lr_kids_and_edges(TokenC* tokens, int length, int loop_count) except -1:
cdef int _set_lr_kids_and_edges(TokenC* tokens, int start, int end, int loop_count) except -1:
# May be called multiple times due to non-projectivity. See issues #3170
# and #4688.
# Set left edges
cdef TokenC* head
cdef TokenC* child
cdef int i, j
for i in range(length):
for i in range(start, end):
child = &tokens[i]
head = &tokens[i + child.head]
if child < head and loop_count == 0:
if loop_count == 0 and child < head:
head.l_kids += 1
if child.l_edge < head.l_edge:
head.l_edge = child.l_edge
if child.r_edge > head.r_edge:
head.r_edge = child.r_edge
# Set right edges - same as above, but iterate in reverse
for i in range(length-1, -1, -1):
for i in range(end-1, start-1, -1):
child = &tokens[i]
head = &tokens[i + child.head]
if child > head and loop_count == 0:
if loop_count == 0 and child > head:
head.r_kids += 1
if child.r_edge > head.r_edge:
head.r_edge = child.r_edge
@ -1380,14 +1366,14 @@ cdef int _set_lr_kids_and_edges(TokenC* tokens, int length, int loop_count) exce
head.l_edge = child.l_edge
# Get sentence start positions according to current state
sent_starts = set()
for i in range(length):
if tokens[i].head == 0 and tokens[i].dep != 0:
for i in range(start, end):
if tokens[i].head == 0:
sent_starts.add(tokens[i].l_edge)
cdef int curr_sent_start = 0
cdef int curr_sent_end = 0
# Check whether any heads are not within the current sentence
for i in range(length):
if (i > 0 and i in sent_starts) or i == length - 1:
for i in range(start, end):
if (i > 0 and i in sent_starts) or i == end - 1:
curr_sent_end = i
for j in range(curr_sent_start, curr_sent_end):
if tokens[j].head + j < curr_sent_start or tokens[j].head + j >= curr_sent_end + 1:
@ -1436,6 +1422,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
with shape (n, n), where n = len(doc).
"""
cdef int [:,:] lca_matrix
cdef int j, k
n_tokens= end - start
lca_mat = numpy.empty((n_tokens, n_tokens), dtype=numpy.int32)
lca_mat.fill(-1)

View File

@ -4,13 +4,10 @@ cimport numpy as np
from libc.math cimport sqrt
import numpy
import numpy.linalg
from thinc.api import get_array_module
from collections import defaultdict
import warnings
from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix
from .token cimport TokenC
from ..structs cimport TokenC, LexemeC
from ..typedefs cimport flags_t, attr_t, hash_t
from ..attrs cimport attr_id_t

View File

@ -1,6 +1,4 @@
# cython: infer_types=True
from libc.string cimport memcpy
from cpython.mem cimport PyMem_Malloc, PyMem_Free
# Compiler crashes on memory view coercion without this. Should report bug.
from cython.view cimport array as cvarray
cimport numpy as np
@ -14,14 +12,13 @@ from ..typedefs cimport hash_t
from ..lexeme cimport Lexeme
from ..attrs cimport IS_ALPHA, IS_ASCII, IS_DIGIT, IS_LOWER, IS_PUNCT, IS_SPACE
from ..attrs cimport IS_BRACKET, IS_QUOTE, IS_LEFT_PUNCT, IS_RIGHT_PUNCT
from ..attrs cimport IS_TITLE, IS_UPPER, IS_CURRENCY, LIKE_URL, LIKE_NUM, LIKE_EMAIL
from ..attrs cimport IS_STOP, ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX
from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP
from ..attrs cimport IS_TITLE, IS_UPPER, IS_CURRENCY, IS_STOP
from ..attrs cimport LIKE_URL, LIKE_NUM, LIKE_EMAIL
from ..symbols cimport conj
from .morphanalysis cimport MorphAnalysis
from .doc cimport set_children_from_heads
from .. import parts_of_speech
from .. import util
from ..errors import Errors, Warnings
from .underscore import Underscore, get_ext_args
@ -658,78 +655,19 @@ cdef class Token:
# Do nothing if old head is new head
if self.i + self.c.head == new_head.i:
return
cdef Token old_head = self.head
cdef int rel_newhead_i = new_head.i - self.i
# Is the new head a descendant of the old head
cdef bint is_desc = old_head.is_ancestor(new_head)
cdef int new_edge
cdef Token anc, child
# Update number of deps of old head
if self.c.head > 0: # left dependent
old_head.c.l_kids -= 1
if self.c.l_edge == old_head.c.l_edge:
# The token dominates the left edge so the left edge of
# the head may change when the token is reattached, it may
# not change if the new head is a descendant of the current
# head.
new_edge = self.c.l_edge
# The new l_edge is the left-most l_edge on any of the
# other dependents where the l_edge is left of the head,
# otherwise it is the head
if not is_desc:
new_edge = old_head.i
for child in old_head.children:
if child == self:
continue
if child.c.l_edge < new_edge:
new_edge = child.c.l_edge
old_head.c.l_edge = new_edge
# Walk up the tree from old_head and assign new l_edge to
# ancestors until an ancestor already has an l_edge that's
# further left
for anc in old_head.ancestors:
if anc.c.l_edge <= new_edge:
break
anc.c.l_edge = new_edge
elif self.c.head < 0: # right dependent
old_head.c.r_kids -= 1
# Do the same thing as for l_edge
if self.c.r_edge == old_head.c.r_edge:
new_edge = self.c.r_edge
if not is_desc:
new_edge = old_head.i
for child in old_head.children:
if child == self:
continue
if child.c.r_edge > new_edge:
new_edge = child.c.r_edge
old_head.c.r_edge = new_edge
for anc in old_head.ancestors:
if anc.c.r_edge >= new_edge:
break
anc.c.r_edge = new_edge
# Update number of deps of new head
if rel_newhead_i > 0: # left dependent
new_head.c.l_kids += 1
# Walk up the tree from new head and set l_edge to self.l_edge
# until you hit a token with an l_edge further to the left
if self.c.l_edge < new_head.c.l_edge:
new_head.c.l_edge = self.c.l_edge
for anc in new_head.ancestors:
if anc.c.l_edge <= self.c.l_edge:
break
anc.c.l_edge = self.c.l_edge
elif rel_newhead_i < 0: # right dependent
new_head.c.r_kids += 1
# Do the same as for l_edge
if self.c.r_edge > new_head.c.r_edge:
new_head.c.r_edge = self.c.r_edge
for anc in new_head.ancestors:
if anc.c.r_edge >= self.c.r_edge:
break
anc.c.r_edge = self.c.r_edge
# Find the widest l/r_edges of the roots of the two tokens involved
# to limit the number of tokens for set_children_from_heads
cdef Token self_root, new_head_root
self_ancestors = list(self.ancestors)
new_head_ancestors = list(new_head.ancestors)
self_root = self_ancestors[-1] if self_ancestors else self
new_head_root = new_head_ancestors[-1] if new_head_ancestors else new_head
start = self_root.c.l_edge if self_root.c.l_edge < new_head_root.c.l_edge else new_head_root.c.l_edge
end = self_root.c.r_edge if self_root.c.r_edge > new_head_root.c.r_edge else new_head_root.c.r_edge
# Set new head
self.c.head = rel_newhead_i
self.c.head = new_head.i - self.i
# Adjust parse properties and sentence starts
set_children_from_heads(self.doc.c, start, end + 1)
@property
def conjuncts(self):