mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Update ArcEager for changes to GoldParse class
This commit is contained in:
parent
c9d314b7ba
commit
aa5ecf7fd2
|
@ -19,6 +19,7 @@ from ..structs cimport TokenC
|
|||
|
||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||
cdef int BINARY_COSTS = 1
|
||||
cdef int MAX_SPLIT = 4
|
||||
|
||||
DEF NON_MONOTONIC = True
|
||||
DEF USE_BREAK = True
|
||||
|
@ -437,12 +438,12 @@ cdef class ArcEager(TransitionSystem):
|
|||
# TODO: Split?
|
||||
return actions
|
||||
|
||||
property max_split:
|
||||
def __get__(self):
|
||||
return self.cfg.get('max_split', 0)
|
||||
#property max_split:
|
||||
# def __get__(self):
|
||||
# return self.cfg.get('max_split', 0)
|
||||
|
||||
def __set__(self, int value):
|
||||
self.cfg['max_split'] = value
|
||||
# def __set__(self, int value):
|
||||
# self.cfg['max_split'] = value
|
||||
|
||||
property action_types:
|
||||
def __get__(self):
|
||||
|
@ -464,14 +465,15 @@ cdef class ArcEager(TransitionSystem):
|
|||
predicted = set()
|
||||
truth = set()
|
||||
for i in range(gold.length):
|
||||
if gold.cand_to_gold[i] is None:
|
||||
gold_i = gold._alignment.index_to_yours(i)
|
||||
if gold_i is None:
|
||||
continue
|
||||
if state.safe_get(i).dep:
|
||||
predicted.add((i, state.H(i),
|
||||
self.strings[state.safe_get(i).dep]))
|
||||
else:
|
||||
predicted.add((i, state.H(i), 'ROOT'))
|
||||
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
|
||||
id_, word, tag, head, dep, ner = gold.orig_annot[gold_i]
|
||||
truth.add((id_, head, dep))
|
||||
return truth == predicted
|
||||
|
||||
|
@ -487,23 +489,23 @@ cdef class ArcEager(TransitionSystem):
|
|||
return None
|
||||
subtok_label = self.strings['subtok']
|
||||
if USE_SPLIT:
|
||||
gold.resize_arrays(self.max_split * len(gold))
|
||||
gold.resize_arrays(MAX_SPLIT * len(gold))
|
||||
# Subtokens are addressed by (subposition, position).
|
||||
# This way the 'normal' tokens (at subposition 0) occupy positions
|
||||
# 0...n in the array.
|
||||
for i in range(1, self.max_split-1):
|
||||
for i in range(1, MAX_SPLIT-1):
|
||||
for j in range(len(gold)):
|
||||
index = i * len(gold) + j
|
||||
# If we've incorrectly split, we want to join them back
|
||||
# up -- so, set the head of each subtoken to the following
|
||||
# subtoken (until the end), and set the label to 'subtok'.
|
||||
gold.c.heads[index] = (i+1)*len(gold) + j
|
||||
gold.c.dep[index] = subtok_label
|
||||
gold.c.labels[index] = subtok_label
|
||||
gold.c.has_dep[index] = True
|
||||
for j in range(len(gold)):
|
||||
# For the last subtoken in each position, set head to 'unknown'.
|
||||
gold.c.heads[index] = index
|
||||
gold.c.deps[index] = 0
|
||||
gold.c.labels[index] = 0
|
||||
gold.c.has_dep[index] = False
|
||||
for child_i, (head_group, dep_group) in enumerate(zip(gold.heads, gold.labels)):
|
||||
if not USE_SPLIT and isinstance(head_group, list):
|
||||
|
|
Loading…
Reference in New Issue
Block a user