mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Fix training methods
This commit is contained in:
parent
6a4221a6de
commit
89a4f262fc
|
@ -2,8 +2,8 @@
|
||||||
from __future__ import unicode_literals, division, print_function
|
from __future__ import unicode_literals, division, print_function
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
from ..util import ensure_path
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from ..gold import GoldParse, merge_sents
|
from ..gold import GoldParse, merge_sents
|
||||||
from ..gold import read_json_file as read_gold_json
|
from ..gold import read_json_file as read_gold_json
|
||||||
|
@ -12,9 +12,9 @@ from .. import util
|
||||||
|
|
||||||
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
|
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner,
|
||||||
parser_L1):
|
parser_L1):
|
||||||
output_path = Path(output_dir)
|
output_path = ensure_path(output_dir)
|
||||||
train_path = Path(train_data)
|
train_path = ensure_path(train_data)
|
||||||
dev_path = Path(dev_data)
|
dev_path = ensure_path(dev_data)
|
||||||
check_dirs(output_path, train_path, dev_path)
|
check_dirs(output_path, train_path, dev_path)
|
||||||
|
|
||||||
lang = util.get_lang_class(language)
|
lang = util.get_lang_class(language)
|
||||||
|
@ -43,7 +43,7 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
|
||||||
|
|
||||||
|
|
||||||
def train_config(config):
|
def train_config(config):
|
||||||
config_path = Path(config)
|
config_path = ensure_path(config)
|
||||||
if not config_path.is_file():
|
if not config_path.is_file():
|
||||||
util.sys_exit(config_path.as_posix(), title="Config file not found")
|
util.sys_exit(config_path.as_posix(), title="Config file not found")
|
||||||
config = json.load(config_path)
|
config = json.load(config_path)
|
||||||
|
@ -57,7 +57,8 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
|
||||||
entity_cfg, n_iter):
|
entity_cfg, n_iter):
|
||||||
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
|
print("Itn.\tN weight\tN feats\tUAS\tNER F.\tTag %\tToken %")
|
||||||
|
|
||||||
with Language.train(output_path, train_data, tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
with Language.train(output_path, train_data,
|
||||||
|
pos=tagger_cfg, deps=parser_cfg, ner=entity_cfg) as trainer:
|
||||||
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
||||||
for doc, gold in epoch:
|
for doc, gold in epoch:
|
||||||
trainer.update(doc, gold)
|
trainer.update(doc, gold)
|
||||||
|
|
|
@ -5,9 +5,9 @@ from __future__ import unicode_literals, print_function
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
import ujson
|
import ujson
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .syntax import nonproj
|
from .syntax import nonproj
|
||||||
|
from .util import ensure_path
|
||||||
|
|
||||||
|
|
||||||
def tags_to_entities(tags):
|
def tags_to_entities(tags):
|
||||||
|
@ -139,12 +139,12 @@ def _min_edit_path(cand_words, gold_words):
|
||||||
|
|
||||||
|
|
||||||
def read_json_file(loc, docs_filter=None):
|
def read_json_file(loc, docs_filter=None):
|
||||||
loc = Path(loc)
|
loc = ensure_path(loc)
|
||||||
if loc.is_dir():
|
if loc.is_dir():
|
||||||
for filename in loc.iterdir():
|
for filename in loc.iterdir():
|
||||||
yield from read_json_file(loc / filename)
|
yield from read_json_file(loc / filename)
|
||||||
else:
|
else:
|
||||||
with io.open(loc, 'r', encoding='utf8') as file_:
|
with loc.open('r', encoding='utf8') as file_:
|
||||||
docs = ujson.load(file_)
|
docs = ujson.load(file_)
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if docs_filter is not None and not docs_filter(doc):
|
if docs_filter is not None and not docs_filter(doc):
|
||||||
|
|
|
@ -204,15 +204,18 @@ class Language(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def train(cls, path, gold_tuples, **configs):
|
def train(cls, path, gold_tuples, **configs):
|
||||||
if parser_cfg['pseudoprojective']:
|
parser_cfg = configs.get('deps', {})
|
||||||
|
if parser_cfg.get('pseudoprojective'):
|
||||||
# preprocess training data here before ArcEager.get_labels() is called
|
# preprocess training data here before ArcEager.get_labels() is called
|
||||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||||
|
|
||||||
for subdir in ('deps', 'ner', 'pos'):
|
for subdir in ('deps', 'ner', 'pos'):
|
||||||
if subdir not in configs:
|
if subdir not in configs:
|
||||||
configs[subdir] = {}
|
configs[subdir] = {}
|
||||||
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
if parser_cfg:
|
||||||
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
configs['deps']['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||||
|
if 'ner' in configs:
|
||||||
|
configs['ner']['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||||
|
|
||||||
cls.setup_directory(path, **configs)
|
cls.setup_directory(path, **configs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user