mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Shuffle data after each epoch. Improve script
This commit is contained in:
parent
bdb0174571
commit
551c93fe01
|
@ -30,6 +30,7 @@ random.seed(0)
|
|||
numpy.random.seed(0)
|
||||
|
||||
def minibatch_by_words(items, size=5000):
|
||||
random.shuffle(items)
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
|
@ -39,10 +40,17 @@ def minibatch_by_words(items, size=5000):
|
|||
batch_size = next(size_)
|
||||
batch = []
|
||||
while batch_size >= 0:
|
||||
try:
|
||||
doc, gold = next(items)
|
||||
except StopIteration:
|
||||
yield batch
|
||||
return
|
||||
batch_size -= len(doc)
|
||||
batch.append((doc, gold))
|
||||
if batch:
|
||||
yield batch
|
||||
else:
|
||||
break
|
||||
|
||||
################
|
||||
# Data reading #
|
||||
|
@ -146,10 +154,6 @@ def _make_gold(nlp, text, sent_annots):
|
|||
doc = nlp.make_doc(text)
|
||||
flat.pop('spaces')
|
||||
gold = GoldParse(doc, **flat)
|
||||
#for annot in gold.orig_annot:
|
||||
# print(annot)
|
||||
#for i in range(len(doc)):
|
||||
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
||||
return doc, gold
|
||||
|
||||
#############################
|
||||
|
@ -168,12 +172,6 @@ def golds_to_gold_tuples(docs, golds):
|
|||
return tuples
|
||||
|
||||
|
||||
def refresh_docs(docs):
|
||||
vocab = docs[0].vocab
|
||||
return [Doc(vocab, words=[t.text for t in doc],
|
||||
spaces=[t.whitespace_ for t in doc])
|
||||
for doc in docs]
|
||||
|
||||
##############
|
||||
# Evaluation #
|
||||
##############
|
||||
|
@ -357,23 +355,22 @@ class TreebankPaths(object):
|
|||
)
|
||||
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
||||
paths = TreebankPaths(ud_dir, corpus)
|
||||
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||
nlp = load_nlp(paths.lang, config)
|
||||
|
||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||
limit=limit)
|
||||
max_doc_length=config.max_doc_length, limit=limit)
|
||||
|
||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||
n_train_words = sum(len(doc) for doc in docs)
|
||||
print("Begin training (%d words)" % n_train_words)
|
||||
for i in range(config.nr_epoch):
|
||||
docs = refresh_docs(docs)
|
||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||
losses = {}
|
||||
for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
|
||||
if not batch:
|
||||
continue
|
||||
n_train_words = sum(len(doc) for doc in docs)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
for batch in batches:
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
|
||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=config.dropout, losses=losses)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user