mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 07:04:53 +03:00
Handle complex tags in ud-train
This commit is contained in:
parent
6d42e0ad8e
commit
6aded3d855
|
@ -274,9 +274,13 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
|||
for i, tag in enumerate(gold.tags):
|
||||
if isinstance(tag, list):
|
||||
for subtag in tag:
|
||||
if isinstance(subtag, tuple):
|
||||
subtag = subtag[0]
|
||||
nlp.tagger.add_label(subtag)
|
||||
else:
|
||||
if tag is not None:
|
||||
if isinstance(tag, tuple):
|
||||
tag = tag[0]
|
||||
nlp.tagger.add_label(tag)
|
||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
|
||||
|
||||
|
@ -361,7 +365,11 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
|||
n_train_words = sum(len(doc) for doc in docs)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
for batch in batches:
|
||||
if not batch:
|
||||
continue
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
batch_docs = list(batch_docs)
|
||||
batch_gold = list(batch_gold)
|
||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=config.dropout, losses=losses)
|
||||
|
|
Loading…
Reference in New Issue
Block a user