From 3eb9f3e2b84361de37e8c7d08df75b2ead01aea0 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 13 Sep 2018 18:05:48 +0200 Subject: [PATCH] Fix defaults for ud-train --- spacy/cli/ud_train.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index f8714fa33..71a0083bb 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -300,17 +300,25 @@ def initialize_pipeline(nlp, docs, golds, config, device): ######################## class Config(object): - def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True, - multitask_sent=True, multitask_dep=True, multitask_vectors=False, - nr_epoch=30, min_batch_size=1, max_batch_size=16, batch_by_words=False, - dropout=0.2, conv_depth=4, subword_features=True): + def __init__(self, vectors=None, max_doc_length=10, multitask_tag=False, + multitask_sent=False, multitask_dep=False, multitask_vectors=None, + nr_epoch=30, min_batch_size=100, max_batch_size=1000, + batch_by_words=True, dropout=0.2, conv_depth=4, subword_features=True, + vectors_dir=None): + if vectors_dir is not None: + if vectors is None: + vectors = True + if multitask_vectors is None: + multitask_vectors = True for key, value in locals().items(): setattr(self, key, value) @classmethod - def load(cls, loc): + def load(cls, loc, vectors_dir=None): with Path(loc).open('r', encoding='utf8') as file_: cfg = json.load(file_) + if vectors_dir is not None: + cfg['vectors_dir'] = vectors_dir return cls(**cfg) @@ -353,16 +361,16 @@ class TreebankPaths(object): vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/", "option", "v", Path), ) -def main(ud_dir, parses_dir, config=None, corpus, limit=0, use_gpu=-1, vectors_dir=None, +def main(ud_dir, parses_dir, corpus, config=None, limit=0, use_gpu=-1, vectors_dir=None, use_oracle_segments=False): spacy.util.fix_random_seed() lang.zh.Chinese.Defaults.use_jieba = False lang.ja.Japanese.Defaults.use_janome = False if config is not None: - config = Config.load(config) + config = Config.load(config, vectors_dir=vectors_dir) else: - config = Config() + config = Config(vectors_dir=vectors_dir) paths = TreebankPaths(ud_dir, corpus) if not (parses_dir / corpus).exists(): (parses_dir / corpus).mkdir()