Use new alignment implementation in GoldParse

This commit is contained in:
Matthew Honnibal 2018-02-20 21:16:35 +01:00
parent c0734ba526
commit f466f0186e

View File

@ -7,7 +7,9 @@ import ujson
import random import random
import cytoolz import cytoolz
import itertools import itertools
import numpy
from . import _align
from .syntax import nonproj from .syntax import nonproj
from .tokens import Doc from .tokens import Doc
from . import util from . import util
@ -59,90 +61,15 @@ def merge_sents(sents):
return [(m_deps, m_brackets)] return [(m_deps, m_brackets)]
def align(cand_words, gold_words):
cost, edit_path = _min_edit_path(cand_words, gold_words)
alignment = []
i_of_gold = 0
for move in edit_path:
if move == 'M':
alignment.append(i_of_gold)
i_of_gold += 1
elif move == 'S':
alignment.append(None)
i_of_gold += 1
elif move == 'D':
alignment.append(None)
elif move == 'I':
i_of_gold += 1
else:
raise Exception(move)
return alignment
punct_re = re.compile(r'\W') punct_re = re.compile(r'\W')
def align(cand_words, gold_words):
def _min_edit_path(cand_words, gold_words):
cdef:
Pool mem
int i, j, n_cand, n_gold
int* curr_costs
int* prev_costs
# TODO: Fix this --- just do it properly, make the full edit matrix and
# then walk back over it...
# Preprocess inputs
cand_words = [punct_re.sub('', w).lower() for w in cand_words] cand_words = [punct_re.sub('', w).lower() for w in cand_words]
gold_words = [punct_re.sub('', w).lower() for w in gold_words] gold_words = [punct_re.sub('', w).lower() for w in gold_words]
if cand_words == gold_words: if cand_words == gold_words:
return 0, ''.join(['M' for _ in gold_words]) alignment = numpy.arange(len(cand_words))
mem = Pool() return 0, alignment, alignment
n_cand = len(cand_words) cost, i2j, j2i, matrix = _align.align(cand_words, gold_words)
n_gold = len(gold_words) return cost, i2j, j2i
# Levenshtein distance, except we need the history, and we may want
# different costs. Mark operations with a string, and score the history
# using _edit_cost.
previous_row = []
prev_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
curr_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
for i in range(n_gold + 1):
cell = ''
for j in range(i):
cell += 'I'
previous_row.append('I' * i)
prev_costs[i] = i
for i, cand in enumerate(cand_words):
current_row = ['D' * (i + 1)]
curr_costs[0] = i+1
for j, gold in enumerate(gold_words):
if gold.lower() == cand.lower():
s_cost = prev_costs[j]
i_cost = curr_costs[j] + 1
d_cost = prev_costs[j + 1] + 1
else:
s_cost = prev_costs[j] + 1
i_cost = curr_costs[j] + 1
d_cost = prev_costs[j + 1] + (1 if cand else 0)
if s_cost <= i_cost and s_cost <= d_cost:
best_cost = s_cost
best_hist = previous_row[j] + ('M' if gold == cand else 'S')
elif i_cost <= s_cost and i_cost <= d_cost:
best_cost = i_cost
best_hist = current_row[j] + 'I'
else:
best_cost = d_cost
best_hist = previous_row[j + 1] + 'D'
current_row.append(best_hist)
curr_costs[j+1] = best_cost
previous_row = current_row
for j in range(len(gold_words) + 1):
prev_costs[j] = curr_costs[j]
curr_costs[j] = 0
return prev_costs[n_gold], previous_row[-1]
class GoldCorpus(object): class GoldCorpus(object):
@ -434,8 +361,9 @@ cdef class GoldParse:
self.labels = [None] * len(doc) self.labels = [None] * len(doc)
self.ner = [None] * len(doc) self.ner = [None] * len(doc)
self.cand_to_gold = align([t.orth_ for t in doc], words) cost, i2j, j2i = align([t.orth_ for t in doc], words)
self.gold_to_cand = align(words, [t.orth_ for t in doc]) self.cand_to_gold = [(j if j != -1 else None) for j in i2j]
self.gold_to_cand = [(i if i != -1 else None) for i in j2i]
annot_tuples = (range(len(words)), words, tags, heads, deps, entities) annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
self.orig_annot = list(zip(*annot_tuples)) self.orig_annot = list(zip(*annot_tuples))