* Hacks to conll.pyx. Should clean these up.

This commit is contained in:
Matthew Honnibal 2015-03-08 00:14:48 -05:00
parent f321b2b2eb
commit 5278c7504b

View File

@ -12,12 +12,25 @@ cdef class GoldParse:
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
@property
def n_non_punct(self):
return len([l for l in self.labels if l != 'P'])
@property
def py_heads(self):
return [self.c_heads[i] for i in range(self.length)]
cdef int heads_correct(self, TokenC* tokens, bint score_punct=False) except -1:
n = 0
for i in range(self.length):
if not score_punct and self.labels[i] == 'P':
continue
n += (i + tokens[i].head) == self.c_heads[i]
return n
def is_correct(self, i, head):
return head == self.c_heads[i]
@classmethod
def from_conll(cls, unicode sent_str):
ids = []
@ -96,6 +109,10 @@ cdef class GoldParse:
self.c_heads = <int*>self.mem.alloc(self.length, sizeof(int))
self.c_labels = <int*>self.mem.alloc(self.length, sizeof(int))
self.ids = [token.idx for token in tokens]
self.map_heads(label_ids)
return self.loss
def map_heads(self, label_ids):
mapped_heads = _map_indices_to_tokens(self.ids, self.heads)
for i in range(self.length):
if mapped_heads[i] is None:
@ -121,7 +138,6 @@ def _map_indices_to_tokens(ids, heads):
return mapped
def _parse_line(line):
pieces = line.split()
if len(pieces) == 4: