mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Alignment: use a simplified ragged type for performance (#10319)
* Alignment: use a simplified ragged type for performance This introduces the AlignmentArray type, which is a simplified version of Ragged that performs better on the simple(r) indexing performed for alignment. * AlignmentArray: raise an error when using unsupported index * AlignmentArray: move error messages to Errors * AlignmentArray: remove simlified ... with simplifications * AlignmentArray: fix typo that broke a[n:n] indexing
This commit is contained in:
parent
03762b4b92
commit
c90dd6f265
1
setup.py
1
setup.py
|
@ -23,6 +23,7 @@ Options.docstrings = True
|
|||
|
||||
PACKAGES = find_packages()
|
||||
MOD_NAMES = [
|
||||
"spacy.training.alignment_array",
|
||||
"spacy.training.example",
|
||||
"spacy.parts_of_speech",
|
||||
"spacy.strings",
|
||||
|
|
|
@ -897,6 +897,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1025 = ("Cannot intify the value '{value}' as an IOB string. The only "
|
||||
"supported values are: 'I', 'O', 'B' and ''")
|
||||
E1026 = ("Edit tree has an invalid format:\n{errors}")
|
||||
E1027 = ("AlignmentArray only supports slicing with a step of 1.")
|
||||
E1028 = ("AlignmentArray only supports indexing using an int or a slice.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -218,7 +218,7 @@ def _get_aligned_sent_starts(example):
|
|||
sent_starts = [False] * len(example.x)
|
||||
seen_words = set()
|
||||
for y_sent in example.y.sents:
|
||||
x_indices = list(align[y_sent.start : y_sent.end].dataXd)
|
||||
x_indices = list(align[y_sent.start : y_sent.end])
|
||||
if any(x_idx in seen_words for x_idx in x_indices):
|
||||
# If there are any tokens in X that align across two sentences,
|
||||
# regard the sentence annotations as missing, as we can't
|
||||
|
|
|
@ -228,7 +228,7 @@ class Scorer:
|
|||
if token.orth_.isspace():
|
||||
continue
|
||||
if align.x2y.lengths[token.i] == 1:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
gold_i = align.x2y[token.i][0]
|
||||
if gold_i not in missing_indices:
|
||||
pred_tags.add((gold_i, getter(token, attr)))
|
||||
tag_score.score_set(pred_tags, gold_tags)
|
||||
|
@ -287,7 +287,7 @@ class Scorer:
|
|||
if token.orth_.isspace():
|
||||
continue
|
||||
if align.x2y.lengths[token.i] == 1:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
gold_i = align.x2y[token.i][0]
|
||||
if gold_i not in missing_indices:
|
||||
value = getter(token, attr)
|
||||
morph = gold_doc.vocab.strings[value]
|
||||
|
@ -694,13 +694,13 @@ class Scorer:
|
|||
if align.x2y.lengths[token.i] != 1:
|
||||
gold_i = None # type: ignore
|
||||
else:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
gold_i = align.x2y[token.i][0]
|
||||
if gold_i not in missing_indices:
|
||||
dep = getter(token, attr)
|
||||
head = head_getter(token, head_attr)
|
||||
if dep not in ignore_labels and token.orth_.strip():
|
||||
if align.x2y.lengths[head.i] == 1:
|
||||
gold_head = align.x2y[head.i].dataXd[0, 0]
|
||||
gold_head = align.x2y[head.i][0]
|
||||
else:
|
||||
gold_head = None
|
||||
# None is indistinct, so we can't just add it to the set
|
||||
|
@ -750,7 +750,7 @@ def get_ner_prf(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
for pred_ent in eg.x.ents:
|
||||
if pred_ent.label_ not in score_per_type:
|
||||
score_per_type[pred_ent.label_] = PRFScore()
|
||||
indices = align_x2y[pred_ent.start : pred_ent.end].dataXd.ravel()
|
||||
indices = align_x2y[pred_ent.start : pred_ent.end]
|
||||
if len(indices):
|
||||
g_span = eg.y[indices[0] : indices[-1] + 1]
|
||||
# Check we aren't missing annotation on this span. If so,
|
||||
|
|
|
@ -8,6 +8,7 @@ from spacy.tokens import Doc, DocBin
|
|||
from spacy.training import Alignment, Corpus, Example, biluo_tags_to_offsets
|
||||
from spacy.training import biluo_tags_to_spans, docs_to_json, iob_to_biluo
|
||||
from spacy.training import offsets_to_biluo_tags
|
||||
from spacy.training.alignment_array import AlignmentArray
|
||||
from spacy.training.align import get_alignments
|
||||
from spacy.training.converters import json_to_docs
|
||||
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
|
||||
|
@ -908,9 +909,41 @@ def test_alignment():
|
|||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts", "."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 6]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 6]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1]
|
||||
assert list(align.y2x.dataXd) == [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
assert list(align.y2x.data) == [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
|
||||
def test_alignment_array():
|
||||
a = AlignmentArray([[0, 1, 2], [3], [], [4, 5, 6, 7], [8, 9]])
|
||||
assert list(a.data) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
assert list(a.lengths) == [3, 1, 0, 4, 2]
|
||||
assert list(a[3]) == [4, 5, 6, 7]
|
||||
assert list(a[2]) == []
|
||||
assert list(a[-2]) == [4, 5, 6, 7]
|
||||
assert list(a[1:4]) == [3, 4, 5, 6, 7]
|
||||
assert list(a[1:]) == [3, 4, 5, 6, 7, 8, 9]
|
||||
assert list(a[:3]) == [0, 1, 2, 3]
|
||||
assert list(a[:]) == list(a.data)
|
||||
assert list(a[0:0]) == []
|
||||
assert list(a[3:3]) == []
|
||||
assert list(a[-1:-1]) == []
|
||||
with pytest.raises(ValueError, match=r"only supports slicing with a step of 1"):
|
||||
a[:4:-1]
|
||||
with pytest.raises(
|
||||
ValueError, match=r"only supports indexing using an int or a slice"
|
||||
):
|
||||
a[[0, 1, 3]]
|
||||
|
||||
a = AlignmentArray([[], [1, 2, 3], [4, 5]])
|
||||
assert list(a[0]) == []
|
||||
assert list(a[0:1]) == []
|
||||
assert list(a[2]) == [4, 5]
|
||||
assert list(a[0:2]) == [1, 2, 3]
|
||||
|
||||
a = AlignmentArray([[1, 2, 3], [4, 5], []])
|
||||
assert list(a[-1]) == []
|
||||
assert list(a[-2:]) == [4, 5]
|
||||
|
||||
|
||||
def test_alignment_case_insensitive():
|
||||
|
@ -918,9 +951,9 @@ def test_alignment_case_insensitive():
|
|||
spacy_tokens = ["i", "listened", "to", "Obama", "'s", "PODCASTS", "."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 6]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 6]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 1, 1]
|
||||
assert list(align.y2x.dataXd) == [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
assert list(align.y2x.data) == [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
|
||||
def test_alignment_complex():
|
||||
|
@ -928,9 +961,9 @@ def test_alignment_complex():
|
|||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5]
|
||||
assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_alignment_complex_example(en_vocab):
|
||||
|
@ -947,9 +980,9 @@ def test_alignment_complex_example(en_vocab):
|
|||
example = Example(predicted, reference)
|
||||
align = example.alignment
|
||||
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5]
|
||||
assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_alignment_different_texts():
|
||||
|
@ -965,70 +998,70 @@ def test_alignment_spaces(en_vocab):
|
|||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [0, 3, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [1, 1, 1, 2, 3, 4, 5, 6]
|
||||
assert list(align.y2x.data) == [1, 1, 1, 2, 3, 4, 5, 6]
|
||||
|
||||
# multiple leading whitespace tokens
|
||||
other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."]
|
||||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [0, 0, 3, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [2, 2, 2, 3, 4, 5, 6, 7]
|
||||
assert list(align.y2x.data) == [2, 2, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
# both with leading whitespace, not identical
|
||||
other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."]
|
||||
spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [1, 0, 3, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 5, 5, 6, 6]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 5, 5, 6, 6]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [0, 2, 2, 2, 3, 4, 5, 6, 7]
|
||||
assert list(align.y2x.data) == [0, 2, 2, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
# same leading whitespace, different tokenization
|
||||
other_tokens = [" ", " ", "i listened to", "obama", "'", "s", "podcasts", "."]
|
||||
spacy_tokens = [" ", "i", "listened", "to", "obama", "'s", "podcasts."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [1, 1, 3, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 0, 1, 2, 3, 4, 5, 5, 6, 6]
|
||||
assert list(align.x2y.data) == [0, 0, 1, 2, 3, 4, 5, 5, 6, 6]
|
||||
assert list(align.y2x.lengths) == [2, 1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [0, 1, 2, 2, 2, 3, 4, 5, 6, 7]
|
||||
assert list(align.y2x.data) == [0, 1, 2, 2, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
# only one with trailing whitespace
|
||||
other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " "]
|
||||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts."]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 0]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5]
|
||||
assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5]
|
||||
|
||||
# different trailing whitespace
|
||||
other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "]
|
||||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 0]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5, 6]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5, 6]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 1]
|
||||
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5, 6]
|
||||
assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5, 6]
|
||||
|
||||
# same trailing whitespace, different tokenization
|
||||
other_tokens = ["i listened to", "obama", "'", "s", "podcasts", ".", " ", " "]
|
||||
spacy_tokens = ["i", "listened", "to", "obama", "'s", "podcasts.", " "]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.lengths) == [3, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert list(align.x2y.dataXd) == [0, 1, 2, 3, 4, 4, 5, 5, 6, 6]
|
||||
assert list(align.x2y.data) == [0, 1, 2, 3, 4, 4, 5, 5, 6, 6]
|
||||
assert list(align.y2x.lengths) == [1, 1, 1, 1, 2, 2, 2]
|
||||
assert list(align.y2x.dataXd) == [0, 0, 0, 1, 2, 3, 4, 5, 6, 7]
|
||||
assert list(align.y2x.data) == [0, 0, 0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
# differing whitespace is allowed
|
||||
other_tokens = ["a", " \n ", "b", "c"]
|
||||
spacy_tokens = ["a", "b", " ", "c"]
|
||||
align = Alignment.from_strings(other_tokens, spacy_tokens)
|
||||
assert list(align.x2y.dataXd) == [0, 1, 3]
|
||||
assert list(align.y2x.dataXd) == [0, 2, 3]
|
||||
assert list(align.x2y.data) == [0, 1, 3]
|
||||
assert list(align.y2x.data) == [0, 2, 3]
|
||||
|
||||
# other differences in whitespace are allowed
|
||||
other_tokens = [" ", "a"]
|
||||
|
|
|
@ -1,31 +1,22 @@
|
|||
from typing import List
|
||||
import numpy
|
||||
from thinc.types import Ragged
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .align import get_alignments
|
||||
from .alignment_array import AlignmentArray
|
||||
|
||||
|
||||
@dataclass
|
||||
class Alignment:
|
||||
x2y: Ragged
|
||||
y2x: Ragged
|
||||
x2y: AlignmentArray
|
||||
y2x: AlignmentArray
|
||||
|
||||
@classmethod
|
||||
def from_indices(cls, x2y: List[List[int]], y2x: List[List[int]]) -> "Alignment":
|
||||
x2y = _make_ragged(x2y)
|
||||
y2x = _make_ragged(y2x)
|
||||
x2y = AlignmentArray(x2y)
|
||||
y2x = AlignmentArray(y2x)
|
||||
return Alignment(x2y=x2y, y2x=y2x)
|
||||
|
||||
@classmethod
|
||||
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
|
||||
x2y, y2x = get_alignments(A, B)
|
||||
return Alignment.from_indices(x2y=x2y, y2x=y2x)
|
||||
|
||||
|
||||
def _make_ragged(indices):
|
||||
lengths = numpy.array([len(x) for x in indices], dtype="i")
|
||||
flat = []
|
||||
for x in indices:
|
||||
flat.extend(x)
|
||||
return Ragged(numpy.array(flat, dtype="i"), lengths)
|
||||
|
|
7
spacy/training/alignment_array.pxd
Normal file
7
spacy/training/alignment_array.pxd
Normal file
|
@ -0,0 +1,7 @@
|
|||
from libcpp.vector cimport vector
|
||||
cimport numpy as np
|
||||
|
||||
cdef class AlignmentArray:
|
||||
cdef np.ndarray _data
|
||||
cdef np.ndarray _lengths
|
||||
cdef np.ndarray _starts_ends
|
68
spacy/training/alignment_array.pyx
Normal file
68
spacy/training/alignment_array.pyx
Normal file
|
@ -0,0 +1,68 @@
|
|||
from typing import List
|
||||
from ..errors import Errors
|
||||
import numpy
|
||||
|
||||
|
||||
cdef class AlignmentArray:
|
||||
"""AlignmentArray is similar to Thinc's Ragged with two simplfications:
|
||||
indexing returns numpy arrays and this type can only be used for CPU arrays.
|
||||
However, these changes make AlginmentArray more efficient for indexing in a
|
||||
tight loop."""
|
||||
|
||||
__slots__ = []
|
||||
|
||||
def __init__(self, alignment: List[List[int]]):
|
||||
self._lengths = None
|
||||
self._starts_ends = numpy.zeros(len(alignment) + 1, dtype="i")
|
||||
|
||||
cdef int data_len = 0
|
||||
cdef int outer_len
|
||||
cdef int idx
|
||||
for idx, outer in enumerate(alignment):
|
||||
outer_len = len(outer)
|
||||
self._starts_ends[idx + 1] = self._starts_ends[idx] + outer_len
|
||||
data_len += outer_len
|
||||
|
||||
self._data = numpy.empty(data_len, dtype="i")
|
||||
idx = 0
|
||||
for outer in alignment:
|
||||
for inner in outer:
|
||||
self._data[idx] = inner
|
||||
idx += 1
|
||||
|
||||
def __getitem__(self, idx):
|
||||
starts = self._starts_ends[:-1]
|
||||
ends = self._starts_ends[1:]
|
||||
if isinstance(idx, int):
|
||||
start = starts[idx]
|
||||
end = ends[idx]
|
||||
elif isinstance(idx, slice):
|
||||
if not (idx.step is None or idx.step == 1):
|
||||
raise ValueError(Errors.E1027)
|
||||
start = starts[idx]
|
||||
if len(start) == 0:
|
||||
return self._data[0:0]
|
||||
start = start[0]
|
||||
end = ends[idx][-1]
|
||||
else:
|
||||
raise ValueError(Errors.E1028)
|
||||
|
||||
return self._data[start:end]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def lengths(self):
|
||||
if self._lengths is None:
|
||||
self._lengths = self.ends - self.starts
|
||||
return self._lengths
|
||||
|
||||
@property
|
||||
def ends(self):
|
||||
return self._starts_ends[1:]
|
||||
|
||||
@property
|
||||
def starts(self):
|
||||
return self._starts_ends[:-1]
|
|
@ -159,7 +159,7 @@ cdef class Example:
|
|||
gold_values = self.reference.to_array([field])
|
||||
output = [None] * len(self.predicted)
|
||||
for token in self.predicted:
|
||||
values = gold_values[align[token.i].dataXd]
|
||||
values = gold_values[align[token.i]]
|
||||
values = values.ravel()
|
||||
if len(values) == 0:
|
||||
output[token.i] = None
|
||||
|
@ -190,9 +190,9 @@ cdef class Example:
|
|||
deps = [d if has_deps[i] else deps[i] for i, d in enumerate(proj_deps)]
|
||||
for cand_i in range(self.x.length):
|
||||
if cand_to_gold.lengths[cand_i] == 1:
|
||||
gold_i = cand_to_gold[cand_i].dataXd[0, 0]
|
||||
gold_i = cand_to_gold[cand_i][0]
|
||||
if gold_to_cand.lengths[heads[gold_i]] == 1:
|
||||
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0])
|
||||
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]][0])
|
||||
aligned_deps[cand_i] = deps[gold_i]
|
||||
return aligned_heads, aligned_deps
|
||||
|
||||
|
@ -204,7 +204,7 @@ cdef class Example:
|
|||
align = self.alignment.y2x
|
||||
sent_starts = [False] * len(self.x)
|
||||
for y_sent in self.y.sents:
|
||||
x_start = int(align[y_sent.start].dataXd[0])
|
||||
x_start = int(align[y_sent.start][0])
|
||||
sent_starts[x_start] = True
|
||||
return sent_starts
|
||||
else:
|
||||
|
@ -220,7 +220,7 @@ cdef class Example:
|
|||
seen = set()
|
||||
output = []
|
||||
for span in spans:
|
||||
indices = align[span.start : span.end].data.ravel()
|
||||
indices = align[span.start : span.end]
|
||||
if not allow_overlap:
|
||||
indices = [idx for idx in indices if idx not in seen]
|
||||
if len(indices) >= 1:
|
||||
|
@ -316,7 +316,7 @@ cdef class Example:
|
|||
seen_indices = set()
|
||||
output = []
|
||||
for y_sent in self.reference.sents:
|
||||
indices = align[y_sent.start : y_sent.end].data.ravel()
|
||||
indices = align[y_sent.start : y_sent.end]
|
||||
indices = [idx for idx in indices if idx not in seen_indices]
|
||||
if indices:
|
||||
x_sent = self.predicted[indices[0] : indices[-1] + 1]
|
||||
|
|
Loading…
Reference in New Issue
Block a user