mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Improve gold-standard alignment (#5711)
* Remove previous alignment * Implement better alignment, using ragged data structure * Use pytokenizations for alignment * Fixes * Fixes * Fix overlapping entities in alignment * Fix align split_sents * Update test * Commit align.py * Try to appease setuptools * Fix flake8 * use realistic entities for testing * Update tests for better alignment * Improve alignment heuristic Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com>
This commit is contained in:
parent
a35236e5f0
commit
cc477be952
|
@ -7,6 +7,7 @@ requires = [
|
|||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.0.0a12,<8.0.0a20",
|
||||
"blis>=0.4.0,<0.5.0"
|
||||
"blis>=0.4.0,<0.5.0",
|
||||
"pytokenizations"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
|
|
@ -14,6 +14,7 @@ numpy>=1.15.0
|
|||
requests>=2.13.0,<3.0.0
|
||||
tqdm>=4.38.0,<5.0.0
|
||||
pydantic>=1.3.0,<2.0.0
|
||||
pytokenizations
|
||||
# Official Python utilities
|
||||
setuptools
|
||||
packaging
|
||||
|
|
|
@ -51,6 +51,7 @@ install_requires =
|
|||
numpy>=1.15.0
|
||||
requests>=2.13.0,<3.0.0
|
||||
pydantic>=1.3.0,<2.0.0
|
||||
pytokenizations
|
||||
# Official Python utilities
|
||||
setuptools
|
||||
packaging
|
||||
|
|
3
setup.py
3
setup.py
|
@ -1,11 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
from setuptools import Extension, setup, find_packages
|
||||
import sys
|
||||
import platform
|
||||
from distutils.command.build_ext import build_ext
|
||||
from distutils.sysconfig import get_python_inc
|
||||
import distutils.util
|
||||
from distutils import ccompiler, msvccompiler
|
||||
from setuptools import Extension, setup, find_packages
|
||||
import numpy
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
@ -23,7 +23,6 @@ Options.docstrings = True
|
|||
|
||||
PACKAGES = find_packages()
|
||||
MOD_NAMES = [
|
||||
"spacy.gold.align",
|
||||
"spacy.gold.example",
|
||||
"spacy.parts_of_speech",
|
||||
"spacy.strings",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .corpus import Corpus
|
||||
from .example import Example
|
||||
from .align import align
|
||||
from .align import Alignment
|
||||
|
||||
from .iob_utils import iob_to_biluo, biluo_to_iob
|
||||
from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
cdef class Alignment:
|
||||
cdef public object cost
|
||||
cdef public object i2j
|
||||
cdef public object j2i
|
||||
cdef public object i2j_multi
|
||||
cdef public object j2i_multi
|
||||
cdef public object cand_to_gold
|
||||
cdef public object gold_to_cand
|
30
spacy/gold/align.py
Normal file
30
spacy/gold/align.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from typing import List
|
||||
import numpy
|
||||
from thinc.types import Ragged
|
||||
from dataclasses import dataclass
|
||||
import tokenizations
|
||||
|
||||
|
||||
@dataclass
|
||||
class Alignment:
|
||||
x2y: Ragged
|
||||
y2x: Ragged
|
||||
|
||||
@classmethod
|
||||
def from_indices(cls, x2y: List[List[int]], y2x: List[List[int]]) -> "Alignment":
|
||||
x2y = _make_ragged(x2y)
|
||||
y2x = _make_ragged(y2x)
|
||||
return Alignment(x2y=x2y, y2x=y2x)
|
||||
|
||||
@classmethod
|
||||
def from_strings(cls, A: List[str], B: List[str]) -> "Alignment":
|
||||
x2y, y2x = tokenizations.get_alignments(A, B)
|
||||
return Alignment.from_indices(x2y=x2y, y2x=y2x)
|
||||
|
||||
|
||||
def _make_ragged(indices):
|
||||
lengths = numpy.array([len(x) for x in indices], dtype="i")
|
||||
flat = []
|
||||
for x in indices:
|
||||
flat.extend(x)
|
||||
return Ragged(numpy.array(flat, dtype="i"), lengths)
|
|
@ -1,101 +0,0 @@
|
|||
import numpy
|
||||
from ..errors import Errors, AlignmentError
|
||||
|
||||
|
||||
cdef class Alignment:
|
||||
def __init__(self, spacy_words, gold_words):
|
||||
# Do many-to-one alignment for misaligned tokens.
|
||||
# If we over-segment, we'll have one gold word that covers a sequence
|
||||
# of predicted words
|
||||
# If we under-segment, we'll have one predicted word that covers a
|
||||
# sequence of gold words.
|
||||
# If we "mis-segment", we'll have a sequence of predicted words covering
|
||||
# a sequence of gold words. That's many-to-many -- we don't do that
|
||||
# except for NER spans where the start and end can be aligned.
|
||||
cost, i2j, j2i, i2j_multi, j2i_multi = align(spacy_words, gold_words)
|
||||
self.cost = cost
|
||||
self.i2j = i2j
|
||||
self.j2i = j2i
|
||||
self.i2j_multi = i2j_multi
|
||||
self.j2i_multi = j2i_multi
|
||||
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
|
||||
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
|
||||
|
||||
|
||||
def align(tokens_a, tokens_b):
|
||||
"""Calculate alignment tables between two tokenizations.
|
||||
|
||||
tokens_a (List[str]): The candidate tokenization.
|
||||
tokens_b (List[str]): The reference tokenization.
|
||||
RETURNS: (tuple): A 5-tuple consisting of the following information:
|
||||
* cost (int): The number of misaligned tokens.
|
||||
* a2b (List[int]): Mapping of indices in `tokens_a` to indices in `tokens_b`.
|
||||
For instance, if `a2b[4] == 6`, that means that `tokens_a[4]` aligns
|
||||
to `tokens_b[6]`. If there's no one-to-one alignment for a token,
|
||||
it has the value -1.
|
||||
* b2a (List[int]): The same as `a2b`, but mapping the other direction.
|
||||
* a2b_multi (Dict[int, int]): A dictionary mapping indices in `tokens_a`
|
||||
to indices in `tokens_b`, where multiple tokens of `tokens_a` align to
|
||||
the same token of `tokens_b`.
|
||||
* b2a_multi (Dict[int, int]): As with `a2b_multi`, but mapping the other
|
||||
direction.
|
||||
"""
|
||||
tokens_a = _normalize_for_alignment(tokens_a)
|
||||
tokens_b = _normalize_for_alignment(tokens_b)
|
||||
cost = 0
|
||||
a2b = numpy.empty(len(tokens_a), dtype="i")
|
||||
b2a = numpy.empty(len(tokens_b), dtype="i")
|
||||
a2b.fill(-1)
|
||||
b2a.fill(-1)
|
||||
a2b_multi = {}
|
||||
b2a_multi = {}
|
||||
i = 0
|
||||
j = 0
|
||||
offset_a = 0
|
||||
offset_b = 0
|
||||
while i < len(tokens_a) and j < len(tokens_b):
|
||||
a = tokens_a[i][offset_a:]
|
||||
b = tokens_b[j][offset_b:]
|
||||
if a == b:
|
||||
if offset_a == offset_b == 0:
|
||||
a2b[i] = j
|
||||
b2a[j] = i
|
||||
elif offset_a == 0:
|
||||
cost += 2
|
||||
a2b_multi[i] = j
|
||||
elif offset_b == 0:
|
||||
cost += 2
|
||||
b2a_multi[j] = i
|
||||
offset_a = offset_b = 0
|
||||
i += 1
|
||||
j += 1
|
||||
elif a == "":
|
||||
assert offset_a == 0
|
||||
cost += 1
|
||||
i += 1
|
||||
elif b == "":
|
||||
assert offset_b == 0
|
||||
cost += 1
|
||||
j += 1
|
||||
elif b.startswith(a):
|
||||
cost += 1
|
||||
if offset_a == 0:
|
||||
a2b_multi[i] = j
|
||||
i += 1
|
||||
offset_a = 0
|
||||
offset_b += len(a)
|
||||
elif a.startswith(b):
|
||||
cost += 1
|
||||
if offset_b == 0:
|
||||
b2a_multi[j] = i
|
||||
j += 1
|
||||
offset_b = 0
|
||||
offset_a += len(b)
|
||||
else:
|
||||
assert "".join(tokens_a) != "".join(tokens_b)
|
||||
raise AlignmentError(Errors.E186.format(tok_a=tokens_a, tok_b=tokens_b))
|
||||
return cost, a2b, b2a, a2b_multi, b2a_multi
|
||||
|
||||
|
||||
def _normalize_for_alignment(tokens):
|
||||
return [w.replace(" ", "").lower() for w in tokens]
|
|
@ -1,8 +1,7 @@
|
|||
from ..tokens.doc cimport Doc
|
||||
from .align cimport Alignment
|
||||
|
||||
|
||||
cdef class Example:
|
||||
cdef readonly Doc x
|
||||
cdef readonly Doc y
|
||||
cdef readonly Alignment _alignment
|
||||
cdef readonly object _alignment
|
||||
|
|
|
@ -6,10 +6,9 @@ from ..tokens.doc cimport Doc
|
|||
from ..tokens.span cimport Span
|
||||
from ..tokens.span import Span
|
||||
from ..attrs import IDS
|
||||
from .align cimport Alignment
|
||||
from .align import Alignment
|
||||
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
|
||||
from .iob_utils import spans_from_biluo_tags
|
||||
from .align import Alignment
|
||||
from ..errors import Errors, Warnings
|
||||
from ..syntax import nonproj
|
||||
|
||||
|
@ -28,7 +27,7 @@ cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
|
|||
|
||||
|
||||
cdef class Example:
|
||||
def __init__(self, Doc predicted, Doc reference, *, Alignment alignment=None):
|
||||
def __init__(self, Doc predicted, Doc reference, *, alignment=None):
|
||||
""" Doc can either be text, or an actual Doc """
|
||||
if predicted is None:
|
||||
raise TypeError(Errors.E972.format(arg="predicted"))
|
||||
|
@ -83,34 +82,38 @@ cdef class Example:
|
|||
gold_words = [token.orth_ for token in self.reference]
|
||||
if gold_words == []:
|
||||
gold_words = spacy_words
|
||||
self._alignment = Alignment(spacy_words, gold_words)
|
||||
self._alignment = Alignment.from_strings(spacy_words, gold_words)
|
||||
return self._alignment
|
||||
|
||||
def get_aligned(self, field, as_string=False):
|
||||
"""Return an aligned array for a token attribute."""
|
||||
i2j_multi = self.alignment.i2j_multi
|
||||
cand_to_gold = self.alignment.cand_to_gold
|
||||
align = self.alignment.x2y
|
||||
|
||||
vocab = self.reference.vocab
|
||||
gold_values = self.reference.to_array([field])
|
||||
output = [None] * len(self.predicted)
|
||||
for i, gold_i in enumerate(cand_to_gold):
|
||||
if self.predicted[i].text.isspace():
|
||||
output[i] = None
|
||||
if gold_i is None:
|
||||
if i in i2j_multi:
|
||||
output[i] = gold_values[i2j_multi[i]]
|
||||
else:
|
||||
output[i] = None
|
||||
for token in self.predicted:
|
||||
if token.is_space:
|
||||
output[token.i] = None
|
||||
else:
|
||||
output[i] = gold_values[gold_i]
|
||||
values = gold_values[align[token.i].dataXd]
|
||||
values = values.ravel()
|
||||
if len(values) == 0:
|
||||
output[token.i] = None
|
||||
elif len(values) == 1:
|
||||
output[token.i] = values[0]
|
||||
elif len(set(list(values))) == 1:
|
||||
# If all aligned tokens have the same value, use it.
|
||||
output[token.i] = values[0]
|
||||
else:
|
||||
output[token.i] = None
|
||||
if as_string and field not in ["ENT_IOB", "SENT_START"]:
|
||||
output = [vocab.strings[o] if o is not None else o for o in output]
|
||||
return output
|
||||
|
||||
def get_aligned_parse(self, projectivize=True):
|
||||
cand_to_gold = self.alignment.cand_to_gold
|
||||
gold_to_cand = self.alignment.gold_to_cand
|
||||
cand_to_gold = self.alignment.x2y
|
||||
gold_to_cand = self.alignment.y2x
|
||||
aligned_heads = [None] * self.x.length
|
||||
aligned_deps = [None] * self.x.length
|
||||
heads = [token.head.i for token in self.y]
|
||||
|
@ -118,52 +121,51 @@ cdef class Example:
|
|||
if projectivize:
|
||||
heads, deps = nonproj.projectivize(heads, deps)
|
||||
for cand_i in range(self.x.length):
|
||||
gold_i = cand_to_gold[cand_i]
|
||||
if gold_i is not None: # Alignment found
|
||||
gold_head = gold_to_cand[heads[gold_i]]
|
||||
if gold_head is not None:
|
||||
aligned_heads[cand_i] = gold_head
|
||||
if cand_to_gold.lengths[cand_i] == 1:
|
||||
gold_i = cand_to_gold[cand_i].dataXd[0, 0]
|
||||
if gold_to_cand.lengths[heads[gold_i]] == 1:
|
||||
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0])
|
||||
aligned_deps[cand_i] = deps[gold_i]
|
||||
return aligned_heads, aligned_deps
|
||||
|
||||
def get_aligned_spans_x2y(self, x_spans):
|
||||
return self._get_aligned_spans(self.y, x_spans, self.alignment.x2y)
|
||||
|
||||
def get_aligned_spans_y2x(self, y_spans):
|
||||
return self._get_aligned_spans(self.x, y_spans, self.alignment.y2x)
|
||||
|
||||
def _get_aligned_spans(self, doc, spans, align):
|
||||
seen = set()
|
||||
output = []
|
||||
for span in spans:
|
||||
indices = align[span.start : span.end].data.ravel()
|
||||
indices = [idx for idx in indices if idx not in seen]
|
||||
if len(indices) >= 1:
|
||||
aligned_span = Span(doc, indices[0], indices[-1] + 1, label=span.label)
|
||||
target_text = span.text.lower().strip().replace(" ", "")
|
||||
our_text = aligned_span.text.lower().strip().replace(" ", "")
|
||||
if our_text == target_text:
|
||||
output.append(aligned_span)
|
||||
seen.update(indices)
|
||||
return output
|
||||
|
||||
def get_aligned_ner(self):
|
||||
if not self.y.is_nered:
|
||||
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
|
||||
x_text = self.x.text
|
||||
# Get a list of entities, and make spans for non-entity tokens.
|
||||
# We then work through the spans in order, trying to find them in
|
||||
# the text and using that to get the offset. Any token that doesn't
|
||||
# get a tag set this way is tagged None.
|
||||
# This could maybe be improved? It at least feels easy to reason about.
|
||||
y_spans = list(self.y.ents)
|
||||
y_spans.sort()
|
||||
x_text_offset = 0
|
||||
x_spans = []
|
||||
for y_span in y_spans:
|
||||
if x_text.count(y_span.text) >= 1:
|
||||
start_char = x_text.index(y_span.text) + x_text_offset
|
||||
end_char = start_char + len(y_span.text)
|
||||
x_span = self.x.char_span(start_char, end_char, label=y_span.label)
|
||||
if x_span is not None:
|
||||
x_spans.append(x_span)
|
||||
x_text = self.x.text[end_char:]
|
||||
x_text_offset = end_char
|
||||
x_ents = self.get_aligned_spans_y2x(self.y.ents)
|
||||
# Default to 'None' for missing values
|
||||
x_tags = biluo_tags_from_offsets(
|
||||
self.x,
|
||||
[(e.start_char, e.end_char, e.label_) for e in x_spans],
|
||||
[(e.start_char, e.end_char, e.label_) for e in x_ents],
|
||||
missing=None
|
||||
)
|
||||
gold_to_cand = self.alignment.gold_to_cand
|
||||
for token in self.y:
|
||||
if token.ent_iob_ == "O":
|
||||
cand_i = gold_to_cand[token.i]
|
||||
if cand_i is not None and x_tags[cand_i] is None:
|
||||
x_tags[cand_i] = "O"
|
||||
i2j_multi = self.alignment.i2j_multi
|
||||
for i, tag in enumerate(x_tags):
|
||||
if tag is None and i in i2j_multi:
|
||||
gold_i = i2j_multi[i]
|
||||
if gold_i is not None and self.y[gold_i].ent_iob_ == "O":
|
||||
# Now fill the tokens we can align to O.
|
||||
O = 2 # I=1, O=2, B=3
|
||||
for i, ent_iob in enumerate(self.get_aligned("ENT_IOB")):
|
||||
if x_tags[i] is None:
|
||||
if ent_iob == O:
|
||||
x_tags[i] = "O"
|
||||
elif self.x[i].is_space:
|
||||
x_tags[i] = "O"
|
||||
return x_tags
|
||||
|
||||
|
@ -194,25 +196,22 @@ cdef class Example:
|
|||
links[(ent.start_char, ent.end_char)] = {ent.kb_id_: 1.0}
|
||||
return links
|
||||
|
||||
|
||||
def split_sents(self):
|
||||
""" Split the token annotations into multiple Examples based on
|
||||
sent_starts and return a list of the new Examples"""
|
||||
if not self.reference.is_sentenced:
|
||||
return [self]
|
||||
|
||||
sent_starts = self.get_aligned("SENT_START")
|
||||
sent_starts.append(1) # appending virtual start of a next sentence to facilitate search
|
||||
|
||||
align = self.alignment.y2x
|
||||
seen_indices = set()
|
||||
output = []
|
||||
pred_start = 0
|
||||
for sent in self.reference.sents:
|
||||
new_ref = sent.as_doc()
|
||||
pred_end = sent_starts.index(1, pred_start+1) # find where the next sentence starts
|
||||
new_pred = self.predicted[pred_start : pred_end].as_doc()
|
||||
output.append(Example(new_pred, new_ref))
|
||||
pred_start = pred_end
|
||||
|
||||
for y_sent in self.reference.sents:
|
||||
indices = align[y_sent.start : y_sent.end].data.ravel()
|
||||
indices = [idx for idx in indices if idx not in seen_indices]
|
||||
if indices:
|
||||
x_sent = self.predicted[indices[0] : indices[-1] + 1]
|
||||
output.append(Example(x_sent.as_doc(), y_sent.as_doc()))
|
||||
seen_indices.update(indices)
|
||||
return output
|
||||
|
||||
property text:
|
||||
|
|
|
@ -326,10 +326,11 @@ class Scorer(object):
|
|||
for token in doc:
|
||||
if token.orth_.isspace():
|
||||
continue
|
||||
gold_i = align.cand_to_gold[token.i]
|
||||
if gold_i is None:
|
||||
if align.x2y.lengths[token.i] != 1:
|
||||
self.tokens.fp += 1
|
||||
gold_i = None
|
||||
else:
|
||||
gold_i = align.x2y[token.i].dataXd[0, 0]
|
||||
self.tokens.tp += 1
|
||||
cand_tags.add((gold_i, token.tag_))
|
||||
cand_pos.add((gold_i, token.pos_))
|
||||
|
@ -345,7 +346,10 @@ class Scorer(object):
|
|||
if token.is_sent_start:
|
||||
cand_sent_starts.add(gold_i)
|
||||
if token.dep_.lower() not in punct_labels and token.orth_.strip():
|
||||
gold_head = align.cand_to_gold[token.head.i]
|
||||
if align.x2y.lengths[token.head.i] == 1:
|
||||
gold_head = align.x2y[token.head.i].dataXd[0, 0]
|
||||
else:
|
||||
gold_head = None
|
||||
# None is indistinct, so we can't just add it to the set
|
||||
# Multiple (None, None) deps are possible
|
||||
if gold_i is None or gold_head is None:
|
||||
|
@ -381,15 +385,9 @@ class Scorer(object):
|
|||
gold_ents.add(gold_ent)
|
||||
gold_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
|
||||
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||
for ent in doc.ents:
|
||||
first = align.cand_to_gold[ent.start]
|
||||
last = align.cand_to_gold[ent.end - 1]
|
||||
if first is None or last is None:
|
||||
self.ner.fp += 1
|
||||
self.ner_per_ents[ent.label_].fp += 1
|
||||
else:
|
||||
cand_ents.add((ent.label_, first, last))
|
||||
cand_per_ents[ent.label_].add((ent.label_, first, last))
|
||||
for ent in example.get_aligned_spans_x2y(doc.ents):
|
||||
cand_ents.add((ent.label_, ent.start, ent.end - 1))
|
||||
cand_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
|
||||
# Scores per ent
|
||||
for k, v in self.ner_per_ents.items():
|
||||
if k in cand_per_ents:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from spacy.errors import AlignmentError
|
||||
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
|
||||
from spacy.gold import spans_from_biluo_tags, iob_to_biluo, align
|
||||
from spacy.gold import spans_from_biluo_tags, iob_to_biluo
|
||||
from spacy.gold import Corpus, docs_to_json
|
||||
from spacy.gold.example import Example
|
||||
from spacy.gold.converters import json2docs
|
||||
|
@ -271,75 +271,76 @@ def test_split_sentences(en_vocab):
|
|||
assert split_examples[1].text == "had loads of fun "
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Alignment should be fixed after example refactor")
|
||||
def test_gold_biluo_one_to_many(en_vocab, en_tokenizer):
|
||||
words = ["I", "flew to", "San Francisco Valley", "."]
|
||||
spaces = [True, True, False, False]
|
||||
words = ["Mr. and ", "Mrs. Smith", "flew to", "San Francisco Valley", "."]
|
||||
spaces = [True, True, True, False, False]
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
||||
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
prefix = "Mr. and Mrs. Smith flew to "
|
||||
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
|
||||
gold_words = ["Mr. and Mrs. Smith", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "O", "U-LOC", "O"]
|
||||
assert ner_tags == ["O", "O", "O", "U-LOC", "O"]
|
||||
|
||||
entities = [
|
||||
(len("I "), len("I flew to"), "ORG"),
|
||||
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||
(len("Mr. and "), len("Mr. and Mrs. Smith"), "PERSON"), # "Mrs. Smith" is a PERSON
|
||||
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
|
||||
]
|
||||
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
gold_words = ["Mr. and", "Mrs.", "Smith", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "U-ORG", "U-LOC", "O"]
|
||||
assert ner_tags == ["O", "U-PERSON", "O", "U-LOC", "O"]
|
||||
|
||||
entities = [
|
||||
(len("I "), len("I flew"), "ORG"),
|
||||
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||
(len("Mr. and "), len("Mr. and Mrs."), "PERSON"), # "Mrs." is a Person
|
||||
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
|
||||
]
|
||||
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
gold_words = ["Mr. and", "Mrs.", "Smith", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", None, "U-LOC", "O"]
|
||||
assert ner_tags == ["O", None, "O", "U-LOC", "O"]
|
||||
|
||||
|
||||
def test_gold_biluo_many_to_one(en_vocab, en_tokenizer):
|
||||
words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
words = ["Mr. and", "Mrs.", "Smith", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
spaces = [True, True, True, True, True, True, True, False, False]
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
prefix = "Mr. and Mrs. Smith flew to "
|
||||
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
|
||||
gold_words = ["Mr. and Mrs. Smith", "flew to", "San Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "O", "O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
|
||||
|
||||
entities = [
|
||||
(len("Mr. and "), len("Mr. and Mrs. Smith"), "PERSON"), # "Mrs. Smith" is a PERSON
|
||||
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
|
||||
]
|
||||
gold_words = ["Mr. and", "Mrs. Smith", "flew to", "San Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "B-PERSON", "L-PERSON", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
|
||||
|
||||
|
||||
def test_gold_biluo_misaligned(en_vocab, en_tokenizer):
|
||||
words = ["Mr. and Mrs.", "Smith", "flew", "to", "San Francisco", "Valley", "."]
|
||||
spaces = [True, True, True, True, True, False, False]
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
||||
gold_words = ["I", "flew to", "San Francisco Valley", "."]
|
||||
prefix = "Mr. and Mrs. Smith flew to "
|
||||
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
|
||||
gold_words = ["Mr.", "and Mrs. Smith", "flew to", "San", "Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
|
||||
assert ner_tags == ["O", "O", "O", "O", "B-LOC", "L-LOC", "O"]
|
||||
|
||||
entities = [
|
||||
(len("I "), len("I flew to"), "ORG"),
|
||||
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||
(len("Mr. and "), len("Mr. and Mrs. Smith"), "PERSON"), # "Mrs. Smith" is a PERSON
|
||||
(len(prefix), len(prefix + "San Francisco Valley"), "LOC"),
|
||||
]
|
||||
gold_words = ["I", "flew to", "San Francisco Valley", "."]
|
||||
gold_words = ["Mr. and", "Mrs. Smith", "flew to", "San", "Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "B-ORG", "L-ORG", "B-LOC", "I-LOC", "L-LOC", "O"]
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Alignment should be fixed after example refactor")
|
||||
def test_gold_biluo_misaligned(en_vocab, en_tokenizer):
|
||||
words = ["I flew", "to", "San Francisco", "Valley", "."]
|
||||
spaces = [True, True, True, False, False]
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
||||
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == ["O", "O", "B-LOC", "L-LOC", "O"]
|
||||
|
||||
entities = [
|
||||
(len("I "), len("I flew to"), "ORG"),
|
||||
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||
]
|
||||
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == [None, None, "B-LOC", "L-LOC", "O"]
|
||||
assert ner_tags == [None, None, "O", "O", "B-LOC", "L-LOC", "O"]
|
||||
|
||||
|
||||
def test_gold_biluo_additional_whitespace(en_vocab, en_tokenizer):
|
||||
|
@ -349,7 +350,8 @@ def test_gold_biluo_additional_whitespace(en_vocab, en_tokenizer):
|
|||
"I flew to San Francisco Valley.",
|
||||
)
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
||||
prefix = "I flew to "
|
||||
entities = [(len(prefix), len(prefix + "San Francisco Valley"), "LOC")]
|
||||
gold_words = ["I", "flew", " ", "to", "San Francisco Valley", "."]
|
||||
gold_spaces = [True, True, False, True, False, False]
|
||||
example = Example.from_dict(
|
||||
|
@ -514,6 +516,7 @@ def test_make_orth_variants(doc):
|
|||
make_orth_variants_example(nlp, train_example, orth_variant_level=0.2)
|
||||
|
||||
|
||||
@pytest.mark.skip("Outdated")
|
||||
@pytest.mark.parametrize(
|
||||
"tokens_a,tokens_b,expected",
|
||||
[
|
||||
|
@ -537,12 +540,12 @@ def test_make_orth_variants(doc):
|
|||
([" ", "a"], ["a"], (1, [-1, 0], [1], {}, {})),
|
||||
],
|
||||
)
|
||||
def test_align(tokens_a, tokens_b, expected):
|
||||
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b)
|
||||
assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected
|
||||
def test_align(tokens_a, tokens_b, expected): # noqa
|
||||
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_a, tokens_b) # noqa
|
||||
assert (cost, list(a2b), list(b2a), a2b_multi, b2a_multi) == expected # noqa
|
||||
# check symmetry
|
||||
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a)
|
||||
assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected
|
||||
cost, a2b, b2a, a2b_multi, b2a_multi = align(tokens_b, tokens_a) # noqa
|
||||
assert (cost, list(b2a), list(a2b), b2a_multi, a2b_multi) == expected # noqa
|
||||
|
||||
|
||||
def test_goldparse_startswith_space(en_tokenizer):
|
||||
|
@ -556,7 +559,7 @@ def test_goldparse_startswith_space(en_tokenizer):
|
|||
doc, {"words": gold_words, "entities": entities, "deps": deps, "heads": heads}
|
||||
)
|
||||
ner_tags = example.get_aligned_ner()
|
||||
assert ner_tags == [None, "U-DATE"]
|
||||
assert ner_tags == ["O", "U-DATE"]
|
||||
assert example.get_aligned("DEP", as_string=True) == [None, "ROOT"]
|
||||
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def test_aligned_tags():
|
|||
predicted = Doc(vocab, words=pred_words)
|
||||
example = Example.from_dict(predicted, annots)
|
||||
aligned_tags = example.get_aligned("tag", as_string=True)
|
||||
assert aligned_tags == ["VERB", "DET", None, "SCONJ", "PRON", "VERB", "VERB"]
|
||||
assert aligned_tags == ["VERB", "DET", "NOUN", "SCONJ", "PRON", "VERB", "VERB"]
|
||||
|
||||
|
||||
def test_aligned_tags_multi():
|
||||
|
|
Loading…
Reference in New Issue
Block a user