Improve gold-standard alignment (#5711)

* Remove previous alignment

* Implement better alignment, using ragged data structure

* Use pytokenizations for alignment

* Fixes

* Fixes

* Fix overlapping entities in alignment

* Fix align split_sents

* Update test

* Commit align.py

* Try to appease setuptools

* Fix flake8

* use realistic entities for testing

* Update tests for better alignment

* Improve alignment heuristic

Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
Matthew Honnibal 2020-07-06 17:39:31 +02:00 committed by GitHub
parent a35236e5f0
commit cc477be952
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 167 additions and 245 deletions

View File

@ -7,6 +7,7 @@ requires = [
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a12,<8.0.0a20",
"blis>=0.4.0,<0.5.0"
"blis>=0.4.0,<0.5.0",
"pytokenizations"
]
build-backend = "setuptools.build_meta"

View File

@ -14,6 +14,7 @@ numpy>=1.15.0
requests>=2.13.0,<3.0.0
tqdm>=4.38.0,<5.0.0
pydantic>=1.3.0,<2.0.0
pytokenizations
# Official Python utilities
setuptools
packaging

View File

@ -51,6 +51,7 @@ install_requires =
numpy>=1.15.0
requests>=2.13.0,<3.0.0
pydantic>=1.3.0,<2.0.0
pytokenizations
# Official Python utilities
setuptools
packaging

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python
from setuptools import Extension, setup, find_packages
import sys
import platform
from distutils.command.build_ext import build_ext
from distutils.sysconfig import get_python_inc
import distutils.util
from distutils import ccompiler, msvccompiler
from setuptools import Extension, setup, find_packages
import numpy
from pathlib import Path
import shutil
@ -23,7 +23,6 @@ Options.docstrings = True
PACKAGES = find_packages()
MOD_NAMES = [
"spacy.gold.align",
"spacy.gold.example",
"spacy.parts_of_speech",
"spacy.strings",

View File

@ -1,6 +1,6 @@
from .corpus import Corpus
from .example import Example
from .align import align
from .align import Alignment
from .iob_utils import iob_to_biluo, biluo_to_iob
from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags

View File

@ -1,8 +0,0 @@
cdef class Alignment:
cdef public object cost
cdef public object i2j
cdef public object j2i
cdef public object i2j_multi
cdef public object j2i_multi
cdef public object cand_to_gold
cdef public object gold_to_cand

30
spacy/gold/align.py Normal file
View File

@ -0,0 +1,30 @@
from typing import List
import numpy
from thinc.types import Ragged
from dataclasses import dataclass
import tokenizations
@dataclass
class Alignment:
x2y: Ragged
y2x: Ragged
@classmethod
def from_indices(cls, x2y: List[List[int]], y2x: List[List[int]]) -> "Alignment":
x2y = _make_ragged(x2y)
y2x = _make_ragged(y2x)
return Alignment(x2y=x2y, y2x=y2x)
@classmethod
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
x2y, y2x = tokenizations.get_alignments(A, B)
return Alignment.from_indices(x2y=x2y, y2x=y2x)
def _make_ragged(indices):
lengths = numpy.array([len(x) for x in indices], dtype="i")
flat = []
for x in indices:
flat.extend(x)
return Ragged(numpy.array(flat, dtype="i"), lengths)

View File

@ -1,101 +0,0 @@
import numpy
from ..errors import Errors, AlignmentError
cdef class Alignment:
def __init__(self, spacy_words, gold_words):
# Do many-to-one alignment for misaligned tokens.
# If we over-segment, we'll have one gold word that covers a sequence
# of predicted words
# If we under-segment, we'll have one predicted word that covers a
# sequence of gold words.
# If we "mis-segment", we'll have a sequence of predicted words covering
# a sequence of gold words. That's many-to-many -- we don't do that
# except for NER spans where the start and end can be aligned.
cost, i2j, j2i, i2j_multi, j2i_multi = align(spacy_words, gold_words)
self.cost = cost
self.i2j = i2j
self.j2i = j2i
self.i2j_multi = i2j_multi
self.j2i_multi = j2i_multi
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
def align(tokens_a, tokens_b):
"""Calculate alignment tables between two tokenizations.
tokens_a (List[str]): The candidate tokenization.
tokens_b (List[str]): The reference tokenization.
RETURNS: (tuple): A 5-tuple consisting of the following information:
* cost (int): The number of misaligned tokens.
* a2b (List[int]): Mapping of indices in `tokens_a` to indices in `tokens_b`.
For instance, if `a2b[4] == 6`, that means that `tokens_a[4]` aligns
to `tokens_b[6]`. If there's no one-to-one alignment for a token,
it has the value -1.
* b2a (List[int]): The same as `a2b`, but mapping the other direction.
* a2b_multi (Dict[int, int]): A dictionary mapping indices in `tokens_a`
to indices in `tokens_b`, where multiple tokens of `tokens_a` align to
the same token of `tokens_b`.
* b2a_multi (Dict[int, int]): As with `a2b_multi`, but mapping the other
direction.
"""
tokens_a = _normalize_for_alignment(tokens_a)
tokens_b = _normalize_for_alignment(tokens_b)
cost = 0
a2b = numpy.empty(len(tokens_a), dtype="i")
b2a = numpy.empty(len(tokens_b), dtype="i")
a2b.fill(-1)
b2a.fill(-1)
a2b_multi = {}
b2a_multi = {}
i = 0
j = 0
offset_a = 0
offset_b = 0
while i < len(tokens_a) and j < len(tokens_b):
a = tokens_a[i][offset_a:]
b = tokens_b[j][offset_b:]
if a == b:
if offset_a == offset_b == 0:
a2b[i] = j
b2a[j] = i
elif offset_a == 0:
cost += 2
a2b_multi[i] = j
elif offset_b == 0:
cost += 2
b2a_multi[j] = i
offset_a = offset_b = 0
i += 1
j += 1
elif a == "":
assert offset_a == 0
cost += 1
i += 1
elif b == "":
assert offset_b == 0
cost += 1
j += 1
elif b.startswith(a):
cost += 1
if offset_a == 0:
a2b_multi[i] = j
i += 1
offset_a = 0
offset_b += len(a)
elif a.startswith(b):
cost += 1
if offset_b == 0:
b2a_multi[j] = i
j += 1
offset_b = 0
offset_a += len(b)
else:
assert "".join(tokens_a) != "".join(tokens_b)
raise AlignmentError(Errors.E186.format(tok_a=tokens_a, tok_b=tokens_b))
return cost, a2b, b2a, a2b_multi, b2a_multi
def _normalize_for_alignment(tokens):
return [w.replace(" ", "").lower() for w in tokens]

View File

@ -1,8 +1,7 @@
from ..tokens.doc cimport Doc
from .align cimport Alignment
cdef class Example:
cdef readonly Doc x
cdef readonly Doc y
cdef readonly Alignment _alignment
cdef readonly object _alignment

View File

@ -6,10 +6,9 @@ from ..tokens.doc cimport Doc
from ..tokens.span cimport Span
from ..tokens.span import Span
from ..attrs import IDS
from .align cimport Alignment
from .align import Alignment
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
from .iob_utils import spans_from_biluo_tags
from .align import Alignment
from ..errors import Errors, Warnings
from ..syntax import nonproj
@ -28,7 +27,7 @@ cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
cdef class Example:
def __init__(self, Doc predicted, Doc reference, *, Alignment alignment=None):
def __init__(self, Doc predicted, Doc reference, *, alignment=None):
""" Doc can either be text, or an actual Doc """
if predicted is None:
raise TypeError(Errors.E972.format(arg="predicted"))
@ -83,34 +82,38 @@ cdef class Example:
gold_words = [token.orth_ for token in self.reference]
if gold_words == []:
gold_words = spacy_words
self._alignment = Alignment(spacy_words, gold_words)
self._alignment = Alignment.from_strings(spacy_words, gold_words)
return self._alignment
def get_aligned(self, field, as_string=False):
"""Return an aligned array for a token attribute."""
i2j_multi = self.alignment.i2j_multi
cand_to_gold = self.alignment.cand_to_gold
align = self.alignment.x2y
vocab = self.reference.vocab
gold_values = self.reference.to_array([field])
output = [None] * len(self.predicted)
for i, gold_i in enumerate(cand_to_gold):
if self.predicted[i].text.isspace():
output[i] = None
if gold_i is None:
if i in i2j_multi:
output[i] = gold_values[i2j_multi[i]]
else:
output[i] = None
for token in self.predicted:
if token.is_space:
output[token.i] = None
else:
output[i] = gold_values[gold_i]
values = gold_values[align[token.i].dataXd]
values = values.ravel()
if len(values) == 0:
output[token.i] = None
elif len(values) == 1:
output[token.i] = values[0]
elif len(set(list(values))) == 1:
# If all aligned tokens have the same value, use it.
output[token.i] = values[0]
else:
output[token.i] = None
if as_string and field not in ["ENT_IOB", "SENT_START"]:
output = [vocab.strings[o] if o is not None else o for o in output]
return output
def get_aligned_parse(self, projectivize=True):
cand_to_gold = self.alignment.cand_to_gold
gold_to_cand = self.alignment.gold_to_cand
cand_to_gold = self.alignment.x2y
gold_to_cand = self.alignment.y2x
aligned_heads = [None] * self.x.length
aligned_deps = [None] * self.x.length
heads = [token.head.i for token in self.y]
@ -118,52 +121,51 @@ cdef class Example:
if projectivize:
heads, deps = nonproj.projectivize(heads, deps)
for cand_i in range(self.x.length):
gold_i = cand_to_gold[cand_i]
if gold_i is not None: # Alignment found
gold_head = gold_to_cand[heads[gold_i]]
if gold_head is not None:
aligned_heads[cand_i] = gold_head
if cand_to_gold.lengths[cand_i] == 1:
gold_i = cand_to_gold[cand_i].dataXd[0, 0]
if gold_to_cand.lengths[heads[gold_i]] == 1:
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0])
aligned_deps[cand_i] = deps[gold_i]
return aligned_heads, aligned_deps
def get_aligned_spans_x2y(self, x_spans):
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
def get_aligned_spans_y2x(self, y_spans):
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x)
def _get_aligned_spans(self, doc, spans, align):
seen = set()
output = []
for span in spans:
indices = align[span.start : span.end].data.ravel()
indices = [idx for idx in indices if idx not in seen]
if len(indices) >= 1:
aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label)
target_text = span.text.lower().strip().replace(" ", "")
our_text = aligned_span.text.lower().strip().replace(" ", "")
if our_text == target_text:
output.append(aligned_span)
seen.update(indices)
return output
def get_aligned_ner(self):
if not self.y.is_nered:
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
x_text = self.x.text
# Get a list of entities, and make spans for non-entity tokens.
# We then work through the spans in order, trying to find them in
# the text and using that to get the offset. Any token that doesn't
# get a tag set this way is tagged None.
# This could maybe be improved? It at least feels easy to reason about.
y_spans = list(self.y.ents)
y_spans.sort()
x_text_offset = 0
x_spans = []
for y_span in y_spans:
if x_text.count(y_span.text) >= 1:
start_char = x_text.index(y_span.text) + x_text_offset
end_char = start_char + len(y_span.text)
x_span = self.x.char_span(start_char, end_char, label=y_span.label)
if x_span is not None:
x_spans.append(x_span)
x_text = self.x.text[end_char:]
x_text_offset = end_char
x_ents = self.get_aligned_spans_y2x(self.y.ents)
# Default to 'None' for missing values
x_tags = biluo_tags_from_offsets(
self.x,
[(e.start_char, e.end_char, e.label_) for e in x_spans],
[(e.start_char, e.end_char, e.label_) for e in x_ents],
missing=None
)
gold_to_cand = self.alignment.gold_to_cand
for token in self.y:
if token.ent_iob_ == "O":
cand_i = gold_to_cand[token.i]
if cand_i is not None and x_tags[cand_i] is None:
x_tags[cand_i] = "O"
i2j_multi = self.alignment.i2j_multi
for i, tag in enumerate(x_tags):
if tag is None and i in i2j_multi:
gold_i = i2j_multi[i]
if gold_i is not None and self.y[gold_i].ent_iob_ == "O":
# Now fill the tokens we can align to O.
O = 2 # I=1, O=2, B=3
for i, ent_iob in enumerate(self.get_aligned("ENT_IOB")):
if x_tags[i] is None:
if ent_iob == O:
x_tags[i] = "O"
elif self.x[i].is_space:
x_tags[i] = "O"
return x_tags
@ -194,25 +196,22 @@ cdef class Example:
links[(ent.start_char, ent.end_char)] = {ent.kb_id_: 1.0}
return links
def split_sents(self):
""" Split the token annotations into multiple Examples based on
sent_starts and return a list of the new Examples"""
if not self.reference.is_sentenced:
return [self]
sent_starts = self.get_aligned("SENT_START")
sent_starts.append(1) # appending virtual start of a next sentence to facilitate search
align = self.alignment.y2x
seen_indices = set()
output = []
pred_start = 0
for sent in self.reference.sents:
new_ref = sent.as_doc()
pred_end = sent_starts.index(1, pred_start+1) # find where the next sentence starts
new_pred = self.predicted[pred_start : pred_end].as_doc()
output.append(Example(new_pred, new_ref))
pred_start = pred_end
for y_sent in self.reference.sents:
indices = align[y_sent.start : y_sent.end].data.ravel()
indices = [idx for idx in indices if idx not in seen_indices]
if indices:
x_sent = self.predicted[indices[0] : indices[-1] + 1]
output.append(Example(x_sent.as_doc(), y_sent.as_doc()))
seen_indices.update(indices)
return output
property text:

View File

@ -326,10 +326,11 @@ class Scorer(object):
for token in doc:
if token.orth_.isspace():
continue
gold_i = align.cand_to_gold[token.i]
if gold_i is None:
if align.x2y.lengths[token.i] != 1:
self.tokens.fp += 1
gold_i = None
else:
gold_i = align.x2y[token.i].dataXd[0, 0]
self.tokens.tp += 1
cand_tags.add((gold_i, token.tag_))
cand_pos.add((gold_i, token.pos_))
@ -345,7 +346,10 @@ class Scorer(object):
if token.is_sent_start:
cand_sent_starts.add(gold_i)
if token.dep_.lower() not in punct_labels and token.orth_.strip():
gold_head = align.cand_to_gold[token.head.i]
if align.x2y.lengths[token.head.i] == 1:
gold_head = align.x2y[token.head.i].dataXd[0, 0]
else:
gold_head = None
# None is indistinct, so we can't just add it to the set
# Multiple (None, None) deps are possible
if gold_i is None or gold_head is None:
@ -381,15 +385,9 @@ class Scorer(object):
gold_ents.add(gold_ent)
gold_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
for ent in doc.ents:
first = align.cand_to_gold[ent.start]
last = align.cand_to_gold[ent.end - 1]
if first is None or last is None:
self.ner.fp += 1
self.ner_per_ents[ent.label_].fp += 1
else:
cand_ents.add((ent.label_, first, last))
cand_per_ents[ent.label_].add((ent.label_, first, last))
for ent in example.get_aligned_spans_x2y(doc.ents):
cand_ents.add((ent.label_, ent.start, ent.end - 1))
cand_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
# Scores per ent
for k, v in self.ner_per_ents.items():
if k in cand_per_ents:

View File

@ -1,6 +1,6 @@
from spacy.errors import AlignmentError
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags, iob_to_biluo, align
from spacy.gold import spans_from_biluo_tags, iob_to_biluo
from spacy.gold import Corpus, docs_to_json
from spacy.gold.example import Example
from spacy.gold.converters import json2docs
@ -271,75 +271,76 @@ def test_split_sentences(en_vocab):
assert split_examples[1].text == "had loads of fun "
@pytest.mark.xfail(reason="Alignment should be fixed after example refactor")
def test_gold_biluo_one_to_many(en_vocab, en_tokenizer):
words = ["I", "flew to", "San Francisco Valley", "."]
spaces = [True, True, False, False]
words = ["Mr. and ", "Mrs. Smith", "flew to", "San Francisco Valley", "."]
spaces = [True, True, True, False, False]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
prefix = "Mr. and Mrs. Smith flew to "
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
gold_words = ["Mr. and Mrs. Smith", "flew", "to", "San", "Francisco", "Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "O", "U-LOC", "O"]
assert ner_tags == ["O", "O", "O", "U-LOC", "O"]
entities = [
(len("I "), len("I flew to"), "ORG"),
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
(len("Mr. and "), len("Mr. and Mrs. Smith"), "PERSON"), # "Mrs. Smith" is a PERSON
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
]
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
gold_words = ["Mr. and", "Mrs.", "Smith", "flew", "to", "San", "Francisco", "Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "U-ORG", "U-LOC", "O"]
assert ner_tags == ["O", "U-PERSON", "O", "U-LOC", "O"]
entities = [
(len("I "), len("I flew"), "ORG"),
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
(len("Mr. and "), len("Mr. and Mrs."), "PERSON"), # "Mrs." is a Person
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
]
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
gold_words = ["Mr. and", "Mrs.", "Smith", "flew", "to", "San", "Francisco", "Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", None, "U-LOC", "O"]
assert ner_tags == ["O", None, "O", "U-LOC", "O"]
def test_gold_biluo_many_to_one(en_vocab, en_tokenizer):
words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
words = ["Mr. and", "Mrs.", "Smith", "flew", "to", "San", "Francisco", "Valley", "."]
spaces = [True, True, True, True, True, True, True, False, False]
doc = Doc(en_vocab, words=words, spaces=spaces)
prefix = "Mr. and Mrs. Smith flew to "
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
gold_words = ["Mr. and Mrs. Smith", "flew to", "San Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "O", "O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
entities = [
(len("Mr. and "), len("Mr. and Mrs. Smith"), "PERSON"), # "Mrs. Smith" is a PERSON
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
]
gold_words = ["Mr. and", "Mrs. Smith", "flew to", "San Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "B-PERSON", "L-PERSON", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
def test_gold_biluo_misaligned(en_vocab, en_tokenizer):
words = ["Mr. and Mrs.", "Smith", "flew", "to", "San Francisco", "Valley", "."]
spaces = [True, True, True, True, True, False, False]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
gold_words = ["I", "flew to", "San Francisco Valley", "."]
prefix = "Mr. and Mrs. Smith flew to "
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
gold_words = ["Mr.", "and Mrs. Smith", "flew to", "San", "Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
assert ner_tags == ["O", "O", "O", "O", "B-LOC", "L-LOC", "O"]
entities = [
(len("I "), len("I flew to"), "ORG"),
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
(len("Mr. and "), len("Mr. and Mrs. Smith"), "PERSON"), # "Mrs. Smith" is a PERSON
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
]
gold_words = ["I", "flew to", "San Francisco Valley", "."]
gold_words = ["Mr. and", "Mrs. Smith", "flew to", "San", "Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "B-ORG", "L-ORG", "B-LOC", "I-LOC", "L-LOC", "O"]
@pytest.mark.xfail(reason="Alignment should be fixed after example refactor")
def test_gold_biluo_misaligned(en_vocab, en_tokenizer):
words = ["I flew", "to", "San Francisco", "Valley", "."]
spaces = [True, True, True, False, False]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == ["O", "O", "B-LOC", "L-LOC", "O"]
entities = [
(len("I "), len("I flew to"), "ORG"),
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
]
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
ner_tags = example.get_aligned_ner()
assert ner_tags == [None, None, "B-LOC", "L-LOC", "O"]
assert ner_tags == [None, None, "O", "O", "B-LOC", "L-LOC", "O"]
def test_gold_biluo_additional_whitespace(en_vocab, en_tokenizer):
@ -349,7 +350,8 @@ def test_gold_biluo_additional_whitespace(en_vocab, en_tokenizer):
"I flew to San Francisco Valley.",
)
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
prefix = "I flew to "
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
gold_words = ["I", "flew", " ", "to", "San Francisco Valley", "."]
gold_spaces = [True, True, False, True, False, False]
example = Example.from_dict(
@ -514,6 +516,7 @@ def test_make_orth_variants(doc):
make_orth_variants_example(nlp, train_example, orth_variant_level=0.2)
@pytest.mark.skip("Outdated")
@pytest.mark.parametrize(
"tokens_a,tokens_b,expected",
[
@ -537,12 +540,12 @@ def test_make_orth_variants(doc):
([" ", "a"], ["a"], (1, [-1, 0], [1], {}, {})),
],
)
def test_align(tokens_a, tokens_b, expected):
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b)
assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected
def test_align(tokens_a, tokens_b, expected): # noqa
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b) # noqa
assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected # noqa
# check symmetry
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a)
assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a) # noqa
assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected # noqa
def test_goldparse_startswith_space(en_tokenizer):
@ -556,7 +559,7 @@ def test_goldparse_startswith_space(en_tokenizer):
doc, {"words": gold_words, "entities": entities, "deps": deps, "heads": heads}
)
ner_tags = example.get_aligned_ner()
assert ner_tags == [None, "U-DATE"]
assert ner_tags == ["O", "U-DATE"]
assert example.get_aligned("DEP", as_string=True) == [None, "ROOT"]

View File

@ -55,7 +55,7 @@ def test_aligned_tags():
predicted = Doc(vocab, words=pred_words)
example = Example.from_dict(predicted, annots)
aligned_tags = example.get_aligned("tag", as_string=True)
assert aligned_tags == ["VERB", "DET", None, "SCONJ", "PRON", "VERB", "VERB"]
assert aligned_tags == ["VERB", "DET", "NOUN", "SCONJ", "PRON", "VERB", "VERB"]
def test_aligned_tags_multi():