Refactor alignment into its own class

This commit is contained in:
Matthew Honnibal 2018-04-02 21:54:29 +02:00
parent 9c3612d40b
commit e6641a11b1
2 changed files with 195 additions and 143 deletions

View File

@ -85,12 +85,134 @@ i.e. D[i,j+1] + 1
from __future__ import unicode_literals from __future__ import unicode_literals
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
import numpy import numpy
import copy
cimport numpy as np cimport numpy as np
from .compat import unicode_ from .compat import unicode_
from murmurhash.mrmr cimport hash32 from murmurhash.mrmr cimport hash32
from collections import Counter
def align(S, T): class Alignment(object):
def __init__(self, your_words, their_words):
cost, your2their, their2your = self._align(your_words, their_words)
self.cost = cost
self._y2t = your2their
self._t2y = their2your
def to_yours(self, items):
'''Translate a list of token annotations into your tokenization. Returns
a list of length equal to your tokens. When one of your tokens aligns
to multiple items, the entry will be a list. When multiple of your
tokens align to one item, you'll get a tuple (value, index, n_to_go),
where index is an int starting from 0, and n_to_go tells you how
many remaining subtokens align to the same value.
'''
output = []
for i, alignment in enumerate(self._t2y):
if len(alignment) == 1 and alignment[0][1] == 0:
output.append(items[alignment[0][0]])
else:
output.append([])
for j1, j2 in alignment:
output[-1].append((items[j1], j2))
return output
def index_to_yours(self, index):
'''Translate an index that points into their tokens to point into yours'''
alignment = self._t2y[index]
if len(alignment) == 1 and alignment[0][2] == 0:
return alignment[0][0]
else:
output = []
for i1, i2, n_to_go in alignment:
output.append((i1, i2, n_to_go))
return output
def to_theirs(self, items):
raise NotImplementedError
def index_to_theirs(self, index):
raise NotImplementedError
@classmethod
def _align(cls, cand_words, gold_words):
'''Find best alignment between candidate tokenization and gold tokenization.
Returns the alignment cost and alignment tables in both directions:
cand_to_gold and gold_to_cand
Alignment entries are lists of addresses, where an address is a tuple
(position, subtoken). This allows one-to-many and many-to-one alignment.
For instance, let's say we align:
Cand: ['ab', 'c', 'd']
Gold: ['a', 'b', 'cd']
The cand_to_gold alignment would be:
[[0, 0], (2, 0), (2, 1)]
And the gold_to_cand alignment:
[(0, 0), (0, 1), [1, 2]]
'''
if cand_words == gold_words:
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment, {}, {}, []
cand_words = [w.replace(' ', '') for w in cand_words]
gold_words = [w.replace(' ', '') for w in gold_words]
cost, i2j, j2i, matrix = levenshtein_align(cand_words, gold_words)
i_lengths = [len(w) for w in cand_words]
j_lengths = [len(w) for w in gold_words]
cand2gold, gold2cand = multi_align(i2j, j2i, i_lengths, j_lengths)
return cost, cand2gold, gold2cand
@staticmethod
def flatten(heads):
'''Let's say we have a heads array with fused tokens. We might have
something like:
[[(0, 1), 1], 1]
This indicates that token 0 aligns to two gold tokens. The head of the
first subtoken is the second subtoken. The head of the second subtoken
is the second token.
So we expand to a tree:
[1, 2, 2]
This is helpful for preventing other functions from knowing our weird
format.
'''
# Get an alignment -- normalize to the more complicated format; so
# if we have an int i, treat it as [(i, 0)]
j = 0
alignment = {(None, 0): None}
for i, tokens in enumerate(heads):
if not isinstance(tokens, list):
alignment[(i, 0)] = j
j += 1
else:
for sub_i in range(len(tokens)):
alignment[(i, sub_i)] = j
j += 1
# Apply the alignment to get the new values
new = []
for head_vals in heads:
if not isinstance(head_vals, list):
head_vals = [(head_vals, 0)]
for head_val in head_vals:
if not isinstance(head_val, tuple):
head_val = (head_val, 0)
new.append(alignment[head_val])
return new
def levenshtein_align(S, T):
cdef int m = len(S) cdef int m = len(S)
cdef int n = len(T) cdef int n = len(T)
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32') cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
@ -129,8 +251,38 @@ def multi_align(np.ndarray i2j, np.ndarray j2i, i_lengths, j_lengths):
i2j_miss = _get_regions(i2j, i_lengths) i2j_miss = _get_regions(i2j, i_lengths)
j2i_miss = _get_regions(j2i, j_lengths) j2i_miss = _get_regions(j2i, j_lengths)
i2j_multi, j2i_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths) i2j_many2one = _get_many2one(i2j_miss, j2i_miss, i_lengths, j_lengths)
return i2j_multi, j2i_multi j2i_many2one = _get_many2one(j2i_miss, i2j_miss, j_lengths, i_lengths)
i2j_one2many = _get_one2many(j2i_many2one)
j2i_one2many = _get_one2many(i2j_many2one)
i2j_one2part = _get_one2part(j2i_many2one)
j2i_one2part = _get_one2part(i2j_many2one)
# Now get the more usable format we'll return
cand2gold = _convert_multi_align(i2j, i2j_many2one, i2j_one2many, i2j_one2part)
gold2cand = _convert_multi_align(j2i, j2i_many2one, j2i_one2many, j2i_one2part)
return cand2gold, gold2cand
def _convert_multi_align(one2one, many2one, one2many, one2part):
output = []
seen_j = Counter()
for i, j in enumerate(one2one):
if j != -1:
output.append(j)
elif i in many2one:
j = many2one[i]
output.append((j, seen_j[j]))
seen_j[j] += 1
elif i in one2many:
output.append([])
for j in one2many[i]:
output[-1].append(j)
elif i in one2part:
output.append(one2part[i])
else:
output.append(None)
return output
def _get_regions(alignment, lengths): def _get_regions(alignment, lengths):
@ -149,9 +301,10 @@ def _get_regions(alignment, lengths):
return regions return regions
def _get_mapping(miss1, miss2, lengths1, lengths2): def _get_many2one(miss1, miss2, lengths1, lengths2):
i2j = {} miss1 = copy.deepcopy(miss1)
j2i = {} miss2 = copy.deepcopy(miss2)
i2j_many2one = {}
for start, region1 in miss1.items(): for start, region1 in miss1.items():
if not region1 or start not in miss2: if not region1 or start not in miss2:
continue continue
@ -166,8 +319,7 @@ def _get_mapping(miss1, miss2, lengths1, lengths2):
buff.append(region1.pop(0)) buff.append(region1.pop(0))
if sum(lengths1[i] for i in buff) == lengths2[j]: if sum(lengths1[i] for i in buff) == lengths2[j]:
for i in buff: for i in buff:
i2j[i] = j i2j_many2one[i] = j
j2i[j] = buff[-1]
j += 1 j += 1
buff = [] buff = []
elif sum(lengths1[i] for i in buff) > lengths2[j]: elif sum(lengths1[i] for i in buff) > lengths2[j]:
@ -175,11 +327,25 @@ def _get_mapping(miss1, miss2, lengths1, lengths2):
else: else:
if buff and sum(lengths1[i] for i in buff) == lengths2[j]: if buff and sum(lengths1[i] for i in buff) == lengths2[j]:
for i in buff: for i in buff:
i2j[i] = j i2j_many2one[i] = j
j2i[j] = buff[-1] return i2j_many2one
return i2j, j2i
def _get_one2many(many2one):
one2many = {}
for j, i in many2one.items():
one2many.setdefault(i, []).append(j)
return one2many
def _get_one2part(many2one):
one2part = {}
seen_j = Counter()
for i, j in many2one.items():
one2part[i] = (j, seen_j[j])
seen_j[j] += 1
return one2part
def _convert_sequence(seq): def _convert_sequence(seq):
if isinstance(seq, numpy.ndarray): if isinstance(seq, numpy.ndarray):
return numpy.ascontiguousarray(seq, dtype='uint32_t') return numpy.ascontiguousarray(seq, dtype='uint32_t')

View File

@ -11,10 +11,11 @@ import tempfile
import shutil import shutil
from pathlib import Path from pathlib import Path
import msgpack import msgpack
from collections import Counter
import ujson import ujson
from . import _align from ._align import Alignment
from .syntax import nonproj from .syntax import nonproj
from .tokens import Doc from .tokens import Doc
from . import util from . import util
@ -23,6 +24,7 @@ from .compat import json_dumps
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
def tags_to_entities(tags): def tags_to_entities(tags):
entities = [] entities = []
start = None start = None
@ -68,32 +70,6 @@ def merge_sents(sents):
return [(m_deps, m_brackets)] return [(m_deps, m_brackets)]
punct_re = re.compile(r'\W')
def align(cand_words, gold_words):
if cand_words == gold_words:
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment, {}, {}, []
cand_words = [w.replace(' ', '') for w in cand_words]
gold_words = [w.replace(' ', '') for w in gold_words]
cost, i2j, j2i, matrix = _align.align(cand_words, gold_words)
i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in cand_words],
[len(w) for w in gold_words])
for i, j in list(i2j_multi.items()):
if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j:
i2j[i] = j
i2j_multi.pop(i)
for j, i in list(j2i_multi.items()):
if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i:
j2i[j] = i
j2i_multi.pop(j)
reverse_j2i, reverse_i2j = _align.multi_align(j2i, i2j, [len(w) for w in gold_words],
[len(w) for w in cand_words])
undersegmented = {}
for j, i in reverse_j2i.items():
undersegmented.setdefault(i, []).append(j)
return cost, i2j, j2i, i2j_multi, j2i_multi, undersegmented
class GoldCorpus(object): class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages """An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER.""" annotations for tagging, dependency parsing and NER."""
@ -385,47 +361,6 @@ def _consume_ent(tags):
return [start] + middle + [end] return [start] + middle + [end]
def _flatten_fused_heads(heads):
'''Let's say we have a heads array with fused tokens. We might have
something like:
[[(0, 1), 1], 1]
This indicates that token 0 aligns to two gold tokens. The head of the
first subtoken is the second subtoken. The head of the second subtoken
is the second token.
So we expand to a tree:
[1, 2, 2]
This is helpful for preventing other functions from knowing our weird
format.
'''
# Get an alignment -- normalize to the more complicated format; so
# if we have an int i, treat it as [(i, 0)]
j = 0
alignment = {(None, 0): None}
for i, tokens in enumerate(heads):
if not isinstance(tokens, list):
alignment[(i, 0)] = j
j += 1
else:
for sub_i in range(len(tokens)):
alignment[(i, sub_i)] = j
j += 1
# Apply the alignment to get the new values
new = []
for head_vals in heads:
if not isinstance(head_vals, list):
head_vals = [(head_vals, 0)]
for head_val in head_vals:
if not isinstance(head_val, tuple):
head_val = (head_val, 0)
new.append(alignment[head_val])
return new
cdef class GoldParse: cdef class GoldParse:
"""Collection for training annotations.""" """Collection for training annotations."""
@classmethod @classmethod
@ -508,80 +443,30 @@ cdef class GoldParse:
# sequence of gold words. # sequence of gold words.
# If we "mis-segment", we'll have a sequence of predicted words covering # If we "mis-segment", we'll have a sequence of predicted words covering
# a sequence of gold words. That's many-to-many -- we don't do that. # a sequence of gold words. That's many-to-many -- we don't do that.
cost, i2j, j2i, i2j_multi, j2i_multi, undersegmented = align([t.orth_ for t in doc], words) self._alignment = Alignment([t.orth_ for t in doc], words)
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
annot_tuples = (range(len(words)), words, tags, heads, deps, entities) annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
self.orig_annot = list(zip(*annot_tuples)) self.orig_annot = list(zip(*annot_tuples))
for i, gold_i in enumerate(self.cand_to_gold): self.words = self._alignment.to_yours(words)
self.tags = self._alignment.to_yours(tags)
self.labels = self._alignment.to_yours(deps)
self.tags = self._alignment.to_yours(tags)
self.ner = self._alignment.to_yours(entities)
aligned_heads = [self._alignment.index_to_yours(h) for h in heads]
self.heads = self._alignment.to_yours(aligned_heads)
for i in range(len(doc)):
# Fix spaces
if doc[i].text.isspace(): if doc[i].text.isspace():
self.words[i] = doc[i].text self.words[i] = doc[i].text
self.tags[i] = '_SP' self.tags[i] = '_SP'
self.heads[i] = None self.heads[i] = None
self.labels[i] = None self.labels[i] = None
self.ner[i] = 'O' self.ner[i] = 'O'
if gold_i is None:
if i in undersegmented: cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads))
self.words[i] = [words[j] for j in undersegmented[i]]
self.tags[i] = [tags[j] for j in undersegmented[i]]
self.labels[i] = [deps[j] for j in undersegmented[i]]
self.ner[i] = [entities[j] for j in undersegmented[i]]
self.heads[i] = []
for h in [heads[j] for j in undersegmented[i]]:
if heads[h] is None:
self.heads[i].append(None)
else:
self.heads[i].append(self.gold_to_cand[heads[h]])
elif i in i2j_multi:
self.words[i] = words[i2j_multi[i]]
self.tags[i] = tags[i2j_multi[i]]
is_last = i2j_multi[i] != i2j_multi.get(i+1)
is_first = i2j_multi[i] != i2j_multi.get(i-1)
# Set next word in multi-token span as head, until last
if not is_last:
self.heads[i] = i+1
self.labels[i] = 'subtok'
else:
self.heads[i] = self.gold_to_cand[heads[i2j_multi[i]]]
self.labels[i] = deps[i2j_multi[i]]
# Now set NER...This is annoying because if we've split
# got an entity word split into two, we need to adjust the
# BILOU tags. We can't have BB or LL etc.
# Case 1: O -- easy.
ner_tag = entities[i2j_multi[i]]
if ner_tag is not None:
if ner_tag == 'O':
self.ner[i] = 'O'
# Case 2: U. This has to become a B I* L sequence.
elif ner_tag.startswith('U-'):
if is_first:
self.ner[i] = ner_tag.replace('U-', 'B-', 1)
elif is_last:
self.ner[i] = ner_tag.replace('U-', 'L-', 1)
else:
self.ner[i] = ner_tag.replace('U-', 'I-', 1)
# Case 3: L. If not last, change to I.
elif ner_tag.startswith('L-'):
if is_last:
self.ner[i] = ner_tag
else:
self.ner[i] = ner_tag.replace('L-', 'I-', 1)
# Case 4: I. Stays correct
elif ner_tag.startswith('I-'):
self.ner[i] = ner_tag
else:
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i]
if heads[gold_i] is None:
self.heads[i] = None
else:
self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i]
self.ner[i] = entities[gold_i]
cycle = nonproj.contains_cycle(_flatten_fused_heads(self.heads))
if cycle is not None: if cycle is not None:
raise Exception("Cycle found: %s" % cycle) raise Exception("Cycle found: %s" % cycle)
@ -690,3 +575,4 @@ def offsets_from_biluo_tags(doc, tags):
def is_punct_label(label): def is_punct_label(label):
return label == 'P' or label.lower() == 'punct' return label == 'P' or label.lower() == 'punct'