mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Use rising beam update prob
This commit is contained in:
parent
544ae7f1db
commit
74d5c625b3
|
@ -370,7 +370,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
|||
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||
|
||||
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
||||
nlp.parser.cfg['beam_update_prob'] = 1.0
|
||||
beam_prob = compounding(0.2, 0.8, 1.001)
|
||||
for i in range(config.nr_epoch):
|
||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||
max_doc_length=config.max_doc_length, limit=limit,
|
||||
|
@ -385,6 +385,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No
|
|||
for batch in batches:
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||
nlp.parser.cfg['beam_update_prob'] = next(beam_prob)
|
||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=config.dropout, losses=losses)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user