mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Handle list values in ud-train for tagger
This commit is contained in:
parent
5f68e491e1
commit
00fa41a924
|
@ -190,6 +190,15 @@ def write_conllu(docs, file_):
|
|||
for k, token in enumerate(sent):
|
||||
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||
file_.write('\n')
|
||||
for word in sent:
|
||||
if word.head.i == word.i and word.dep_ == 'ROOT':
|
||||
break
|
||||
else:
|
||||
print("Rootless sentence!")
|
||||
print(sent)
|
||||
print(i)
|
||||
raise ValueError
|
||||
|
||||
|
||||
|
||||
def print_progress(itn, losses, ud_scores):
|
||||
|
@ -262,7 +271,11 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
|||
nlp.parser.add_multitask_objective('sent_start')
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
for gold in golds:
|
||||
for tag in gold.tags:
|
||||
for i, tag in enumerate(gold.tags):
|
||||
if isinstance(tag, list):
|
||||
for subtag in tag:
|
||||
nlp.tagger.add_label(subtag)
|
||||
else:
|
||||
if tag is not None:
|
||||
nlp.tagger.add_label(tag)
|
||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
|
||||
|
@ -338,7 +351,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
|||
|
||||
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||
|
||||
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
||||
batch_sizes = compounding(config.batch_size, config.batch_size, 1.001)
|
||||
for i in range(config.nr_epoch):
|
||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||
Xs = list(zip(docs, golds))
|
||||
|
|
Loading…
Reference in New Issue
Block a user