Refactor alignment into its own class

This commit is contained in:
Matthew Honnibal 2018-04-02 21:54:29 +02:00
parent 9c3612d40b
commit e6641a11b1
2 changed files with 195 additions and 143 deletions

View File

@ -85,12 +85,134 @@ i.e. D[i,j+1] + 1
from __future__ import unicode_literals
from libc.stdint cimport uint32_t
import numpy
import copy
cimport numpy as np
from .compat import unicode_
from murmurhash.mrmr cimport hash32
from collections import Counter
def align(S, T):
class Alignment(object):
def __init__(self, your_words, their_words):
cost, your2their, their2your = self._align(your_words, their_words)
self.cost = cost
self._y2t = your2their
self._t2y = their2your
def to_yours(self, items):
'''Translate a list of token annotations into your tokenization. Returns
a list of length equal to your tokens. When one of your tokens aligns
to multiple items, the entry will be a list. When multiple of your
tokens align to one item, you'll get a tuple (value, index, n_to_go),
where index is an int starting from 0, and n_to_go tells you how
many remaining subtokens align to the same value.
'''
output = []
for i, alignment in enumerate(self._t2y):
if len(alignment) == 1 and alignment[0][1] == 0:
output.append(items[alignment[0][0]])
else:
output.append([])
for j1, j2 in alignment:
output[-1].append((items[j1], j2))
return output
def index_to_yours(self, index):
'''Translate an index that points into their tokens to point into yours'''
alignment = self._t2y[index]
if len(alignment) == 1 and alignment[0][2] == 0:
return alignment[0][0]
else:
output = []
for i1, i2, n_to_go in alignment:
output.append((i1, i2, n_to_go))
return output
def to_theirs(self, items):
raise NotImplementedError
def index_to_theirs(self, index):
raise NotImplementedError
@classmethod
def _align(cls, cand_words, gold_words):
'''Find best alignment between candidate tokenization and gold tokenization.
Returns the alignment cost and alignment tables in both directions:
cand_to_gold and gold_to_cand
Alignment entries are lists of addresses, where an address is a tuple
(position, subtoken). This allows one-to-many and many-to-one alignment.
For instance, let's say we align:
Cand: ['ab', 'c', 'd']
Gold: ['a', 'b', 'cd']
The cand_to_gold alignment would be:
[[0, 0], (2, 0), (2, 1)]
And the gold_to_cand alignment:
[(0, 0), (0, 1), [1, 2]]
'''
if cand_words == gold_words:
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment, {}, {}, []
cand_words = [w.replace(' ', '') for w in cand_words]
gold_words = [w.replace(' ', '') for w in gold_words]
cost, i2j, j2i, matrix = levenshtein_align(cand_words, gold_words)
i_lengths = [len(w) for w in cand_words]
j_lengths = [len(w) for w in gold_words]
cand2gold, gold2cand = multi_align(i2j, j2i, i_lengths, j_lengths)
return cost, cand2gold, gold2cand
@staticmethod
def flatten(heads):
'''Let's say we have a heads array with fused tokens. We might have
something like:
[[(0, 1), 1], 1]
This indicates that token 0 aligns to two gold tokens. The head of the
first subtoken is the second subtoken. The head of the second subtoken
is the second token.
So we expand to a tree:
[1, 2, 2]
This is helpful for preventing other functions from knowing our weird
format.
'''
# Get an alignment -- normalize to the more complicated format; so
# if we have an int i, treat it as [(i, 0)]
j = 0
alignment = {(None, 0): None}
for i, tokens in enumerate(heads):
if not isinstance(tokens, list):
alignment[(i, 0)] = j
j += 1
else:
for sub_i in range(len(tokens)):
alignment[(i, sub_i)] = j
j += 1
# Apply the alignment to get the new values
new = []
for head_vals in heads:
if not isinstance(head_vals, list):
head_vals = [(head_vals, 0)]
for head_val in head_vals:
if not isinstance(head_val, tuple):
head_val = (head_val, 0)
new.append(alignment[head_val])
return new
def levenshtein_align(S, T):
cdef int m = len(S)
cdef int n = len(T)
cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32')
@ -128,9 +250,39 @@ def multi_align(np.ndarray i2j, np.ndarray j2i, i_lengths, j_lengths):
'''
i2j_miss = _get_regions(i2j, i_lengths)
j2i_miss = _get_regions(j2i, j_lengths)
i2j_many2one = _get_many2one(i2j_miss, j2i_miss, i_lengths, j_lengths)
j2i_many2one = _get_many2one(j2i_miss, i2j_miss, j_lengths, i_lengths)
i2j_one2many = _get_one2many(j2i_many2one)
j2i_one2many = _get_one2many(i2j_many2one)
i2j_one2part = _get_one2part(j2i_many2one)
j2i_one2part = _get_one2part(i2j_many2one)
i2j_multi, j2i_multi = _get_mapping(i2j_miss, j2i_miss, i_lengths, j_lengths)
return i2j_multi, j2i_multi
# Now get the more usable format we'll return
cand2gold = _convert_multi_align(i2j, i2j_many2one, i2j_one2many, i2j_one2part)
gold2cand = _convert_multi_align(j2i, j2i_many2one, j2i_one2many, j2i_one2part)
return cand2gold, gold2cand
def _convert_multi_align(one2one, many2one, one2many, one2part):
output = []
seen_j = Counter()
for i, j in enumerate(one2one):
if j != -1:
output.append(j)
elif i in many2one:
j = many2one[i]
output.append((j, seen_j[j]))
seen_j[j] += 1
elif i in one2many:
output.append([])
for j in one2many[i]:
output[-1].append(j)
elif i in one2part:
output.append(one2part[i])
else:
output.append(None)
return output
def _get_regions(alignment, lengths):
@ -149,9 +301,10 @@ def _get_regions(alignment, lengths):
return regions
def _get_mapping(miss1, miss2, lengths1, lengths2):
i2j = {}
j2i = {}
def _get_many2one(miss1, miss2, lengths1, lengths2):
miss1 = copy.deepcopy(miss1)
miss2 = copy.deepcopy(miss2)
i2j_many2one = {}
for start, region1 in miss1.items():
if not region1 or start not in miss2:
continue
@ -166,8 +319,7 @@ def _get_mapping(miss1, miss2, lengths1, lengths2):
buff.append(region1.pop(0))
if sum(lengths1[i] for i in buff) == lengths2[j]:
for i in buff:
i2j[i] = j
j2i[j] = buff[-1]
i2j_many2one[i] = j
j += 1
buff = []
elif sum(lengths1[i] for i in buff) > lengths2[j]:
@ -175,11 +327,25 @@ def _get_mapping(miss1, miss2, lengths1, lengths2):
else:
if buff and sum(lengths1[i] for i in buff) == lengths2[j]:
for i in buff:
i2j[i] = j
j2i[j] = buff[-1]
return i2j, j2i
i2j_many2one[i] = j
return i2j_many2one
def _get_one2many(many2one):
one2many = {}
for j, i in many2one.items():
one2many.setdefault(i, []).append(j)
return one2many
def _get_one2part(many2one):
one2part = {}
seen_j = Counter()
for i, j in many2one.items():
one2part[i] = (j, seen_j[j])
seen_j[j] += 1
return one2part
def _convert_sequence(seq):
if isinstance(seq, numpy.ndarray):
return numpy.ascontiguousarray(seq, dtype='uint32_t')

View File

@ -11,10 +11,11 @@ import tempfile
import shutil
from pathlib import Path
import msgpack
from collections import Counter
import ujson
from . import _align
from ._align import Alignment
from .syntax import nonproj
from .tokens import Doc
from . import util
@ -23,6 +24,7 @@ from .compat import json_dumps
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
def tags_to_entities(tags):
entities = []
start = None
@ -68,32 +70,6 @@ def merge_sents(sents):
return [(m_deps, m_brackets)]
punct_re = re.compile(r'\W')
def align(cand_words, gold_words):
if cand_words == gold_words:
alignment = numpy.arange(len(cand_words))
return 0, alignment, alignment, {}, {}, []
cand_words = [w.replace(' ', '') for w in cand_words]
gold_words = [w.replace(' ', '') for w in gold_words]
cost, i2j, j2i, matrix = _align.align(cand_words, gold_words)
i2j_multi, j2i_multi = _align.multi_align(i2j, j2i, [len(w) for w in cand_words],
[len(w) for w in gold_words])
for i, j in list(i2j_multi.items()):
if i2j_multi.get(i+1) != j and i2j_multi.get(i-1) != j:
i2j[i] = j
i2j_multi.pop(i)
for j, i in list(j2i_multi.items()):
if j2i_multi.get(j+1) != i and j2i_multi.get(j-1) != i:
j2i[j] = i
j2i_multi.pop(j)
reverse_j2i, reverse_i2j = _align.multi_align(j2i, i2j, [len(w) for w in gold_words],
[len(w) for w in cand_words])
undersegmented = {}
for j, i in reverse_j2i.items():
undersegmented.setdefault(i, []).append(j)
return cost, i2j, j2i, i2j_multi, j2i_multi, undersegmented
class GoldCorpus(object):
"""An annotated corpus, using the JSON file format. Manages
annotations for tagging, dependency parsing and NER."""
@ -385,47 +361,6 @@ def _consume_ent(tags):
return [start] + middle + [end]
def _flatten_fused_heads(heads):
'''Let's say we have a heads array with fused tokens. We might have
something like:
[[(0, 1), 1], 1]
This indicates that token 0 aligns to two gold tokens. The head of the
first subtoken is the second subtoken. The head of the second subtoken
is the second token.
So we expand to a tree:
[1, 2, 2]
This is helpful for preventing other functions from knowing our weird
format.
'''
# Get an alignment -- normalize to the more complicated format; so
# if we have an int i, treat it as [(i, 0)]
j = 0
alignment = {(None, 0): None}
for i, tokens in enumerate(heads):
if not isinstance(tokens, list):
alignment[(i, 0)] = j
j += 1
else:
for sub_i in range(len(tokens)):
alignment[(i, sub_i)] = j
j += 1
# Apply the alignment to get the new values
new = []
for head_vals in heads:
if not isinstance(head_vals, list):
head_vals = [(head_vals, 0)]
for head_val in head_vals:
if not isinstance(head_val, tuple):
head_val = (head_val, 0)
new.append(alignment[head_val])
return new
cdef class GoldParse:
"""Collection for training annotations."""
@classmethod
@ -508,80 +443,30 @@ cdef class GoldParse:
# sequence of gold words.
# If we "mis-segment", we'll have a sequence of predicted words covering
# a sequence of gold words. That's many-to-many -- we don't do that.
cost, i2j, j2i, i2j_multi, j2i_multi, undersegmented = align([t.orth_ for t in doc], words)
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
self._alignment = Alignment([t.orth_ for t in doc], words)
annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
self.orig_annot = list(zip(*annot_tuples))
for i, gold_i in enumerate(self.cand_to_gold):
self.words = self._alignment.to_yours(words)
self.tags = self._alignment.to_yours(tags)
self.labels = self._alignment.to_yours(deps)
self.tags = self._alignment.to_yours(tags)
self.ner = self._alignment.to_yours(entities)
aligned_heads = [self._alignment.index_to_yours(h) for h in heads]
self.heads = self._alignment.to_yours(aligned_heads)
for i in range(len(doc)):
# Fix spaces
if doc[i].text.isspace():
self.words[i] = doc[i].text
self.tags[i] = '_SP'
self.heads[i] = None
self.labels[i] = None
self.ner[i] = 'O'
if gold_i is None:
if i in undersegmented:
self.words[i] = [words[j] for j in undersegmented[i]]
self.tags[i] = [tags[j] for j in undersegmented[i]]
self.labels[i] = [deps[j] for j in undersegmented[i]]
self.ner[i] = [entities[j] for j in undersegmented[i]]
self.heads[i] = []
for h in [heads[j] for j in undersegmented[i]]:
if heads[h] is None:
self.heads[i].append(None)
else:
self.heads[i].append(self.gold_to_cand[heads[h]])
elif i in i2j_multi:
self.words[i] = words[i2j_multi[i]]
self.tags[i] = tags[i2j_multi[i]]
is_last = i2j_multi[i] != i2j_multi.get(i+1)
is_first = i2j_multi[i] != i2j_multi.get(i-1)
# Set next word in multi-token span as head, until last
if not is_last:
self.heads[i] = i+1
self.labels[i] = 'subtok'
else:
self.heads[i] = self.gold_to_cand[heads[i2j_multi[i]]]
self.labels[i] = deps[i2j_multi[i]]
# Now set NER...This is annoying because if we've split
# got an entity word split into two, we need to adjust the
# BILOU tags. We can't have BB or LL etc.
# Case 1: O -- easy.
ner_tag = entities[i2j_multi[i]]
if ner_tag is not None:
if ner_tag == 'O':
self.ner[i] = 'O'
# Case 2: U. This has to become a B I* L sequence.
elif ner_tag.startswith('U-'):
if is_first:
self.ner[i] = ner_tag.replace('U-', 'B-', 1)
elif is_last:
self.ner[i] = ner_tag.replace('U-', 'L-', 1)
else:
self.ner[i] = ner_tag.replace('U-', 'I-', 1)
# Case 3: L. If not last, change to I.
elif ner_tag.startswith('L-'):
if is_last:
self.ner[i] = ner_tag
else:
self.ner[i] = ner_tag.replace('L-', 'I-', 1)
# Case 4: I. Stays correct
elif ner_tag.startswith('I-'):
self.ner[i] = ner_tag
else:
self.words[i] = words[gold_i]
self.tags[i] = tags[gold_i]
if heads[gold_i] is None:
self.heads[i] = None
else:
self.heads[i] = self.gold_to_cand[heads[gold_i]]
self.labels[i] = deps[gold_i]
self.ner[i] = entities[gold_i]
cycle = nonproj.contains_cycle(_flatten_fused_heads(self.heads))
cycle = nonproj.contains_cycle(self._alignment.flatten(self.heads))
if cycle is not None:
raise Exception("Cycle found: %s" % cycle)
@ -690,3 +575,4 @@ def offsets_from_biluo_tags(doc, tags):
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'