Handle list values in ud-train for tagger

This commit is contained in:
Matthew Honnibal 2018-04-01 18:34:28 +02:00
parent 5f68e491e1
commit 00fa41a924

View File

@ -190,6 +190,15 @@ def write_conllu(docs, file_):
for k, token in enumerate(sent):
file_.write(token._.get_conllu_lines(k) + '\n')
file_.write('\n')
for word in sent:
if word.head.i == word.i and word.dep_ == 'ROOT':
break
else:
print("Rootless sentence!")
print(sent)
print(i)
raise ValueError
def print_progress(itn, losses, ud_scores):
@ -262,7 +271,11 @@ def initialize_pipeline(nlp, docs, golds, config, device):
nlp.parser.add_multitask_objective('sent_start')
nlp.add_pipe(nlp.create_pipe('tagger'))
for gold in golds:
for tag in gold.tags:
for i, tag in enumerate(gold.tags):
if isinstance(tag, list):
for subtag in tag:
nlp.tagger.add_label(subtag)
else:
if tag is not None:
nlp.tagger.add_label(tag)
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
@ -338,7 +351,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
batch_sizes = compounding(config.batch_size, config.batch_size, 1.001)
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
Xs = list(zip(docs, golds))