Fix training methods

This commit is contained in:
Matthew Honnibal 2017-04-16 13:00:37 -05:00
parent 6a4221a6de
commit 89a4f262fc
3 changed files with 16 additions and 12 deletions

View File

@ -2,8 +2,8 @@
from __future__ import unicode_literals, division, print_function
import json
from pathlib import Path
from ..util import ensure_path
from ..scorer import Scorer
from ..gold import GoldParse, merge_sents
from ..gold import read_json_file as read_gold_json
@ -12,9 +12,9 @@ from .. import util
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
parser_L1):
output_path = Path(output_dir)
train_path = Path(train_data)
dev_path = Path(dev_data)
output_path = ensure_path(output_dir)
train_path = ensure_path(train_data)
dev_path = ensure_path(dev_data)
check_dirs(output_path, train_path, dev_path)
lang = util.get_lang_class(language)
@ -43,7 +43,7 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
def train_config(config):
config_path = Path(config)
config_path = ensure_path(config)
if not config_path.is_file():
util.sys_exit(config_path.as_posix(), title="Config file not found")
config = json.load(config_path)
@ -57,7 +57,8 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
entity_cfg, n_iter):
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer:
with Language.train(output_path, train_data,
pos=tagger_cfg, deps=parser_cfg, ner=entity_cfg) as trainer:
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
for doc, gold in epoch:
trainer.update(doc, gold)

View File

@ -5,9 +5,9 @@ from __future__ import unicode_literals, print_function
import io
import re
import ujson
from pathlib import Path
from .syntax import nonproj
from .util import ensure_path
def tags_to_entities(tags):
@ -139,12 +139,12 @@ def _min_edit_path(cand_words, gold_words):
def read_json_file(loc, docs_filter=None):
loc = Path(loc)
loc = ensure_path(loc)
if loc.is_dir():
for filename in loc.iterdir():
yield from read_json_file(loc / filename)
else:
with io.open(loc, 'r', encoding='utf8') as file_:
with loc.open('r', encoding='utf8') as file_:
docs = ujson.load(file_)
for doc in docs:
if docs_filter is not None and not docs_filter(doc):

View File

@ -204,15 +204,18 @@ class Language(object):
@classmethod
@contextmanager
def train(cls, path, gold_tuples, **configs):
if parser_cfg['pseudoprojective']:
parser_cfg = configs.get('deps', {})
if parser_cfg.get('pseudoprojective'):
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
for subdir in ('deps', 'ner', 'pos'):
if subdir not in configs:
configs[subdir] = {}
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
if parser_cfg:
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
if 'ner' in configs:
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
cls.setup_directory(path, **configs)