spaCy/spacy/gold.pyx
2015-07-03 13:29:22 +02:00

306 lines
10 KiB
Cython

import numpy
import codecs
import json
import ujson
import random
import re
import os
from os import path
from libc.string cimport memset
from .typedefs cimport flags_t
def tags_to_entities(tags):
entities = []
start = None
for i, tag in enumerate(tags):
if tag.startswith('O'):
# TODO: We shouldn't be getting these malformed inputs. Fix this.
if start is not None:
start = None
continue
elif tag == '-':
continue
elif tag.startswith('I'):
assert start is not None, tags[:i]
continue
if tag.startswith('U'):
entities.append((tag[2:], i, i))
elif tag.startswith('B'):
start = i
elif tag.startswith('L'):
entities.append((tag[2:], start, i))
start = None
else:
raise Exception(tag)
return entities
def align(cand_words, gold_words):
cost, edit_path = _min_edit_path(cand_words, gold_words)
alignment = []
i_of_gold = 0
for move in edit_path:
if move == 'M':
alignment.append(i_of_gold)
i_of_gold += 1
elif move == 'S':
alignment.append(None)
i_of_gold += 1
elif move == 'D':
alignment.append(None)
elif move == 'I':
i_of_gold += 1
else:
raise Exception(move)
return alignment
punct_re = re.compile(r'\W')
def _min_edit_path(cand_words, gold_words):
cdef:
Pool mem
int i, j, n_cand, n_gold
int* curr_costs
int* prev_costs
# TODO: Fix this --- just do it properly, make the full edit matrix and
# then walk back over it...
# Preprocess inputs
cand_words = [punct_re.sub('', w) for w in cand_words]
gold_words = [punct_re.sub('', w) for w in gold_words]
if cand_words == gold_words:
return 0, ''.join(['M' for _ in gold_words])
mem = Pool()
n_cand = len(cand_words)
n_gold = len(gold_words)
# Levenshtein distance, except we need the history, and we may want different
# costs.
# Mark operations with a string, and score the history using _edit_cost.
previous_row = []
prev_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
curr_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
for i in range(n_gold + 1):
cell = ''
for j in range(i):
cell += 'I'
previous_row.append('I' * i)
prev_costs[i] = i
for i, cand in enumerate(cand_words):
current_row = ['D' * (i + 1)]
curr_costs[0] = i+1
for j, gold in enumerate(gold_words):
if gold.lower() == cand.lower():
s_cost = prev_costs[j]
i_cost = curr_costs[j] + 1
d_cost = prev_costs[j + 1] + 1
else:
s_cost = prev_costs[j] + 1
i_cost = curr_costs[j] + 1
d_cost = prev_costs[j + 1] + (1 if cand else 0)
if s_cost <= i_cost and s_cost <= d_cost:
best_cost = s_cost
best_hist = previous_row[j] + ('M' if gold == cand else 'S')
elif i_cost <= s_cost and i_cost <= d_cost:
best_cost = i_cost
best_hist = current_row[j] + 'I'
else:
best_cost = d_cost
best_hist = previous_row[j + 1] + 'D'
current_row.append(best_hist)
curr_costs[j+1] = best_cost
previous_row = current_row
for j in range(len(gold_words) + 1):
prev_costs[j] = curr_costs[j]
curr_costs[j] = 0
return prev_costs[n_gold], previous_row[-1]
def read_json_file(loc, docs_filter=None):
print loc
if path.isdir(loc):
for filename in os.listdir(loc):
yield from read_json_file(path.join(loc, filename))
else:
with open(loc) as file_:
docs = ujson.load(file_)
for doc in docs:
if docs_filter is not None and not docs_filter(doc):
continue
paragraphs = []
for paragraph in doc['paragraphs']:
sents = []
for sent in paragraph['sentences']:
words = []
ids = []
tags = []
heads = []
labels = []
ner = []
wsd = []
for i, token in enumerate(sent['tokens']):
words.append(token['orth'])
ids.append(i)
tags.append(token['tag'])
heads.append(token['head'] + i)
labels.append(token['dep'])
# Ensure ROOT label is case-insensitive
if labels[-1].lower() == 'root':
labels[-1] = 'ROOT'
ner.append(token.get('ner', '-'))
t_wsd = [s.replace('.', '_') for s in token.get('ssenses', [])]
wsd.append(t_wsd)
sents.append((
(ids, words, tags, heads, labels, ner, wsd),
sent.get('brackets', [])))
if sents:
yield (paragraph.get('raw', None), sents)
def _iob_to_biluo(tags):
out = []
curr_label = None
tags = list(tags)
while tags:
out.extend(_consume_os(tags))
out.extend(_consume_ent(tags))
return out
def _consume_os(tags):
while tags and tags[0] == 'O':
yield tags.pop(0)
def _consume_ent(tags):
if not tags:
return []
target = tags.pop(0).replace('B', 'I')
length = 1
while tags and tags[0] == target:
length += 1
tags.pop(0)
label = target[2:]
if length == 1:
return ['U-' + label]
else:
start = 'B-' + label
end = 'L-' + label
middle = ['I-%s' % label for _ in range(1, length - 1)]
return [start] + middle + [end]
cdef class GoldParse:
def __init__(self, tokens, annot_tuples, brackets=tuple(), make_projective=False):
self.mem = Pool()
self.loss = 0
self.length = len(tokens)
# These are filled by the tagger/parser/entity recogniser
self.c.tags = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c.ssenses = <flags_t*>self.mem.alloc(len(tokens), sizeof(flags_t))
self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
self.c.brackets = <int**>self.mem.alloc(len(tokens), sizeof(int*))
for i in range(len(tokens)):
self.c.brackets[i] = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.tags = [None] * len(tokens)
self.heads = [None] * len(tokens)
self.labels = [''] * len(tokens)
self.ner = ['-'] * len(tokens)
self.ssenses = [[] for _ in range(len(tokens))]
self.cand_to_gold = align([t.orth_ for t in tokens], annot_tuples[1])
self.gold_to_cand = align(annot_tuples[1], [t.orth_ for t in tokens])
self.orig_annot = zip(*annot_tuples)
# This iterates 0...n for n words in the candidate, with an index
# gold_i aligned into the gold. Assign tag, label, ner and word sense.
# For the head, the value is an index into the gold sentence, so we
# have to translate it across into the candidate.
for i, gold_i in enumerate(self.cand_to_gold):
if gold_i is None:
# Missing values handled in the various oracle functions
pass
else:
self.tags[i] = annot_tuples[2][gold_i]
self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]]
self.labels[i] = annot_tuples[4][gold_i]
self.ner[i] = annot_tuples[5][gold_i]
self.ssenses[i] = annot_tuples[6][gold_i]
# If we have any non-projective arcs, i.e. crossing brackets, consider
# the heads for those words missing in the gold-standard.
# This way, we can train from these sentences
cdef int w1, w2, h1, h2
if make_projective:
heads = list(self.heads)
for w1 in range(self.length):
if heads[w1] is not None:
h1 = heads[w1]
for w2 in range(w1+1, self.length):
if heads[w2] is not None:
h2 = heads[w2]
if _arcs_cross(w1, h1, w2, h2):
self.heads[w1] = None
self.labels[w1] = ''
self.heads[w2] = None
self.labels[w2] = ''
self.ssenses[w2] = []
# Check there are no cycles in the dependencies, i.e. we are a tree
for w in range(self.length):
seen = set([w])
head = w
while self.heads[head] != head and self.heads[head] != None:
head = self.heads[head]
if head in seen:
raise Exception("Cycle found: %s" % seen)
seen.add(head)
self.brackets = {}
for (gold_start, gold_end, label_str) in brackets:
start = self.gold_to_cand[gold_start]
end = self.gold_to_cand[gold_end]
if start is not None and end is not None:
self.brackets.setdefault(start, {}).setdefault(end, set())
self.brackets[end][start].add(label_str)
def __len__(self):
return self.length
@property
def is_projective(self):
heads = list(self.heads)
for w1 in range(self.length):
if heads[w1] is not None:
h1 = heads[w1]
for w2 in range(self.length):
if heads[w2] is not None and _arcs_cross(w1, h1, w2, heads[w2]):
return False
return True
cdef int _arcs_cross(int w1, int h1, int w2, int h2) except -1:
if w1 > h1:
w1, h1 = h1, w1
if w2 > h2:
w2, h2 = h2, w2
if w1 > w2:
w1, h1, w2, h2 = w2, h2, w1, h1
return w1 < w2 < h1 < h2 or w1 < w2 == h2 < h1
def is_punct_label(label):
return label == 'P' or label.lower() == 'punct'