mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix training methods
This commit is contained in:
parent
6a4221a6de
commit
89a4f262fc
|
@ -2,8 +2,8 @@
|
|||
from __future__ import unicode_literals, division, print_function
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from ..util import ensure_path
|
||||
from ..scorer import Scorer
|
||||
from ..gold import GoldParse, merge_sents
|
||||
from ..gold import read_json_file as read_gold_json
|
||||
|
@ -12,9 +12,9 @@ from .. import util
|
|||
|
||||
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
|
||||
parser_L1):
|
||||
output_path = Path(output_dir)
|
||||
train_path = Path(train_data)
|
||||
dev_path = Path(dev_data)
|
||||
output_path = ensure_path(output_dir)
|
||||
train_path = ensure_path(train_data)
|
||||
dev_path = ensure_path(dev_data)
|
||||
check_dirs(output_path, train_path, dev_path)
|
||||
|
||||
lang = util.get_lang_class(language)
|
||||
|
@ -43,7 +43,7 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
|
|||
|
||||
|
||||
def train_config(config):
|
||||
config_path = Path(config)
|
||||
config_path = ensure_path(config)
|
||||
if not config_path.is_file():
|
||||
util.sys_exit(config_path.as_posix(), title="Config file not found")
|
||||
config = json.load(config_path)
|
||||
|
@ -57,7 +57,8 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
|
|||
entity_cfg, n_iter):
|
||||
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
|
||||
|
||||
with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
||||
with Language.train(output_path, train_data,
|
||||
pos=tagger_cfg, deps=parser_cfg, ner=entity_cfg) as trainer:
|
||||
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
||||
for doc, gold in epoch:
|
||||
trainer.update(doc, gold)
|
||||
|
|
|
@ -5,9 +5,9 @@ from __future__ import unicode_literals, print_function
|
|||
import io
|
||||
import re
|
||||
import ujson
|
||||
from pathlib import Path
|
||||
|
||||
from .syntax import nonproj
|
||||
from .util import ensure_path
|
||||
|
||||
|
||||
def tags_to_entities(tags):
|
||||
|
@ -139,12 +139,12 @@ def _min_edit_path(cand_words, gold_words):
|
|||
|
||||
|
||||
def read_json_file(loc, docs_filter=None):
|
||||
loc = Path(loc)
|
||||
loc = ensure_path(loc)
|
||||
if loc.is_dir():
|
||||
for filename in loc.iterdir():
|
||||
yield from read_json_file(loc / filename)
|
||||
else:
|
||||
with io.open(loc, 'r', encoding='utf8') as file_:
|
||||
with loc.open('r', encoding='utf8') as file_:
|
||||
docs = ujson.load(file_)
|
||||
for doc in docs:
|
||||
if docs_filter is not None and not docs_filter(doc):
|
||||
|
|
|
@ -204,15 +204,18 @@ class Language(object):
|
|||
@classmethod
|
||||
@contextmanager
|
||||
def train(cls, path, gold_tuples, **configs):
|
||||
if parser_cfg['pseudoprojective']:
|
||||
parser_cfg = configs.get('deps', {})
|
||||
if parser_cfg.get('pseudoprojective'):
|
||||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
|
||||
for subdir in ('deps', 'ner', 'pos'):
|
||||
if subdir not in configs:
|
||||
configs[subdir] = {}
|
||||
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
if parser_cfg:
|
||||
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
if 'ner' in configs:
|
||||
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
|
||||
cls.setup_directory(path, **configs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user