mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-03 20:53:12 +03:00
Update ud_train script
This commit is contained in:
parent
5de8a36537
commit
17af6aa3a4
|
@ -247,12 +247,18 @@ Token.set_extension('inside_fused', default=False)
|
||||||
##################
|
##################
|
||||||
|
|
||||||
|
|
||||||
def load_nlp(corpus, config):
|
def load_nlp(corpus, config, vectors=None):
|
||||||
lang = corpus.split('_')[0]
|
lang = corpus.split('_')[0]
|
||||||
nlp = spacy.blank(lang)
|
nlp = spacy.blank(lang)
|
||||||
if config.vectors:
|
if config.vectors:
|
||||||
nlp.vocab.from_disk(Path(config.vectors) / 'vocab')
|
if not vectors:
|
||||||
|
raise ValueError("config asks for vectors, but no vectors "
|
||||||
|
"directory set on command line (use -v)")
|
||||||
|
if (Path(vectors) / corpus).exists():
|
||||||
|
nlp.vocab.from_disk(Path(vectors) / corpus / 'vocab')
|
||||||
|
nlp.meta['treebank'] = corpus
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def initialize_pipeline(nlp, docs, golds, config, device):
|
def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
|
@ -274,10 +280,12 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
|
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
|
||||||
multitask_sent=True, nr_epoch=30, batch_size=1000, dropout=0.2):
|
multitask_sent=True, multitask_dep=True, multitask_vectors=False,
|
||||||
|
nr_epoch=30, batch_size=1000, dropout=0.2):
|
||||||
for key, value in locals().items():
|
for key, value in locals().items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, loc):
|
def load(cls, loc):
|
||||||
with Path(loc).open('r', encoding='utf8') as file_:
|
with Path(loc).open('r', encoding='utf8') as file_:
|
||||||
|
@ -319,9 +327,11 @@ class TreebankPaths(object):
|
||||||
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||||
config=("Path to json formatted config file", "positional"),
|
config=("Path to json formatted config file", "positional"),
|
||||||
limit=("Size limit", "option", "n", int),
|
limit=("Size limit", "option", "n", int),
|
||||||
use_gpu=("Use GPU", "option", "g", int)
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
|
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
||||||
|
"option", "v", Path),
|
||||||
)
|
)
|
||||||
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None):
|
||||||
spacy.util.fix_random_seed()
|
spacy.util.fix_random_seed()
|
||||||
lang.zh.Chinese.Defaults.use_jieba = False
|
lang.zh.Chinese.Defaults.use_jieba = False
|
||||||
lang.ja.Japanese.Defaults.use_janome = False
|
lang.ja.Japanese.Defaults.use_janome = False
|
||||||
|
@ -331,7 +341,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
||||||
if not (parses_dir / corpus).exists():
|
if not (parses_dir / corpus).exists():
|
||||||
(parses_dir / corpus).mkdir()
|
(parses_dir / corpus).mkdir()
|
||||||
print("Train and evaluate", corpus, "using lang", paths.lang)
|
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||||
nlp = load_nlp(paths.lang, config)
|
nlp = load_nlp(paths.lang, config, vectors=vectors_dir)
|
||||||
|
|
||||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
max_doc_length=config.max_doc_length, limit=limit)
|
max_doc_length=config.max_doc_length, limit=limit)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user