Alignment: use a simplified ragged type for performance (#10319)

* Alignment: use a simplified ragged type for performance

This introduces the AlignmentArray type, which is a simplified version
of Ragged that performs better on the simple(r) indexing performed for
alignment.

* AlignmentArray: raise an error when using unsupported index

* AlignmentArray: move error messages to Errors

* AlignmentArray: remove simlified ... with simplifications

* AlignmentArray: fix typo that broke a[n:n] indexing
This commit is contained in:
Daniël de Kok 2022-04-01 09:02:06 +02:00 committed by GitHub
parent 03762b4b92
commit c90dd6f265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 152 additions and 50 deletions

View File

@ -23,6 +23,7 @@ Options.docstrings = True
PACKAGES = find_packages() PACKAGES = find_packages()
MOD_NAMES = [ MOD_NAMES = [
"spacy.training.alignment_array",
"spacy.training.example", "spacy.training.example",
"spacy.parts_of_speech", "spacy.parts_of_speech",
"spacy.strings", "spacy.strings",

View File

@ -897,6 +897,8 @@ class Errors(metaclass=ErrorsWithCodes):
E1025 = ("Cannot intify the value '{value}' as an IOB string. The only " E1025 = ("Cannot intify the value '{value}' as an IOB string. The only "
"supported values are: 'I', 'O', 'B' and ''") "supported values are: 'I', 'O', 'B' and ''")
E1026 = ("Edit tree has an invalid format:\n{errors}") E1026 = ("Edit tree has an invalid format:\n{errors}")
E1027 = ("AlignmentArray only supports slicing with a step of 1.")
E1028 = ("AlignmentArray only supports indexing using an int or a slice.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -218,7 +218,7 @@ def _get_aligned_sent_starts(example):
sent_starts = [False] * len(example.x) sent_starts = [False] * len(example.x)
seen_words = set() seen_words = set()
for y_sent in example.y.sents: for y_sent in example.y.sents:
x_indices = list(align[y_sent.start : y_sent.end].dataXd) x_indices = list(align[y_sent.start : y_sent.end])
if any(x_idx in seen_words for x_idx in x_indices): if any(x_idx in seen_words for x_idx in x_indices):
# If there are any tokens in X that align across two sentences, # If there are any tokens in X that align across two sentences,
# regard the sentence annotations as missing, as we can't # regard the sentence annotations as missing, as we can't

View File

@ -228,7 +228,7 @@ class Scorer:
if token.orth_.isspace(): if token.orth_.isspace():
continue continue
if align.x2y.lengths[token.i] == 1: if align.x2y.lengths[token.i] == 1:
gold_i = align.x2y[token.i].dataXd[0, 0] gold_i = align.x2y[token.i][0]
if gold_i not in missing_indices: if gold_i not in missing_indices:
pred_tags.add((gold_i, getter(token, attr))) pred_tags.add((gold_i, getter(token, attr)))
tag_score.score_set(pred_tags, gold_tags) tag_score.score_set(pred_tags, gold_tags)
@ -287,7 +287,7 @@ class Scorer:
if token.orth_.isspace(): if token.orth_.isspace():
continue continue
if align.x2y.lengths[token.i] == 1: if align.x2y.lengths[token.i] == 1:
gold_i = align.x2y[token.i].dataXd[0, 0] gold_i = align.x2y[token.i][0]
if gold_i not in missing_indices: if gold_i not in missing_indices:
value = getter(token, attr) value = getter(token, attr)
morph = gold_doc.vocab.strings[value] morph = gold_doc.vocab.strings[value]
@ -694,13 +694,13 @@ class Scorer:
if align.x2y.lengths[token.i] != 1: if align.x2y.lengths[token.i] != 1:
gold_i = None # type: ignore gold_i = None # type: ignore
else: else:
gold_i = align.x2y[token.i].dataXd[0, 0] gold_i = align.x2y[token.i][0]
if gold_i not in missing_indices: if gold_i not in missing_indices:
dep = getter(token, attr) dep = getter(token, attr)
head = head_getter(token, head_attr) head = head_getter(token, head_attr)
if dep not in ignore_labels and token.orth_.strip(): if dep not in ignore_labels and token.orth_.strip():
if align.x2y.lengths[head.i] == 1: if align.x2y.lengths[head.i] == 1:
gold_head = align.x2y[head.i].dataXd[0, 0] gold_head = align.x2y[head.i][0]
else: else:
gold_head = None gold_head = None
# None is indistinct, so we can't just add it to the set # None is indistinct, so we can't just add it to the set
@ -750,7 +750,7 @@ def get_ner_prf(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
for pred_ent in eg.x.ents: for pred_ent in eg.x.ents:
if pred_ent.label_ not in score_per_type: if pred_ent.label_ not in score_per_type:
score_per_type[pred_ent.label_] = PRFScore() score_per_type[pred_ent.label_] = PRFScore()
indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel() indices = align_x2y[pred_ent.start : pred_ent.end]
if len(indices): if len(indices):
g_span = eg.y[indices[0] : indices[-1] + 1] g_span = eg.y[indices[0] : indices[-1] + 1]
# Check we aren't missing annotation on this span. If so, # Check we aren't missing annotation on this span. If so,

View File

@ -8,6 +8,7 @@ from spacy.tokens import Doc, DocBin
from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets
from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo
from spacy.training import offsets_to_biluo_tags from spacy.training import offsets_to_biluo_tags
from spacy.training.alignment_array import AlignmentArray
from spacy.training.align import get_alignments from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs from spacy.training.converters import json_to_docs
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
@ -908,9 +909,41 @@ def test_alignment():
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts", "."] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts", "."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 6] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 6]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1]
assert list(align.y2x.dataXd) == [0, 1, 2, 3, 4, 5, 6, 7] assert list(align.y2x.data) == [0, 1, 2, 3, 4, 5, 6, 7]
def test_alignment_array():
a = AlignmentArray([[0, 1, 2], [3], [], [4, 5, 6, 7], [8, 9]])
assert list(a.data) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert list(a.lengths) == [3, 1, 0, 4, 2]
assert list(a[3]) == [4, 5, 6, 7]
assert list(a[2]) == []
assert list(a[-2]) == [4, 5, 6, 7]
assert list(a[1:4]) == [3, 4, 5, 6, 7]
assert list(a[1:]) == [3, 4, 5, 6, 7, 8, 9]
assert list(a[:3]) == [0, 1, 2, 3]
assert list(a[:]) == list(a.data)
assert list(a[0:0]) == []
assert list(a[3:3]) == []
assert list(a[-1:-1]) == []
with pytest.raises(ValueError, match=r"only supports slicing with a step of 1"):
a[:4:-1]
with pytest.raises(
ValueError, match=r"only supports indexing using an int or a slice"
):
a[[0, 1, 3]]
a = AlignmentArray([[], [1, 2, 3], [4, 5]])
assert list(a[0]) == []
assert list(a[0:1]) == []
assert list(a[2]) == [4, 5]
assert list(a[0:2]) == [1, 2, 3]
a = AlignmentArray([[1, 2, 3], [4, 5], []])
assert list(a[-1]) == []
assert list(a[-2:]) == [4, 5]
def test_alignment_case_insensitive(): def test_alignment_case_insensitive():
@ -918,9 +951,9 @@ def test_alignment_case_insensitive():
spacy_tokens = ["i", "listened", "to", "Obama", "'s", "PODCASTS", "."] spacy_tokens = ["i", "listened", "to", "Obama", "'s", "PODCASTS", "."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 6] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 6]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1]
assert list(align.y2x.dataXd) == [0, 1, 2, 3, 4, 5, 6, 7] assert list(align.y2x.data) == [0, 1, 2, 3, 4, 5, 6, 7]
def test_alignment_complex(): def test_alignment_complex():
@ -928,9 +961,9 @@ def test_alignment_complex():
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5] assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5]
def test_alignment_complex_example(en_vocab): def test_alignment_complex_example(en_vocab):
@ -947,9 +980,9 @@ def test_alignment_complex_example(en_vocab):
example = Example(predicted, reference) example = Example(predicted, reference)
align = example.alignment align = example.alignment
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5] assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5]
def test_alignment_different_texts(): def test_alignment_different_texts():
@ -965,70 +998,70 @@ def test_alignment_spaces(en_vocab):
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [0, 3, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [0, 3, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [1, 1, 1, 2, 3, 4, 5, 6] assert list(align.y2x.data) == [1, 1, 1, 2, 3, 4, 5, 6]
# multiple leading whitespace tokens # multiple leading whitespace tokens
other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."] other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."]
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [0, 0, 3, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [0, 0, 3, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [2, 2, 2, 3, 4, 5, 6, 7] assert list(align.y2x.data) == [2, 2, 2, 3, 4, 5, 6, 7]
# both with leading whitespace, not identical # both with leading whitespace, not identical
other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."] other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."]
spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."] spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [1, 0, 3, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [1, 0, 3, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 5, 5, 6, 6] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 5, 5, 6, 6]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [0, 2, 2, 2, 3, 4, 5, 6, 7] assert list(align.y2x.data) == [0, 2, 2, 2, 3, 4, 5, 6, 7]
# same leading whitespace, different tokenization # same leading whitespace, different tokenization
other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."] other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."]
spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."] spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [1, 1, 3, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [1, 1, 3, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 0, 1, 2, 3, 4, 5, 5, 6, 6] assert list(align.x2y.data) == [0, 0, 1, 2, 3, 4, 5, 5, 6, 6]
assert list(align.y2x.lengths) == [2, 1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [2, 1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [0, 1, 2, 2, 2, 3, 4, 5, 6, 7] assert list(align.y2x.data) == [0, 1, 2, 2, 2, 3, 4, 5, 6, 7]
# only one with trailing whitespace # only one with trailing whitespace
other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " "] other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " "]
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 0] assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 0]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5] assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5]
# different trailing whitespace # different trailing whitespace
other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "] other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "]
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 0] assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 0]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5, 6] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5, 6]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 1] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 1]
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5, 6] assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5, 6]
# same trailing whitespace, different tokenization # same trailing whitespace, different tokenization
other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "] other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "]
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 1] assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 1]
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5, 6, 6] assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5, 6, 6]
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 2] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 2]
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5, 6, 7] assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5, 6, 7]
# differing whitespace is allowed # differing whitespace is allowed
other_tokens = ["a", " \n ", "b", "c"] other_tokens = ["a", " \n ", "b", "c"]
spacy_tokens = ["a", "b", " ", "c"] spacy_tokens = ["a", "b", " ", "c"]
align = Alignment.from_strings(other_tokens, spacy_tokens) align = Alignment.from_strings(other_tokens, spacy_tokens)
assert list(align.x2y.dataXd) == [0, 1, 3] assert list(align.x2y.data) == [0, 1, 3]
assert list(align.y2x.dataXd) == [0, 2, 3] assert list(align.y2x.data) == [0, 2, 3]
# other differences in whitespace are allowed # other differences in whitespace are allowed
other_tokens = [" ", "a"] other_tokens = [" ", "a"]

View File

@ -1,31 +1,22 @@
from typing import List from typing import List
import numpy
from thinc.types import Ragged
from dataclasses import dataclass from dataclasses import dataclass
from .align import get_alignments from .align import get_alignments
from .alignment_array import AlignmentArray
@dataclass @dataclass
class Alignment: class Alignment:
x2y: Ragged x2y: AlignmentArray
y2x: Ragged y2x: AlignmentArray
@classmethod @classmethod
def from_indices(cls, x2y: List[List[int]], y2x: List[List[int]]) -> "Alignment": def from_indices(cls, x2y: List[List[int]], y2x: List[List[int]]) -> "Alignment":
x2y = _make_ragged(x2y) x2y = AlignmentArray(x2y)
y2x = _make_ragged(y2x) y2x = AlignmentArray(y2x)
return Alignment(x2y=x2y, y2x=y2x) return Alignment(x2y=x2y, y2x=y2x)
@classmethod @classmethod
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment": def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
x2y, y2x = get_alignments(A, B) x2y, y2x = get_alignments(A, B)
return Alignment.from_indices(x2y=x2y, y2x=y2x) return Alignment.from_indices(x2y=x2y, y2x=y2x)
def _make_ragged(indices):
lengths = numpy.array([len(x) for x in indices], dtype="i")
flat = []
for x in indices:
flat.extend(x)
return Ragged(numpy.array(flat, dtype="i"), lengths)

View File

@ -0,0 +1,7 @@
from libcpp.vector cimport vector
cimport numpy as np
cdef class AlignmentArray:
cdef np.ndarray _data
cdef np.ndarray _lengths
cdef np.ndarray _starts_ends

View File

@ -0,0 +1,68 @@
from typing import List
from ..errors import Errors
import numpy
cdef class AlignmentArray:
"""AlignmentArray is similar to Thinc's Ragged with two simplfications:
indexing returns numpy arrays and this type can only be used for CPU arrays.
However, these changes make AlginmentArray more efficient for indexing in a
tight loop."""
__slots__ = []
def __init__(self, alignment: List[List[int]]):
self._lengths = None
self._starts_ends = numpy.zeros(len(alignment) + 1, dtype="i")
cdef int data_len = 0
cdef int outer_len
cdef int idx
for idx, outer in enumerate(alignment):
outer_len = len(outer)
self._starts_ends[idx + 1] = self._starts_ends[idx] + outer_len
data_len += outer_len
self._data = numpy.empty(data_len, dtype="i")
idx = 0
for outer in alignment:
for inner in outer:
self._data[idx] = inner
idx += 1
def __getitem__(self, idx):
starts = self._starts_ends[:-1]
ends = self._starts_ends[1:]
if isinstance(idx, int):
start = starts[idx]
end = ends[idx]
elif isinstance(idx, slice):
if not (idx.step is None or idx.step == 1):
raise ValueError(Errors.E1027)
start = starts[idx]
if len(start) == 0:
return self._data[0:0]
start = start[0]
end = ends[idx][-1]
else:
raise ValueError(Errors.E1028)
return self._data[start:end]
@property
def data(self):
return self._data
@property
def lengths(self):
if self._lengths is None:
self._lengths = self.ends - self.starts
return self._lengths
@property
def ends(self):
return self._starts_ends[1:]
@property
def starts(self):
return self._starts_ends[:-1]

View File

@ -159,7 +159,7 @@ cdef class Example:
gold_values = self.reference.to_array([field]) gold_values = self.reference.to_array([field])
output = [None] * len(self.predicted) output = [None] * len(self.predicted)
for token in self.predicted: for token in self.predicted:
values = gold_values[align[token.i].dataXd] values = gold_values[align[token.i]]
values = values.ravel() values = values.ravel()
if len(values) == 0: if len(values) == 0:
output[token.i] = None output[token.i] = None
@ -190,9 +190,9 @@ cdef class Example:
deps = [d if has_deps[i] else deps[i] for i, d in enumerate(proj_deps)] deps = [d if has_deps[i] else deps[i] for i, d in enumerate(proj_deps)]
for cand_i in range(self.x.length): for cand_i in range(self.x.length):
if cand_to_gold.lengths[cand_i] == 1: if cand_to_gold.lengths[cand_i] == 1:
gold_i = cand_to_gold[cand_i].dataXd[0, 0] gold_i = cand_to_gold[cand_i][0]
if gold_to_cand.lengths[heads[gold_i]] == 1: if gold_to_cand.lengths[heads[gold_i]] == 1:
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0]) aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]][0])
aligned_deps[cand_i] = deps[gold_i] aligned_deps[cand_i] = deps[gold_i]
return aligned_heads, aligned_deps return aligned_heads, aligned_deps
@ -204,7 +204,7 @@ cdef class Example:
align = self.alignment.y2x align = self.alignment.y2x
sent_starts = [False] * len(self.x) sent_starts = [False] * len(self.x)
for y_sent in self.y.sents: for y_sent in self.y.sents:
x_start = int(align[y_sent.start].dataXd[0]) x_start = int(align[y_sent.start][0])
sent_starts[x_start] = True sent_starts[x_start] = True
return sent_starts return sent_starts
else: else:
@ -220,7 +220,7 @@ cdef class Example:
seen = set() seen = set()
output = [] output = []
for span in spans: for span in spans:
indices = align[span.start : span.end].data.ravel() indices = align[span.start : span.end]
if not allow_overlap: if not allow_overlap:
indices = [idx for idx in indices if idx not in seen] indices = [idx for idx in indices if idx not in seen]
if len(indices) >= 1: if len(indices) >= 1:
@ -316,7 +316,7 @@ cdef class Example:
seen_indices = set() seen_indices = set()
output = [] output = []
for y_sent in self.reference.sents: for y_sent in self.reference.sents:
indices = align[y_sent.start : y_sent.end].data.ravel() indices = align[y_sent.start : y_sent.end]
indices = [idx for idx in indices if idx not in seen_indices] indices = [idx for idx in indices if idx not in seen_indices]
if indices: if indices:
x_sent = self.predicted[indices[0] : indices[-1] + 1] x_sent = self.predicted[indices[0] : indices[-1] + 1]