This commit is contained in:
Matthew Honnibal 2015-03-09 01:46:22 -04:00
parent 220ce8bfed
commit b3eda03c9c
9 changed files with 336 additions and 240 deletions

View File

@ -42,9 +42,17 @@ cdef struct PosTag:
univ_pos_t pos
cdef struct Entity:
int start
int end
int tag
int label
cdef struct TokenC:
const LexemeC* lex
Morphology morph
Entity ent
univ_pos_t pos
int tag
int idx

View File

@ -2,15 +2,17 @@ from libc.stdint cimport uint32_t
from cymem.cymem cimport Pool
from ..structs cimport TokenC
from ..structs cimport TokenC, Entity
cdef struct State:
TokenC* sent
int* stack
Entity* ent
int i
int sent_len
int stack_len
int ents_len
cdef int add_dep(const State *s, const int head, const int child, const int label) except -1

View File

@ -35,16 +35,16 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
LEFT: {}, BREAK: {'ROOT': True}}
for parse in gold_parses:
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
move_labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
LEFT: {}, BREAK: {'ROOT': True}}
for raw_text, segmented, (ids, tags, heads, labels, iob) in gold_parses:
for i, (head, label) in enumerate(zip(heads, labels)):
if label != 'ROOT':
if head > i:
labels[RIGHT][label] = True
move_labels[RIGHT][label] = True
elif head < i:
labels[LEFT][label] = True
return labels
move_labels[LEFT][label] = True
return move_labels
cdef Transition init_transition(self, int clas, int move, int label) except *:
# TODO: Apparent Cython bug here when we try to use the Transition()

View File

@ -1,22 +1,33 @@
from cymem.cymem cimport Pool
from ..structs cimport TokenC
from .transition_system cimport Transition
cdef class GoldParse:
cdef Pool mem
cdef int length
cdef readonly int loss
cdef readonly object ids
cdef readonly object tags
cdef readonly object heads
cdef readonly object labels
cdef readonly object tags_
cdef readonly object labels_
cdef readonly object ner_
cdef Transition* ner
cdef int* c_heads
cdef int* c_labels
cdef int length
cdef int loss
cdef readonly unicode raw_text
cdef readonly list words
cdef readonly list ids
cdef readonly list tags
cdef readonly list heads
cdef readonly list labels
cdef int heads_correct(self, TokenC* tokens, bint score_punct=?) except -1
cdef class NERAnnotation:
cdef Pool mem
cdef int* starts
cdef int* ends
cdef int* labels
cdef readonly list entities

View File

@ -1,67 +1,24 @@
cdef class GoldParse:
def __init__(self, raw_text, words, ids, tags, heads, labels):
self.mem = Pool()
self.loss = 0
self.length = len(words)
self.raw_text = raw_text
self.words = words
self.ids = ids
self.tags = tags
self.heads = heads
self.labels = labels
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
import numpy
import codecs
from .ner_util import iob_to_biluo
@property
def n_non_punct(self):
return len([l for l in self.labels if l != 'P'])
from libc.string cimport memset
@property
def py_heads(self):
return [self.c_heads[i] for i in range(self.length)]
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
n = 0
for i in range(self.length):
if not score_punct and self.labels[i] == 'P':
continue
n += (i + tokens[i].head) == self.c_heads[i]
return n
def is_correct(self, i, head):
return head == self.c_heads[i]
@classmethod
def from_conll(cls, unicode sent_str):
ids = []
words = []
heads = []
labels = []
tags = []
for i, line in enumerate(sent_str.split('\n')):
id_, word, pos_string, head_idx, label = _parse_line(line)
words.append(word)
if head_idx == -1:
head_idx = i
ids.append(id_)
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
text = ' '.join(words)
return cls(text, [words], ids, words, tags, heads, labels)
@classmethod
def from_docparse(cls, unicode sent_str):
def read_docparse_file(loc):
sents = []
for sent_str in codecs.open(loc, 'r', 'utf8').read().strip().split('\n\n'):
words = []
heads = []
labels = []
tags = []
ids = []
iob_ents = []
lines = sent_str.strip().split('\n')
raw_text = lines.pop(0).strip()
tok_text = lines.pop(0).strip()
for i, line in enumerate(lines):
id_, word, pos_string, head_idx, label = _parse_line(line)
id_, word, pos_string, head_idx, label, iob_ent = _parse_line(line)
if label == 'root':
label = 'ROOT'
words.append(word)
@ -71,57 +28,78 @@ cdef class GoldParse:
heads.append(head_idx)
labels.append(label)
tags.append(pos_string)
tokenized = [sent_str.replace('<SEP>', ' ').split(' ')
for sent_str in tok_text.split('<SENT>')]
return cls(raw_text, words, ids, tags, heads, labels)
iob_ents.append(iob_ent)
tokenized = [s.replace('<SEP>', ' ').split(' ')
for s in tok_text.split('<SENT>')]
sents.append((raw_text, tokenized, (ids, tags, heads, labels, iob_ents)))
return sents
def align_to_tokens(self, tokens, label_ids):
orig_words = list(self.words)
annot = zip(self.ids, self.tags, self.heads, self.labels)
self.ids = []
self.tags = []
self.heads = []
self.labels = []
missed = []
for token in tokens:
while annot and token.idx > annot[0][0]:
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
if not is_punct_label(miss_label):
self.loss += 1
if not annot:
self.tags.append(None)
self.heads.append(None)
self.labels.append(None)
continue
id_, tag, head, label = annot[0]
if token.idx == id_:
self.tags.append(tag)
self.heads.append(head)
self.labels.append(label)
annot.pop(0)
elif token.idx < id_:
self.tags.append(None)
self.heads.append(None)
self.labels.append(None)
else:
raise StandardError
cdef class GoldParse:
def __init__(self, tokens, annot_tuples, pos_tags, dep_labels, entity_types):
self.mem = Pool()
self.loss = 0
self.length = len(tokens)
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
self.ids = [token.idx for token in tokens]
self.map_heads(label_ids)
return self.loss
self.ids = numpy.empty(shape=(len(tokens), 1), dtype=numpy.int32)
self.tags = numpy.empty(shape=(len(tokens), 1), dtype=numpy.int32)
self.heads = numpy.empty(shape=(len(tokens), 1), dtype=numpy.int32)
self.labels = numpy.empty(shape=(len(tokens), 1), dtype=numpy.int32)
def map_heads(self, label_ids):
mapped_heads = _map_indices_to_tokens(self.ids, self.heads)
for i in range(self.length):
if mapped_heads[i] is None:
self.ids[:] = -1
self.tags[:] = -1
self.heads[:] = -1
self.labels[:] = -1
self.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition))
self.c_heads = <int*>self.mem.alloc(len(tokens), sizeof(int))
self.c_labels = <int*>self.mem.alloc(len(tokens), sizeof(int))
for i in range(len(tokens)):
self.c_heads[i] = -1
self.c_labels[i] = -1
self.tags_ = [None] * len(tokens)
self.labels_ = [None] * len(tokens)
self.ner_ = [None] * len(tokens)
idx_map = {token.idx: token.i for token in tokens}
print idx_map
# TODO: Fill NER moves
print raw_text
for idx, tag, head, label, ner in zip(*annot_tuples):
if idx < tokens[0].idx:
pass
elif idx > tokens[-1].idx:
break
elif idx in idx_map:
i = idx_map[idx]
print i, idx, head, idx_map.get(head, -1)
self.ids[i] = idx
self.tags[i] = pos_tags.index(tag)
self.heads[i] = idx_map.get(head, -1)
self.labels[i] = dep_labels[label]
self.c_heads[i] = -1
self.c_labels[i] = -1
else:
self.c_heads[i] = mapped_heads[i]
self.c_labels[i] = label_ids[self.labels[i]]
return self.loss
self.tags_[i] = tag
self.labels_[i] = label
self.ner_[i] = ner
@property
def n_non_punct(self):
return len([l for l in self.labels if l != 'P'])
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
n = 0
for i in range(self.length):
if not score_punct and self.labels_[i] == 'P':
continue
if self.heads[i] == -1:
continue
n += (i + tokens[i].head) == self.heads[i]
return n
def is_correct(self, i, head):
return head == self.c_heads[i]
def is_punct_label(label):
@ -146,6 +124,63 @@ def _parse_line(line):
id_ = int(pieces[0])
word = pieces[1]
pos = pieces[3]
iob_ent = pieces[5]
head_idx = int(pieces[6])
label = pieces[7]
return id_, word, pos, head_idx, label
return id_, word, pos, head_idx, label, iob_ent
cdef class NERAnnotation:
def __init__(self, entities, length, entity_types):
self.mem = Pool()
self.starts = <int*>self.mem.alloc(length, sizeof(int))
self.ends = <int*>self.mem.alloc(length, sizeof(int))
self.labels = <int*>self.mem.alloc(length, sizeof(int))
self.entities = entities
memset(self.starts, -1, sizeof(int) * length)
memset(self.ends, -1, sizeof(int) * length)
memset(self.labels, -1, sizeof(int) * length)
cdef int start, end, label
for start, end, label in entities:
for i in range(start, end):
self.starts[i] = start
self.ends[i] = end
self.labels[i] = label
@property
def biluo_tags(self):
pass
@property
def iob_tags(self):
pass
@classmethod
def from_iobs(cls, iob_strs, entity_types):
return cls.from_biluos(iob_to_biluo(iob_strs), entity_types)
@classmethod
def from_biluos(cls, tag_strs, entity_types):
entities = []
start = None
for i, tag_str in enumerate(tag_strs):
if tag_str == 'O' or tag_str == '-':
continue
move, label_str = tag_str.split('-')
label = entity_types.index(label_str)
if label == -1:
label = len(entity_types)
entity_types.append(label)
if move == 'U':
assert start is None
entities.append((i, i+1, label))
elif move == 'B':
assert start is None
start = i
elif move == 'L':
assert start is not None
entities.append((start, i+1, label))
start = None
return cls(entities, len(tag_strs), entity_types)

View File

@ -1,28 +1,7 @@
from cymem.cymem cimport Pool
from thinc.typedefs cimport weight_t
from .transition_system cimport TransitionSystem
from .transition_system cimport Transition
from ._state cimport State
from ._state cimport State
cdef struct Transition:
int clas
int move
int label
int cost
weight_t score
cdef class TransitionSystem:
cdef Pool mem
cdef readonly int n_moves
cdef dict label_ids
cdef const Transition* _moves
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
const State* s,
const int* gold_heads, const int* gold_labels) except *
cdef int transition(self, State *s, const Transition* t) except -1
cdef class BiluoPushDown(TransitionSystem):
pass

View File

@ -1,16 +1,15 @@
from __future__ import unicode_literals
from ._state cimport State
from ._state cimport has_head, get_idx, get_s0, get_n0
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
from ._state cimport head_in_buffer, children_in_buffer
from ._state cimport head_in_stack, children_in_stack
from ..structs cimport TokenC
from .transition_system cimport Transition
from .transition_system cimport do_func_t
from ..structs cimport TokenC, Entity
DEF NON_MONOTONIC = True
DEF USE_BREAK = True
from thinc.typedefs cimport weight_t
from .conll cimport GoldParse
from .ner_util import iob_to_biluo
cdef enum:
@ -23,13 +22,34 @@ cdef enum:
N_MOVES
cdef int is_valid(ActionType act, int label, State* s) except -1:
cdef do_func_t[N_MOVES] do_funcs
cdef bint entity_is_open(const State *s) except -1:
return s.sent[s.i - 1].ent.tag >= 1
cdef bint _entity_is_sunk(const State *s, Transition* golds) except -1:
if not entity_is_open(s):
return False
cdef const Entity* curr = &s.sent[s.i - 1].ent
cdef const Transition* gold = &golds[(s.i - 1) + curr.start]
if gold.move != BEGIN and gold.move != UNIT:
return True
elif gold.label != s.ent.label:
return True
else:
return False
cdef int _is_valid(int act, int label, const State* s) except -1:
if act == BEGIN:
return not entity_is_open(s)
elif act == IN:
return entity_is_open(s) and s.curr.label == label
return entity_is_open(s) and s.ent.label == label
elif act == LAST:
return entity_is_open(s) and s.curr.label == label
return entity_is_open(s) and s.ent.label == label
elif act == UNIT:
return not entity_is_open(s)
elif act == OUT:
@ -38,8 +58,56 @@ cdef int is_valid(ActionType act, int label, State* s) except -1:
raise UnknownMove(act, label)
cdef bint is_gold(ActionType act, int tag, ActionType g_act, int g_tag,
ActionType next_act, bint is_sunk):
cdef class BiluoPushDown(TransitionSystem):
@classmethod
def get_labels(cls, gold_tuples):
move_labels = {BEGIN: {}, IN: {}, LAST: {}, UNIT: {}, OUT: {'ROOT': True}}
moves = ('-', 'B', 'I', 'L', 'U')
for (raw_text, toks, (ids, tags, heads, labels, iob)) in gold_tuples:
for i, ner_tag in enumerate(iob_to_biluo(iob)):
if ner_tag != 'O' and ner_tag != '-':
move_str, label = ner_tag.split('-')
move_labels[moves.index(move_str)][label] = True
return move_labels
cdef Transition init_transition(self, int clas, int move, int label) except *:
# TODO: Apparent Cython bug here when we try to use the Transition()
# constructor with the function pointers
cdef Transition t
t.score = 0
t.clas = clas
t.move = move
t.label = label
t.do = do_funcs[move]
t.get_cost = _get_cost
return t
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef int best = -1
cdef weight_t score = -90000
cdef const Transition* m
cdef int i
for i in range(self.n_moves):
m = &self.c[i]
if _is_valid(m.move, m.label, s) and scores[i] > score:
best = i
score = scores[i]
assert best >= 0
cdef Transition t = self.c[best]
t.score = score
return t
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _is_valid(self.move, self.label, s):
return 9000
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
return not _is_gold(self.move, self.label, gold.ner[s.i].move, gold.ner[s.i].label,
next_act, is_sunk)
cdef bint _is_gold(int act, int tag, int g_act, int g_tag,
int next_act, bint is_sunk):
if g_act == MISSING:
return True
if act == BEGIN:
@ -112,98 +180,46 @@ cdef bint is_gold(ActionType act, int tag, ActionType g_act, int g_tag,
return False
cdef bint entity_is_open(State *s) except -1:
return s.sent[s.i - 1].ent.tag >= 1
cdef int _do_begin(const Transition* self, State* s) except -1:
s.ent += 1
s.ents_len += 1
s.ent.start = s.i
s.ent.label = self.label
s.sent[s.i].ent.tag = self.clas
s.i += 1
cdef bint entity_is_sunk(State *s, Move* golds) except -1:
if not entity_is_open(s):
return False
cdef const Entity* curr = &s.sent[s.i - 1].ent
cdef Move* gold = &golds[(s.i - 1) + curr.start]
if gold.action != BEGIN and gold.action != UNIT:
return True
elif gold.label != s.curr.label:
return True
else:
return False
cdef int _do_in(const Transition* self, State* s) except -1:
s.sent[s.i].ent.tag = self.clas
s.i += 1
cdef class TransitionSystem:
def __init__(self, list entity_type_strs):
self.mem = Pool()
cdef int _do_last(const Transition* self, State* s) except -1:
s.ent.end = s.i+1
s.sent[s.i].ent.tag = self.clas
s.i += 1
cdef Move* m
label_names = {'-': 0}
for i, tag_name in enumerate(tag_names):
m = &moves[i]
if '-' in tag_name:
action_str, label = tag_name.split('-')
elif tag_name == 'O':
action_str = 'O'
label = '-'
elif tag_name == 'NULL' or tag_name == 'EOL':
action_str = '?'
label = '-'
else:
raise StandardError(tag_name)
m.action = ACTION_NAMES.index(action_str)
m.label = label_names.setdefault(label, len(label_names))
m.clas = i
cdef int transition(self, State *s, Move* move) except -1:
if move.action == BEGIN:
s.curr.start = s.i
s.curr.label = label
elif move.action == IN:
pass
elif move.action == LAST:
s.curr.end = s.i
s.ents[s.j] = s.curr
s.j += 1
s.curr.start = 0
s.curr.label = -1
s.curr.end = 0
elif move.action == UNIT:
begin_entity(s, move.label)
end_entity(s)
elif move.action == OUT:
pass
s.tags[s.i] = move.clas
s.i += 1
cdef int _do_unit(const Transition* self, State* s) except -1:
s.ent += 1
s.ents_len += 1
s.ent.start = s.i
s.ent.label = self.label
s.ent.end = s.i+1
s.sent[s.i].ent.tag = self.clas
s.i += 1
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef int best = -1
cdef weight_t score = -90000
cdef const Transition* m
cdef int i
for i in range(self.n_moves):
m = &self._moves[i]
if _is_valid(s, m.ent_move, m.ent_label) and scores[i] > score:
best = i
score = scores[i]
assert best >= 0
cdef Transition t = self._moves[best]
t.score = score
return t
cdef Transition best_gold(self, Transition* guess, const weight_t* scores,
const State* s, Move* golds) except *:
cdef Move* g = &golds[s.i]
cdef ActionType next_act = <ActionType>golds[s.i+1].action if s.i < s.length else OUT
cdef bint is_sunk = entity_is_sunk(s, golds)
cdef Move* m
cdef int n_accept = 0
for i in range(1, self.n_classes):
m = &moves[i]
if _is_valid(s, m.move, m.label) and \
_is_gold(s, m.move, m.label, next_act, is_sunk) and \
scores[i] > score:
best = i
score = scores[i]
assert best >= 0
return self._moves[best]
cdef int _do_out(const Transition* self, State* s) except -1:
s.sent[s.i].ent.tag = self.clas
s.i += 1
do_funcs[BEGIN] = _do_begin
do_funcs[IN] = _do_in
do_funcs[LAST] = _do_last
do_funcs[UNIT] = _do_unit
do_funcs[OUT] = _do_out
class OracleError(Exception):
@ -212,3 +228,5 @@ class OracleError(Exception):
class UnknownMove(Exception):
pass

View File

@ -35,3 +35,10 @@ cdef class TransitionSystem:
cdef Transition best_gold(self, const weight_t* scores, const State* state,
GoldParse gold) except *
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# cdef Pool mem
# cdef TransitionSystem system
# cdef State* _state

View File

@ -45,3 +45,39 @@ cdef class TransitionSystem:
score = scores[i]
assert score > MIN_SCORE
return best
#cdef class PyState:
# """Provide a Python class for testing purposes."""
# def __init__(self, GoldParse gold):
# self.mem = Pool()
# self.system = EntityRecognition(labels)
# self._state = init_state(self.mem, tokens, gold.length)
#
# def transition(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# trans.do(trans, self._state)
#
# def is_valid(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _is_valid(trans.move, trans.label, self._state)
#
# def is_gold(self, name):
# cdef const Transition* trans = self._transition_by_name(name)
# return _get_const(trans, self._state, self._gold)
#
# property ent:
# def __get__(self):
# pass
#
# property n_ents:
# def __get__(self):
# pass
#
# property i:
# def __get__(self):
# pass
#
# property open_entity:
# def __get__(self):
# return entity_is_open(self._s)