2017-03-15 17:29:42 +03:00
|
|
|
# cython: profile=True
|
2016-10-09 13:24:24 +03:00
|
|
|
from __future__ import unicode_literals, print_function
|
|
|
|
|
2015-09-30 21:20:09 +03:00
|
|
|
import io
|
2015-05-06 17:27:31 +03:00
|
|
|
import json
|
2015-05-24 22:50:48 +03:00
|
|
|
import re
|
2015-05-29 04:52:55 +03:00
|
|
|
import os
|
|
|
|
from os import path
|
2015-03-08 08:14:48 +03:00
|
|
|
|
2016-10-20 22:23:26 +03:00
|
|
|
import ujson as json
|
2015-07-25 19:11:36 +03:00
|
|
|
|
2016-03-01 12:09:08 +03:00
|
|
|
from .syntax import nonproj
|
2016-02-22 16:40:40 +03:00
|
|
|
|
2015-02-21 19:06:58 +03:00
|
|
|
|
2015-06-08 01:54:13 +03:00
|
|
|
def tags_to_entities(tags):
|
|
|
|
entities = []
|
|
|
|
start = None
|
|
|
|
for i, tag in enumerate(tags):
|
2016-11-25 17:57:59 +03:00
|
|
|
if tag is None:
|
|
|
|
continue
|
2015-06-08 01:54:13 +03:00
|
|
|
if tag.startswith('O'):
|
|
|
|
# TODO: We shouldn't be getting these malformed inputs. Fix this.
|
|
|
|
if start is not None:
|
|
|
|
start = None
|
|
|
|
continue
|
|
|
|
elif tag == '-':
|
|
|
|
continue
|
|
|
|
elif tag.startswith('I'):
|
|
|
|
assert start is not None, tags[:i]
|
|
|
|
continue
|
|
|
|
if tag.startswith('U'):
|
|
|
|
entities.append((tag[2:], i, i))
|
|
|
|
elif tag.startswith('B'):
|
|
|
|
start = i
|
|
|
|
elif tag.startswith('L'):
|
|
|
|
entities.append((tag[2:], start, i))
|
|
|
|
start = None
|
|
|
|
else:
|
|
|
|
raise Exception(tag)
|
|
|
|
return entities
|
|
|
|
|
|
|
|
|
2016-10-13 04:24:29 +03:00
|
|
|
def merge_sents(sents):
|
|
|
|
m_deps = [[], [], [], [], [], []]
|
|
|
|
m_brackets = []
|
|
|
|
i = 0
|
|
|
|
for (ids, words, tags, heads, labels, ner), brackets in sents:
|
|
|
|
m_deps[0].extend(id_ + i for id_ in ids)
|
|
|
|
m_deps[1].extend(words)
|
|
|
|
m_deps[2].extend(tags)
|
|
|
|
m_deps[3].extend(head + i for head in heads)
|
|
|
|
m_deps[4].extend(labels)
|
|
|
|
m_deps[5].extend(ner)
|
|
|
|
m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets)
|
|
|
|
i += len(ids)
|
|
|
|
return [(m_deps, m_brackets)]
|
|
|
|
|
2015-06-08 01:54:13 +03:00
|
|
|
|
2015-05-24 22:50:48 +03:00
|
|
|
def align(cand_words, gold_words):
|
|
|
|
cost, edit_path = _min_edit_path(cand_words, gold_words)
|
|
|
|
alignment = []
|
|
|
|
i_of_gold = 0
|
|
|
|
for move in edit_path:
|
|
|
|
if move == 'M':
|
|
|
|
alignment.append(i_of_gold)
|
|
|
|
i_of_gold += 1
|
|
|
|
elif move == 'S':
|
|
|
|
alignment.append(None)
|
|
|
|
i_of_gold += 1
|
|
|
|
elif move == 'D':
|
|
|
|
alignment.append(None)
|
|
|
|
elif move == 'I':
|
|
|
|
i_of_gold += 1
|
|
|
|
else:
|
|
|
|
raise Exception(move)
|
|
|
|
return alignment
|
|
|
|
|
|
|
|
|
|
|
|
punct_re = re.compile(r'\W')
|
|
|
|
def _min_edit_path(cand_words, gold_words):
|
|
|
|
cdef:
|
|
|
|
Pool mem
|
|
|
|
int i, j, n_cand, n_gold
|
|
|
|
int* curr_costs
|
|
|
|
int* prev_costs
|
|
|
|
|
|
|
|
# TODO: Fix this --- just do it properly, make the full edit matrix and
|
|
|
|
# then walk back over it...
|
|
|
|
# Preprocess inputs
|
2017-03-15 17:29:42 +03:00
|
|
|
cand_words = [punct_re.sub('', w) for w in cand_words]
|
|
|
|
gold_words = [punct_re.sub('', w) for w in gold_words]
|
|
|
|
|
2015-05-27 20:13:11 +03:00
|
|
|
if cand_words == gold_words:
|
2015-06-06 06:58:53 +03:00
|
|
|
return 0, ''.join(['M' for _ in gold_words])
|
2015-05-27 20:13:11 +03:00
|
|
|
mem = Pool()
|
2015-05-24 22:50:48 +03:00
|
|
|
n_cand = len(cand_words)
|
|
|
|
n_gold = len(gold_words)
|
|
|
|
# Levenshtein distance, except we need the history, and we may want different
|
|
|
|
# costs.
|
|
|
|
# Mark operations with a string, and score the history using _edit_cost.
|
|
|
|
previous_row = []
|
|
|
|
prev_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
|
|
|
|
curr_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
|
|
|
|
for i in range(n_gold + 1):
|
|
|
|
cell = ''
|
|
|
|
for j in range(i):
|
|
|
|
cell += 'I'
|
|
|
|
previous_row.append('I' * i)
|
|
|
|
prev_costs[i] = i
|
|
|
|
for i, cand in enumerate(cand_words):
|
|
|
|
current_row = ['D' * (i + 1)]
|
|
|
|
curr_costs[0] = i+1
|
|
|
|
for j, gold in enumerate(gold_words):
|
|
|
|
if gold.lower() == cand.lower():
|
|
|
|
s_cost = prev_costs[j]
|
|
|
|
i_cost = curr_costs[j] + 1
|
|
|
|
d_cost = prev_costs[j + 1] + 1
|
|
|
|
else:
|
|
|
|
s_cost = prev_costs[j] + 1
|
|
|
|
i_cost = curr_costs[j] + 1
|
|
|
|
d_cost = prev_costs[j + 1] + (1 if cand else 0)
|
|
|
|
|
|
|
|
if s_cost <= i_cost and s_cost <= d_cost:
|
|
|
|
best_cost = s_cost
|
|
|
|
best_hist = previous_row[j] + ('M' if gold == cand else 'S')
|
|
|
|
elif i_cost <= s_cost and i_cost <= d_cost:
|
|
|
|
best_cost = i_cost
|
|
|
|
best_hist = current_row[j] + 'I'
|
|
|
|
else:
|
|
|
|
best_cost = d_cost
|
|
|
|
best_hist = previous_row[j + 1] + 'D'
|
2017-03-15 17:29:42 +03:00
|
|
|
|
2015-05-24 22:50:48 +03:00
|
|
|
current_row.append(best_hist)
|
|
|
|
curr_costs[j+1] = best_cost
|
|
|
|
previous_row = current_row
|
|
|
|
for j in range(len(gold_words) + 1):
|
|
|
|
prev_costs[j] = curr_costs[j]
|
|
|
|
curr_costs[j] = 0
|
|
|
|
|
|
|
|
return prev_costs[n_gold], previous_row[-1]
|
|
|
|
|
2015-05-27 20:13:11 +03:00
|
|
|
|
2015-06-12 03:42:08 +03:00
|
|
|
def read_json_file(loc, docs_filter=None):
|
2015-05-29 04:52:55 +03:00
|
|
|
if path.isdir(loc):
|
|
|
|
for filename in os.listdir(loc):
|
|
|
|
yield from read_json_file(path.join(loc, filename))
|
|
|
|
else:
|
2016-10-20 22:23:26 +03:00
|
|
|
with io.open(loc, 'r', encoding='utf8') as file_:
|
2015-07-25 19:11:36 +03:00
|
|
|
docs = json.load(file_)
|
2015-05-30 18:54:52 +03:00
|
|
|
for doc in docs:
|
2015-06-12 03:42:08 +03:00
|
|
|
if docs_filter is not None and not docs_filter(doc):
|
|
|
|
continue
|
2015-05-30 18:54:52 +03:00
|
|
|
paragraphs = []
|
|
|
|
for paragraph in doc['paragraphs']:
|
|
|
|
sents = []
|
|
|
|
for sent in paragraph['sentences']:
|
|
|
|
words = []
|
|
|
|
ids = []
|
|
|
|
tags = []
|
|
|
|
heads = []
|
|
|
|
labels = []
|
|
|
|
ner = []
|
|
|
|
for i, token in enumerate(sent['tokens']):
|
|
|
|
words.append(token['orth'])
|
|
|
|
ids.append(i)
|
2016-04-22 17:32:27 +03:00
|
|
|
tags.append(token.get('tag','-'))
|
|
|
|
heads.append(token.get('head',0) + i)
|
2016-05-02 16:29:30 +03:00
|
|
|
labels.append(token.get('dep',''))
|
2015-06-23 05:14:03 +03:00
|
|
|
# Ensure ROOT label is case-insensitive
|
|
|
|
if labels[-1].lower() == 'root':
|
|
|
|
labels[-1] = 'ROOT'
|
2015-05-30 18:54:52 +03:00
|
|
|
ner.append(token.get('ner', '-'))
|
|
|
|
sents.append((
|
|
|
|
(ids, words, tags, heads, labels, ner),
|
|
|
|
sent.get('brackets', [])))
|
|
|
|
if sents:
|
2015-05-30 02:25:46 +03:00
|
|
|
yield (paragraph.get('raw', None), sents)
|
2015-05-06 17:27:31 +03:00
|
|
|
|
|
|
|
|
2015-04-10 05:59:11 +03:00
|
|
|
def _iob_to_biluo(tags):
|
|
|
|
out = []
|
|
|
|
curr_label = None
|
|
|
|
tags = list(tags)
|
|
|
|
while tags:
|
|
|
|
out.extend(_consume_os(tags))
|
|
|
|
out.extend(_consume_ent(tags))
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def _consume_os(tags):
|
|
|
|
while tags and tags[0] == 'O':
|
|
|
|
yield tags.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
def _consume_ent(tags):
|
|
|
|
if not tags:
|
|
|
|
return []
|
|
|
|
target = tags.pop(0).replace('B', 'I')
|
|
|
|
length = 1
|
|
|
|
while tags and tags[0] == target:
|
|
|
|
length += 1
|
|
|
|
tags.pop(0)
|
|
|
|
label = target[2:]
|
|
|
|
if length == 1:
|
|
|
|
return ['U-' + label]
|
|
|
|
else:
|
|
|
|
start = 'B-' + label
|
|
|
|
end = 'L-' + label
|
|
|
|
middle = ['I-%s' % label for _ in range(1, length - 1)]
|
|
|
|
return [start] + middle + [end]
|
|
|
|
|
|
|
|
|
2015-03-09 08:46:22 +03:00
|
|
|
cdef class GoldParse:
|
2016-11-01 14:25:36 +03:00
|
|
|
"""Collection for training annotations."""
|
2016-10-15 23:09:52 +03:00
|
|
|
@classmethod
|
2016-10-16 00:53:29 +03:00
|
|
|
def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
|
|
|
|
_, words, tags, heads, deps, entities = annot_tuples
|
|
|
|
return cls(doc, words=words, tags=tags, heads=heads, deps=deps, entities=entities,
|
|
|
|
make_projective=make_projective)
|
|
|
|
|
2016-10-16 00:55:07 +03:00
|
|
|
def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None,
|
|
|
|
deps=None, entities=None, make_projective=False):
|
2016-11-01 14:25:36 +03:00
|
|
|
"""Create a GoldParse.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
doc (Doc):
|
|
|
|
The document the annotations refer to.
|
|
|
|
words:
|
|
|
|
A sequence of unicode word strings.
|
|
|
|
tags:
|
|
|
|
A sequence of strings, representing tag annotations.
|
|
|
|
heads:
|
|
|
|
A sequence of integers, representing syntactic head offsets.
|
|
|
|
deps:
|
|
|
|
A sequence of strings, representing the syntactic relation types.
|
|
|
|
entities:
|
|
|
|
A sequence of named entity annotations, either as BILUO tag strings,
|
|
|
|
or as (start_char, end_char, label) tuples, representing the entity
|
|
|
|
positions.
|
|
|
|
Returns (GoldParse): The newly constructed object.
|
|
|
|
"""
|
2016-10-15 23:09:52 +03:00
|
|
|
if words is None:
|
|
|
|
words = [token.text for token in doc]
|
|
|
|
if tags is None:
|
|
|
|
tags = [None for _ in doc]
|
|
|
|
if heads is None:
|
2016-10-16 12:41:36 +03:00
|
|
|
heads = [token.i for token in doc]
|
2016-10-15 23:09:52 +03:00
|
|
|
if deps is None:
|
|
|
|
deps = [None for _ in doc]
|
|
|
|
if entities is None:
|
2016-11-25 17:57:59 +03:00
|
|
|
entities = ['-' for _ in doc]
|
2016-10-15 23:09:52 +03:00
|
|
|
elif len(entities) == 0:
|
|
|
|
entities = ['O' for _ in doc]
|
|
|
|
elif not isinstance(entities[0], basestring):
|
|
|
|
# Assume we have entities specified by character offset.
|
|
|
|
entities = biluo_tags_from_offsets(doc, entities)
|
|
|
|
|
2015-03-09 08:46:22 +03:00
|
|
|
self.mem = Pool()
|
|
|
|
self.loss = 0
|
2016-10-16 00:55:07 +03:00
|
|
|
self.length = len(doc)
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2015-03-09 14:06:01 +03:00
|
|
|
# These are filled by the tagger/parser/entity recogniser
|
2016-10-16 00:53:29 +03:00
|
|
|
self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
|
|
|
|
self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
|
|
|
|
self.c.labels = <int*>self.mem.alloc(len(doc), sizeof(int))
|
|
|
|
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2016-11-25 17:57:59 +03:00
|
|
|
self.words = [None] * len(doc)
|
2016-10-16 00:53:29 +03:00
|
|
|
self.tags = [None] * len(doc)
|
|
|
|
self.heads = [None] * len(doc)
|
2017-03-16 17:38:28 +03:00
|
|
|
self.labels = [None] * len(doc)
|
|
|
|
self.ner = [None] * len(doc)
|
2015-05-24 03:49:56 +03:00
|
|
|
|
2016-10-16 00:53:29 +03:00
|
|
|
self.cand_to_gold = align([t.orth_ for t in doc], words)
|
|
|
|
self.gold_to_cand = align(words, [t.orth_ for t in doc])
|
2015-05-24 18:35:49 +03:00
|
|
|
|
2016-10-16 00:53:29 +03:00
|
|
|
annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
|
2015-07-28 15:44:53 +03:00
|
|
|
self.orig_annot = list(zip(*annot_tuples))
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2015-05-24 18:35:49 +03:00
|
|
|
for i, gold_i in enumerate(self.cand_to_gold):
|
2016-10-16 12:41:36 +03:00
|
|
|
if doc[i].text.isspace():
|
2016-11-25 17:57:59 +03:00
|
|
|
self.words[i] = doc[i].text
|
2015-07-09 14:30:41 +03:00
|
|
|
self.tags[i] = 'SP'
|
|
|
|
self.heads[i] = None
|
|
|
|
self.labels[i] = None
|
|
|
|
self.ner[i] = 'O'
|
2016-10-16 12:41:36 +03:00
|
|
|
if gold_i is None:
|
2015-03-09 08:46:22 +03:00
|
|
|
pass
|
2015-05-24 03:49:56 +03:00
|
|
|
else:
|
2016-11-25 17:57:59 +03:00
|
|
|
self.words[i] = words[gold_i]
|
2016-10-16 00:53:29 +03:00
|
|
|
self.tags[i] = tags[gold_i]
|
|
|
|
self.heads[i] = self.gold_to_cand[heads[gold_i]]
|
|
|
|
self.labels[i] = deps[gold_i]
|
|
|
|
self.ner[i] = entities[gold_i]
|
2016-02-22 16:40:40 +03:00
|
|
|
|
|
|
|
cycle = nonproj.contains_cycle(self.heads)
|
|
|
|
if cycle != None:
|
|
|
|
raise Exception("Cycle found: %s" % cycle)
|
|
|
|
|
2015-05-31 02:11:56 +03:00
|
|
|
if make_projective:
|
2016-10-16 00:53:29 +03:00
|
|
|
proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads, self.labels)
|
2016-03-03 17:21:00 +03:00
|
|
|
self.heads = proj_heads
|
2015-06-23 01:02:22 +03:00
|
|
|
|
2015-03-10 20:00:23 +03:00
|
|
|
def __len__(self):
|
2016-11-01 14:25:36 +03:00
|
|
|
"""Get the number of gold-standard tokens.
|
2017-03-15 17:29:42 +03:00
|
|
|
|
2016-11-01 14:25:36 +03:00
|
|
|
Returns (int): The number of gold-standard tokens.
|
|
|
|
"""
|
2015-03-10 20:00:23 +03:00
|
|
|
return self.length
|
2015-03-09 08:46:22 +03:00
|
|
|
|
2015-05-30 02:25:46 +03:00
|
|
|
@property
|
|
|
|
def is_projective(self):
|
2016-11-01 14:25:36 +03:00
|
|
|
"""Whether the provided syntactic annotations form a projective dependency
|
|
|
|
tree."""
|
2016-02-24 13:26:25 +03:00
|
|
|
return not nonproj.is_nonproj_tree(self.heads)
|
2015-05-30 02:25:46 +03:00
|
|
|
|
2015-02-21 19:06:58 +03:00
|
|
|
|
2016-10-15 22:51:04 +03:00
|
|
|
def biluo_tags_from_offsets(doc, entities):
|
|
|
|
'''Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out
|
|
|
|
scheme (biluo).
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
doc (Doc):
|
|
|
|
The document that the entity offsets refer to. The output tags will
|
|
|
|
refer to the token boundaries within the document.
|
|
|
|
|
|
|
|
entities (sequence):
|
|
|
|
A sequence of (start, end, label) triples. start and end should be
|
|
|
|
character-offset integers denoting the slice into the original string.
|
2017-03-15 17:29:42 +03:00
|
|
|
|
2016-10-15 22:51:04 +03:00
|
|
|
Returns:
|
|
|
|
tags (list):
|
|
|
|
A list of unicode strings, describing the tags. Each tag string will
|
|
|
|
be of the form either "", "O" or "{action}-{label}", where action is one
|
2016-10-16 12:41:36 +03:00
|
|
|
of "B", "I", "L", "U". The string "-" is used where the entity
|
2016-10-15 22:51:04 +03:00
|
|
|
offsets don't align with the tokenization in the Doc object. The
|
|
|
|
training algorithm will view these as missing values. "O" denotes
|
|
|
|
a non-entity token. "B" denotes the beginning of a multi-token entity,
|
|
|
|
"I" the inside of an entity of three or more tokens, and "L" the end
|
|
|
|
of an entity of two or more tokens. "U" denotes a single-token entity.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
text = 'I like London.'
|
|
|
|
entities = [(len('I like '), len('I like London'), 'LOC')]
|
|
|
|
doc = nlp.tokenizer(text)
|
|
|
|
|
|
|
|
tags = biluo_tags_from_offsets(doc, entities)
|
2017-03-15 17:29:42 +03:00
|
|
|
|
2016-10-15 22:51:04 +03:00
|
|
|
assert tags == ['O', 'O', 'U-LOC', 'O']
|
|
|
|
'''
|
|
|
|
starts = {token.idx: token.i for token in doc}
|
|
|
|
ends = {token.idx+len(token): token.i for token in doc}
|
2016-10-16 12:41:36 +03:00
|
|
|
biluo = ['-' for _ in doc]
|
2016-10-15 22:51:04 +03:00
|
|
|
# Handle entity cases
|
|
|
|
for start_char, end_char, label in entities:
|
|
|
|
start_token = starts.get(start_char)
|
|
|
|
end_token = ends.get(end_char)
|
|
|
|
# Only interested if the tokenization is correct
|
|
|
|
if start_token is not None and end_token is not None:
|
|
|
|
if start_token == end_token:
|
|
|
|
biluo[start_token] = 'U-%s' % label
|
|
|
|
else:
|
|
|
|
biluo[start_token] = 'B-%s' % label
|
|
|
|
for i in range(start_token+1, end_token):
|
|
|
|
biluo[i] = 'I-%s' % label
|
|
|
|
biluo[end_token] = 'L-%s' % label
|
|
|
|
# Now distinguish the O cases from ones where we miss the tokenization
|
|
|
|
entity_chars = set()
|
|
|
|
for start_char, end_char, label in entities:
|
|
|
|
for i in range(start_char, end_char):
|
|
|
|
entity_chars.add(i)
|
|
|
|
for token in doc:
|
|
|
|
for i in range(token.idx, token.idx+len(token)):
|
|
|
|
if i in entity_chars:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
biluo[token.i] = 'O'
|
|
|
|
return biluo
|
|
|
|
|
|
|
|
|
2015-02-21 19:06:58 +03:00
|
|
|
def is_punct_label(label):
|
|
|
|
return label == 'P' or label.lower() == 'punct'
|