Restore support for prior data format -- specifically, the labels field of the config.

This commit is contained in:
Matthew Honnibal 2016-10-17 00:53:26 +02:00
parent c36e8676aa
commit 59038f7efa
2 changed files with 5 additions and 15 deletions

View File

@ -287,10 +287,6 @@ cdef class ArcEager(TransitionSystem):
RIGHT: {}, RIGHT: {},
LEFT: {}, LEFT: {},
BREAK: {'ROOT': True}}) BREAK: {'ROOT': True}})
for label in kwargs.get('labels', []):
if label.upper() != 'ROOT':
actions[LEFT][label] = True
actions[RIGHT][label] = True
for label in kwargs.get('left_labels', []): for label in kwargs.get('left_labels', []):
if label.upper() != 'ROOT': if label.upper() != 'ROOT':
actions[LEFT][label] = True actions[LEFT][label] = True

View File

@ -78,6 +78,9 @@ cdef class Parser:
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False): def load(cls, path, Vocab vocab, TransitionSystem=None, require=False):
with (path / 'config.json').open() as file_: with (path / 'config.json').open() as file_:
cfg = json.load(file_) cfg = json.load(file_)
# TODO: remove this shim when we don't have to support older data
if 'labels' in cfg:
cfg['actions'] = cfg.pop('labels')
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg) self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
if (path / 'model').exists(): if (path / 'model').exists():
self.model.load(str(path / 'model')) self.model.load(str(path / 'model'))
@ -188,8 +191,8 @@ cdef class Parser:
free(eg.is_valid) free(eg.is_valid)
return 0 return 0
def update(self, Doc tokens, raw_gold): def update(self, Doc tokens, GoldParse gold):
cdef GoldParse gold = self.preprocess_gold(raw_gold) self.moves.preprocess_gold(gold)
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length) cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
self.moves.initialize_state(stcls.c) self.moves.initialize_state(stcls.c)
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -226,15 +229,6 @@ cdef class Parser:
for action in self.moves.action_types: for action in self.moves.action_types:
self.moves.add_action(action, label) self.moves.add_action(action, label)
def preprocess_gold(self, raw_gold):
cdef GoldParse gold
if isinstance(raw_gold, GoldParse):
gold = raw_gold
self.moves.preprocess_gold(raw_gold)
return gold
else:
raise ValueError("Parser.preprocess_gold requires GoldParse-type input.")
cdef class StepwiseState: cdef class StepwiseState:
cdef readonly StateClass stcls cdef readonly StateClass stcls