mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Handle complex tags in ud-train
This commit is contained in:
parent
6d42e0ad8e
commit
6aded3d855
|
@ -274,9 +274,13 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
for i, tag in enumerate(gold.tags):
|
for i, tag in enumerate(gold.tags):
|
||||||
if isinstance(tag, list):
|
if isinstance(tag, list):
|
||||||
for subtag in tag:
|
for subtag in tag:
|
||||||
|
if isinstance(subtag, tuple):
|
||||||
|
subtag = subtag[0]
|
||||||
nlp.tagger.add_label(subtag)
|
nlp.tagger.add_label(subtag)
|
||||||
else:
|
else:
|
||||||
if tag is not None:
|
if tag is not None:
|
||||||
|
if isinstance(tag, tuple):
|
||||||
|
tag = tag[0]
|
||||||
nlp.tagger.add_label(tag)
|
nlp.tagger.add_label(tag)
|
||||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
|
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
|
||||||
|
|
||||||
|
@ -361,7 +365,11 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
||||||
n_train_words = sum(len(doc) for doc in docs)
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
|
if not batch:
|
||||||
|
continue
|
||||||
batch_docs, batch_gold = zip(*batch)
|
batch_docs, batch_gold = zip(*batch)
|
||||||
|
batch_docs = list(batch_docs)
|
||||||
|
batch_gold = list(batch_gold)
|
||||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||||
drop=config.dropout, losses=losses)
|
drop=config.dropout, losses=losses)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user