Handle complex tags in ud-train

This commit is contained in:
Matthew Honnibal 2018-04-03 01:57:37 +02:00
parent 6d42e0ad8e
commit 6aded3d855

View File

@ -274,9 +274,13 @@ def initialize_pipeline(nlp, docs, golds, config, device):
for i, tag in enumerate(gold.tags): for i, tag in enumerate(gold.tags):
if isinstance(tag, list): if isinstance(tag, list):
for subtag in tag: for subtag in tag:
if isinstance(subtag, tuple):
subtag = subtag[0]
nlp.tagger.add_label(subtag) nlp.tagger.add_label(subtag)
else: else:
if tag is not None: if tag is not None:
if isinstance(tag, tuple):
tag = tag[0]
nlp.tagger.add_label(tag) nlp.tagger.add_label(tag)
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device) return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
@ -361,7 +365,11 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
n_train_words = sum(len(doc) for doc in docs) n_train_words = sum(len(doc) for doc in docs)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar: with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches: for batch in batches:
if not batch:
continue
batch_docs, batch_gold = zip(*batch) batch_docs, batch_gold = zip(*batch)
batch_docs = list(batch_docs)
batch_gold = list(batch_gold)
pbar.update(sum(len(doc) for doc in batch_docs)) pbar.update(sum(len(doc) for doc in batch_docs))
nlp.update(batch_docs, batch_gold, sgd=optimizer, nlp.update(batch_docs, batch_gold, sgd=optimizer,
drop=config.dropout, losses=losses) drop=config.dropout, losses=losses)