From 89a4f262fc7ec88b6deae3e65ca75d5c138a7352 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 16 Apr 2017 13:00:37 -0500 Subject: [PATCH] Fix training methods --- spacy/cli/train.py | 13 +++++++------ spacy/gold.pyx | 6 +++--- spacy/language.py | 9 ++++++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 489430634..3900c7f39 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -2,8 +2,8 @@ from __future__ import unicode_literals, division, print_function import json -from pathlib import Path +from ..util import ensure_path from ..scorer import Scorer from ..gold import GoldParse, merge_sents from ..gold import read_json_file as read_gold_json @@ -12,9 +12,9 @@ from .. import util def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner, parser_L1): - output_path = Path(output_dir) - train_path = Path(train_data) - dev_path = Path(dev_data) + output_path = ensure_path(output_dir) + train_path = ensure_path(train_data) + dev_path = ensure_path(dev_data) check_dirs(output_path, train_path, dev_path) lang = util.get_lang_class(language) @@ -43,7 +43,7 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne def train_config(config): - config_path = Path(config) + config_path = ensure_path(config) if not config_path.is_file(): util.sys_exit(config_path.as_posix(), title="Config file not found") config = json.load(config_path) @@ -57,7 +57,8 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_ entity_cfg, n_iter): print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %") - with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer: + with Language.train(output_path, train_data, + pos=tagger_cfg, deps=parser_cfg, ner=entity_cfg) as trainer: for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): for doc, gold in epoch: trainer.update(doc, gold) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 425ad0fe0..1e55075c7 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -5,9 +5,9 @@ from __future__ import unicode_literals, print_function import io import re import ujson -from pathlib import Path from .syntax import nonproj +from .util import ensure_path def tags_to_entities(tags): @@ -139,12 +139,12 @@ def _min_edit_path(cand_words, gold_words): def read_json_file(loc, docs_filter=None): - loc = Path(loc) + loc = ensure_path(loc) if loc.is_dir(): for filename in loc.iterdir(): yield from read_json_file(loc / filename) else: - with io.open(loc, 'r', encoding='utf8') as file_: + with loc.open('r', encoding='utf8') as file_: docs = ujson.load(file_) for doc in docs: if docs_filter is not None and not docs_filter(doc): diff --git a/spacy/language.py b/spacy/language.py index 4b6c3397d..47408921c 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -204,15 +204,18 @@ class Language(object): @classmethod @contextmanager def train(cls, path, gold_tuples, **configs): - if parser_cfg['pseudoprojective']: + parser_cfg = configs.get('deps', {}) + if parser_cfg.get('pseudoprojective'): # preprocess training data here before ArcEager.get_labels() is called gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) for subdir in ('deps', 'ner', 'pos'): if subdir not in configs: configs[subdir] = {} - configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) - configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) + if parser_cfg: + configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) + if 'ner' in configs: + configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) cls.setup_directory(path, **configs)