diff --git a/setup.py b/setup.py index a5748e9b4..9023b9fa3 100755 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ Options.docstrings = True PACKAGES = find_packages() MOD_NAMES = [ + "spacy.training.alignment_array", "spacy.training.example", "spacy.parts_of_speech", "spacy.strings", diff --git a/spacy/errors.py b/spacy/errors.py index a0cd2ef34..24a9f0339 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -897,6 +897,8 @@ class Errors(metaclass=ErrorsWithCodes): E1025 = ("Cannot intify the value '{value}' as an IOB string. The only " "supported values are: 'I', 'O', 'B' and ''") E1026 = ("Edit tree has an invalid format:\n{errors}") + E1027 = ("AlignmentArray only supports slicing with a step of 1.") + E1028 = ("AlignmentArray only supports indexing using an int or a slice.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 029e2e29e..f1165592e 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -218,7 +218,7 @@ def _get_aligned_sent_starts(example): sent_starts = [False] * len(example.x) seen_words = set() for y_sent in example.y.sents: - x_indices = list(align[y_sent.start : y_sent.end].dataXd) + x_indices = list(align[y_sent.start : y_sent.end]) if any(x_idx in seen_words for x_idx in x_indices): # If there are any tokens in X that align across two sentences, # regard the sentence annotations as missing, as we can't diff --git a/spacy/scorer.py b/spacy/scorer.py index ae9338bd5..8cd755ac4 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -228,7 +228,7 @@ class Scorer: if token.orth_.isspace(): continue if align.x2y.lengths[token.i] == 1: - gold_i = align.x2y[token.i].dataXd[0, 0] + gold_i = align.x2y[token.i][0] if gold_i not in missing_indices: pred_tags.add((gold_i, getter(token, attr))) tag_score.score_set(pred_tags, gold_tags) @@ -287,7 +287,7 @@ class Scorer: if token.orth_.isspace(): continue if align.x2y.lengths[token.i] == 1: - gold_i = align.x2y[token.i].dataXd[0, 0] + gold_i = align.x2y[token.i][0] if gold_i not in missing_indices: value = getter(token, attr) morph = gold_doc.vocab.strings[value] @@ -694,13 +694,13 @@ class Scorer: if align.x2y.lengths[token.i] != 1: gold_i = None # type: ignore else: - gold_i = align.x2y[token.i].dataXd[0, 0] + gold_i = align.x2y[token.i][0] if gold_i not in missing_indices: dep = getter(token, attr) head = head_getter(token, head_attr) if dep not in ignore_labels and token.orth_.strip(): if align.x2y.lengths[head.i] == 1: - gold_head = align.x2y[head.i].dataXd[0, 0] + gold_head = align.x2y[head.i][0] else: gold_head = None # None is indistinct, so we can't just add it to the set @@ -750,7 +750,7 @@ def get_ner_prf(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: for pred_ent in eg.x.ents: if pred_ent.label_ not in score_per_type: score_per_type[pred_ent.label_] = PRFScore() - indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel() + indices = align_x2y[pred_ent.start : pred_ent.end] if len(indices): g_span = eg.y[indices[0] : indices[-1] + 1] # Check we aren't missing annotation on this span. If so, diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index f1f8ce9d4..8e08a25fb 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -8,6 +8,7 @@ from spacy.tokens import Doc, DocBin from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo from spacy.training import offsets_to_biluo_tags +from spacy.training.alignment_array import AlignmentArray from spacy.training.align import get_alignments from spacy.training.converters import json_to_docs from spacy.util import get_words_and_spaces, load_model_from_path, minibatch @@ -908,9 +909,41 @@ def test_alignment(): spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts", "."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 6] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 6] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1] - assert list(align.y2x.dataXd) == [0, 1, 2, 3, 4, 5, 6, 7] + assert list(align.y2x.data) == [0, 1, 2, 3, 4, 5, 6, 7] + + +def test_alignment_array(): + a = AlignmentArray([[0, 1, 2], [3], [], [4, 5, 6, 7], [8, 9]]) + assert list(a.data) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert list(a.lengths) == [3, 1, 0, 4, 2] + assert list(a[3]) == [4, 5, 6, 7] + assert list(a[2]) == [] + assert list(a[-2]) == [4, 5, 6, 7] + assert list(a[1:4]) == [3, 4, 5, 6, 7] + assert list(a[1:]) == [3, 4, 5, 6, 7, 8, 9] + assert list(a[:3]) == [0, 1, 2, 3] + assert list(a[:]) == list(a.data) + assert list(a[0:0]) == [] + assert list(a[3:3]) == [] + assert list(a[-1:-1]) == [] + with pytest.raises(ValueError, match=r"only supports slicing with a step of 1"): + a[:4:-1] + with pytest.raises( + ValueError, match=r"only supports indexing using an int or a slice" + ): + a[[0, 1, 3]] + + a = AlignmentArray([[], [1, 2, 3], [4, 5]]) + assert list(a[0]) == [] + assert list(a[0:1]) == [] + assert list(a[2]) == [4, 5] + assert list(a[0:2]) == [1, 2, 3] + + a = AlignmentArray([[1, 2, 3], [4, 5], []]) + assert list(a[-1]) == [] + assert list(a[-2:]) == [4, 5] def test_alignment_case_insensitive(): @@ -918,9 +951,9 @@ def test_alignment_case_insensitive(): spacy_tokens = ["i", "listened", "to", "Obama", "'s", "PODCASTS", "."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 6] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 6] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1] - assert list(align.y2x.dataXd) == [0, 1, 2, 3, 4, 5, 6, 7] + assert list(align.y2x.data) == [0, 1, 2, 3, 4, 5, 6, 7] def test_alignment_complex(): @@ -928,9 +961,9 @@ def test_alignment_complex(): spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5] + assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5] def test_alignment_complex_example(en_vocab): @@ -947,9 +980,9 @@ def test_alignment_complex_example(en_vocab): example = Example(predicted, reference) align = example.alignment assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5] + assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5] def test_alignment_different_texts(): @@ -965,70 +998,70 @@ def test_alignment_spaces(en_vocab): spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [0, 3, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [1, 1, 1, 2, 3, 4, 5, 6] + assert list(align.y2x.data) == [1, 1, 1, 2, 3, 4, 5, 6] # multiple leading whitespace tokens other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [0, 0, 3, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [2, 2, 2, 3, 4, 5, 6, 7] + assert list(align.y2x.data) == [2, 2, 2, 3, 4, 5, 6, 7] # both with leading whitespace, not identical other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."] spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [1, 0, 3, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 5, 5, 6, 6] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 5, 5, 6, 6] assert list(align.y2x.lengths) == [1, 1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [0, 2, 2, 2, 3, 4, 5, 6, 7] + assert list(align.y2x.data) == [0, 2, 2, 2, 3, 4, 5, 6, 7] # same leading whitespace, different tokenization other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."] spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [1, 1, 3, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 0, 1, 2, 3, 4, 5, 5, 6, 6] + assert list(align.x2y.data) == [0, 0, 1, 2, 3, 4, 5, 5, 6, 6] assert list(align.y2x.lengths) == [2, 1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [0, 1, 2, 2, 2, 3, 4, 5, 6, 7] + assert list(align.y2x.data) == [0, 1, 2, 2, 2, 3, 4, 5, 6, 7] # only one with trailing whitespace other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " "] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 0] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2] - assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5] + assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5] # different trailing whitespace other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 0] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5, 6] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5, 6] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 1] - assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5, 6] + assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5, 6] # same trailing whitespace, different tokenization other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "] spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "] align = Alignment.from_strings(other_tokens, spacy_tokens) assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 1] - assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5, 6, 6] + assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5, 6, 6] assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 2] - assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5, 6, 7] + assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5, 6, 7] # differing whitespace is allowed other_tokens = ["a", " \n ", "b", "c"] spacy_tokens = ["a", "b", " ", "c"] align = Alignment.from_strings(other_tokens, spacy_tokens) - assert list(align.x2y.dataXd) == [0, 1, 3] - assert list(align.y2x.dataXd) == [0, 2, 3] + assert list(align.x2y.data) == [0, 1, 3] + assert list(align.y2x.data) == [0, 2, 3] # other differences in whitespace are allowed other_tokens = [" ", "a"] diff --git a/spacy/training/alignment.py b/spacy/training/alignment.py index 3e3b60ca6..6d24714bf 100644 --- a/spacy/training/alignment.py +++ b/spacy/training/alignment.py @@ -1,31 +1,22 @@ from typing import List -import numpy -from thinc.types import Ragged from dataclasses import dataclass from .align import get_alignments +from .alignment_array import AlignmentArray @dataclass class Alignment: - x2y: Ragged - y2x: Ragged + x2y: AlignmentArray + y2x: AlignmentArray @classmethod def from_indices(cls, x2y: List[List[int]], y2x: List[List[int]]) -> "Alignment": - x2y = _make_ragged(x2y) - y2x = _make_ragged(y2x) + x2y = AlignmentArray(x2y) + y2x = AlignmentArray(y2x) return Alignment(x2y=x2y, y2x=y2x) @classmethod def from_strings(cls, A: List[str], B: List[str]) -> "Alignment": x2y, y2x = get_alignments(A, B) return Alignment.from_indices(x2y=x2y, y2x=y2x) - - -def _make_ragged(indices): - lengths = numpy.array([len(x) for x in indices], dtype="i") - flat = [] - for x in indices: - flat.extend(x) - return Ragged(numpy.array(flat, dtype="i"), lengths) diff --git a/spacy/training/alignment_array.pxd b/spacy/training/alignment_array.pxd new file mode 100644 index 000000000..056f5bef3 --- /dev/null +++ b/spacy/training/alignment_array.pxd @@ -0,0 +1,7 @@ +from libcpp.vector cimport vector +cimport numpy as np + +cdef class AlignmentArray: + cdef np.ndarray _data + cdef np.ndarray _lengths + cdef np.ndarray _starts_ends diff --git a/spacy/training/alignment_array.pyx b/spacy/training/alignment_array.pyx new file mode 100644 index 000000000..b58f08786 --- /dev/null +++ b/spacy/training/alignment_array.pyx @@ -0,0 +1,68 @@ +from typing import List +from ..errors import Errors +import numpy + + +cdef class AlignmentArray: + """AlignmentArray is similar to Thinc's Ragged with two simplfications: + indexing returns numpy arrays and this type can only be used for CPU arrays. + However, these changes make AlginmentArray more efficient for indexing in a + tight loop.""" + + __slots__ = [] + + def __init__(self, alignment: List[List[int]]): + self._lengths = None + self._starts_ends = numpy.zeros(len(alignment) + 1, dtype="i") + + cdef int data_len = 0 + cdef int outer_len + cdef int idx + for idx, outer in enumerate(alignment): + outer_len = len(outer) + self._starts_ends[idx + 1] = self._starts_ends[idx] + outer_len + data_len += outer_len + + self._data = numpy.empty(data_len, dtype="i") + idx = 0 + for outer in alignment: + for inner in outer: + self._data[idx] = inner + idx += 1 + + def __getitem__(self, idx): + starts = self._starts_ends[:-1] + ends = self._starts_ends[1:] + if isinstance(idx, int): + start = starts[idx] + end = ends[idx] + elif isinstance(idx, slice): + if not (idx.step is None or idx.step == 1): + raise ValueError(Errors.E1027) + start = starts[idx] + if len(start) == 0: + return self._data[0:0] + start = start[0] + end = ends[idx][-1] + else: + raise ValueError(Errors.E1028) + + return self._data[start:end] + + @property + def data(self): + return self._data + + @property + def lengths(self): + if self._lengths is None: + self._lengths = self.ends - self.starts + return self._lengths + + @property + def ends(self): + return self._starts_ends[1:] + + @property + def starts(self): + return self._starts_ends[:-1] diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 778dfd12a..ab92f78c6 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -159,7 +159,7 @@ cdef class Example: gold_values = self.reference.to_array([field]) output = [None] * len(self.predicted) for token in self.predicted: - values = gold_values[align[token.i].dataXd] + values = gold_values[align[token.i]] values = values.ravel() if len(values) == 0: output[token.i] = None @@ -190,9 +190,9 @@ cdef class Example: deps = [d if has_deps[i] else deps[i] for i, d in enumerate(proj_deps)] for cand_i in range(self.x.length): if cand_to_gold.lengths[cand_i] == 1: - gold_i = cand_to_gold[cand_i].dataXd[0, 0] + gold_i = cand_to_gold[cand_i][0] if gold_to_cand.lengths[heads[gold_i]] == 1: - aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0]) + aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]][0]) aligned_deps[cand_i] = deps[gold_i] return aligned_heads, aligned_deps @@ -204,7 +204,7 @@ cdef class Example: align = self.alignment.y2x sent_starts = [False] * len(self.x) for y_sent in self.y.sents: - x_start = int(align[y_sent.start].dataXd[0]) + x_start = int(align[y_sent.start][0]) sent_starts[x_start] = True return sent_starts else: @@ -220,7 +220,7 @@ cdef class Example: seen = set() output = [] for span in spans: - indices = align[span.start : span.end].data.ravel() + indices = align[span.start : span.end] if not allow_overlap: indices = [idx for idx in indices if idx not in seen] if len(indices) >= 1: @@ -316,7 +316,7 @@ cdef class Example: seen_indices = set() output = [] for y_sent in self.reference.sents: - indices = align[y_sent.start : y_sent.end].data.ravel() + indices = align[y_sent.start : y_sent.end] indices = [idx for idx in indices if idx not in seen_indices] if indices: x_sent = self.predicted[indices[0] : indices[-1] + 1]