Fix defaults for ud-train

This commit is contained in:
Matthew Honnibal 2018-09-13 18:05:48 +02:00
parent 59cf533879
commit 3eb9f3e2b8

View File

@ -300,17 +300,25 @@ def initialize_pipeline(nlp, docs, golds, config, device):
######################## ########################
class Config(object): class Config(object):
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True, def __init__(self, vectors=None, max_doc_length=10, multitask_tag=False,
multitask_sent=True, multitask_dep=True, multitask_vectors=False, multitask_sent=False, multitask_dep=False, multitask_vectors=None,
nr_epoch=30, min_batch_size=1, max_batch_size=16, batch_by_words=False, nr_epoch=30, min_batch_size=100, max_batch_size=1000,
dropout=0.2, conv_depth=4, subword_features=True): batch_by_words=True, dropout=0.2, conv_depth=4, subword_features=True,
vectors_dir=None):
if vectors_dir is not None:
if vectors is None:
vectors = True
if multitask_vectors is None:
multitask_vectors = True
for key, value in locals().items(): for key, value in locals().items():
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
def load(cls, loc): def load(cls, loc, vectors_dir=None):
with Path(loc).open('r', encoding='utf8') as file_: with Path(loc).open('r', encoding='utf8') as file_:
cfg = json.load(file_) cfg = json.load(file_)
if vectors_dir is not None:
cfg['vectors_dir'] = vectors_dir
return cls(**cfg) return cls(**cfg)
@ -353,16 +361,16 @@ class TreebankPaths(object):
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/", vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
"option", "v", Path), "option", "v", Path),
) )
def main(ud_dir, parses_dir, config=None, corpus, limit=0, use_gpu=-1, vectors_dir=None, def main(ud_dir, parses_dir, corpus, config=None, limit=0, use_gpu=-1, vectors_dir=None,
use_oracle_segments=False): use_oracle_segments=False):
spacy.util.fix_random_seed() spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False lang.ja.Japanese.Defaults.use_janome = False
if config is not None: if config is not None:
config = Config.load(config) config = Config.load(config, vectors_dir=vectors_dir)
else: else:
config = Config() config = Config(vectors_dir=vectors_dir)
paths = TreebankPaths(ud_dir, corpus) paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists(): if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir() (parses_dir / corpus).mkdir()