mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
Support oracle segmentation in ud-train CLI command
This commit is contained in:
parent
c49e44349a
commit
fc4dd49b77
|
@ -161,6 +161,16 @@ def golds_to_gold_tuples(docs, golds):
|
||||||
##############
|
##############
|
||||||
|
|
||||||
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
||||||
|
if text_loc.parts[-1].endswith('.conllu'):
|
||||||
|
docs = []
|
||||||
|
with text_loc.open() as file_:
|
||||||
|
for conllu_doc in read_conllu(file_):
|
||||||
|
for conllu_sent in conllu_doc:
|
||||||
|
words = [line[1] for line in conllu_sent]
|
||||||
|
docs.append(Doc(nlp.vocab, words=words))
|
||||||
|
for name, component in nlp.pipeline:
|
||||||
|
docs = list(component.pipe(docs))
|
||||||
|
else:
|
||||||
with text_loc.open('r', encoding='utf8') as text_file:
|
with text_loc.open('r', encoding='utf8') as text_file:
|
||||||
texts = split_text(text_file.read())
|
texts = split_text(text_file.read())
|
||||||
docs = list(nlp.pipe(texts))
|
docs = list(nlp.pipe(texts))
|
||||||
|
@ -261,12 +271,12 @@ def load_nlp(corpus, config, vectors=None):
|
||||||
|
|
||||||
|
|
||||||
def initialize_pipeline(nlp, docs, golds, config, device):
|
def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
|
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
if config.multitask_tag:
|
if config.multitask_tag:
|
||||||
nlp.parser.add_multitask_objective('tag')
|
nlp.parser.add_multitask_objective('tag')
|
||||||
if config.multitask_sent:
|
if config.multitask_sent:
|
||||||
nlp.parser.add_multitask_objective('sent_start')
|
nlp.parser.add_multitask_objective('sent_start')
|
||||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
|
||||||
for gold in golds:
|
for gold in golds:
|
||||||
for tag in gold.tags:
|
for tag in gold.tags:
|
||||||
if tag is not None:
|
if tag is not None:
|
||||||
|
@ -328,10 +338,12 @@ class TreebankPaths(object):
|
||||||
config=("Path to json formatted config file", "positional"),
|
config=("Path to json formatted config file", "positional"),
|
||||||
limit=("Size limit", "option", "n", int),
|
limit=("Size limit", "option", "n", int),
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
|
use_oracle_segments=("Use oracle segments", "flag", "G", int),
|
||||||
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
||||||
"option", "v", Path),
|
"option", "v", Path),
|
||||||
)
|
)
|
||||||
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None):
|
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None,
|
||||||
|
use_oracle_segments=False):
|
||||||
spacy.util.fix_random_seed()
|
spacy.util.fix_random_seed()
|
||||||
lang.zh.Chinese.Defaults.use_jieba = False
|
lang.zh.Chinese.Defaults.use_jieba = False
|
||||||
lang.ja.Japanese.Defaults.use_janome = False
|
lang.ja.Japanese.Defaults.use_janome = False
|
||||||
|
@ -344,13 +356,17 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
||||||
nlp = load_nlp(paths.lang, config, vectors=vectors_dir)
|
nlp = load_nlp(paths.lang, config, vectors=vectors_dir)
|
||||||
|
|
||||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
max_doc_length=config.max_doc_length, limit=limit)
|
max_doc_length=None, limit=limit)
|
||||||
|
|
||||||
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||||
|
|
||||||
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
||||||
|
nlp.parser.cfg['beam_update_prob'] = 1.0
|
||||||
for i in range(config.nr_epoch):
|
for i in range(config.nr_epoch):
|
||||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
|
max_doc_length=config.max_doc_length, limit=limit,
|
||||||
|
oracle_segments=use_oracle_segments,
|
||||||
|
raw_text=not use_oracle_segments)
|
||||||
Xs = list(zip(docs, golds))
|
Xs = list(zip(docs, golds))
|
||||||
random.shuffle(Xs)
|
random.shuffle(Xs)
|
||||||
batches = minibatch_by_words(Xs, size=batch_sizes)
|
batches = minibatch_by_words(Xs, size=batch_sizes)
|
||||||
|
@ -365,7 +381,12 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
||||||
|
|
||||||
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
if use_oracle_segments:
|
||||||
|
parsed_docs, scores = evaluate(nlp, paths.dev.conllu,
|
||||||
|
paths.dev.conllu, out_path)
|
||||||
|
else:
|
||||||
|
parsed_docs, scores = evaluate(nlp, paths.dev.text,
|
||||||
|
paths.dev.conllu, out_path)
|
||||||
print_progress(i, losses, scores)
|
print_progress(i, losses, scores)
|
||||||
_render_parses(i, parsed_docs[:50])
|
_render_parses(i, parsed_docs[:50])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user