mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-28 02:04:07 +03:00
Shuffle data after each epoch. Improve script
This commit is contained in:
parent
bdb0174571
commit
551c93fe01
|
@ -30,6 +30,7 @@ random.seed(0)
|
||||||
numpy.random.seed(0)
|
numpy.random.seed(0)
|
||||||
|
|
||||||
def minibatch_by_words(items, size=5000):
|
def minibatch_by_words(items, size=5000):
|
||||||
|
random.shuffle(items)
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
else:
|
else:
|
||||||
|
@ -39,10 +40,17 @@ def minibatch_by_words(items, size=5000):
|
||||||
batch_size = next(size_)
|
batch_size = next(size_)
|
||||||
batch = []
|
batch = []
|
||||||
while batch_size >= 0:
|
while batch_size >= 0:
|
||||||
|
try:
|
||||||
doc, gold = next(items)
|
doc, gold = next(items)
|
||||||
|
except StopIteration:
|
||||||
|
yield batch
|
||||||
|
return
|
||||||
batch_size -= len(doc)
|
batch_size -= len(doc)
|
||||||
batch.append((doc, gold))
|
batch.append((doc, gold))
|
||||||
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
################
|
################
|
||||||
# Data reading #
|
# Data reading #
|
||||||
|
@ -146,10 +154,6 @@ def _make_gold(nlp, text, sent_annots):
|
||||||
doc = nlp.make_doc(text)
|
doc = nlp.make_doc(text)
|
||||||
flat.pop('spaces')
|
flat.pop('spaces')
|
||||||
gold = GoldParse(doc, **flat)
|
gold = GoldParse(doc, **flat)
|
||||||
#for annot in gold.orig_annot:
|
|
||||||
# print(annot)
|
|
||||||
#for i in range(len(doc)):
|
|
||||||
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
|
||||||
return doc, gold
|
return doc, gold
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
|
@ -168,12 +172,6 @@ def golds_to_gold_tuples(docs, golds):
|
||||||
return tuples
|
return tuples
|
||||||
|
|
||||||
|
|
||||||
def refresh_docs(docs):
|
|
||||||
vocab = docs[0].vocab
|
|
||||||
return [Doc(vocab, words=[t.text for t in doc],
|
|
||||||
spaces=[t.whitespace_ for t in doc])
|
|
||||||
for doc in docs]
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Evaluation #
|
# Evaluation #
|
||||||
##############
|
##############
|
||||||
|
@ -357,23 +355,22 @@ class TreebankPaths(object):
|
||||||
)
|
)
|
||||||
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
||||||
paths = TreebankPaths(ud_dir, corpus)
|
paths = TreebankPaths(ud_dir, corpus)
|
||||||
|
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||||
nlp = load_nlp(paths.lang, config)
|
nlp = load_nlp(paths.lang, config)
|
||||||
|
|
||||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
limit=limit)
|
max_doc_length=config.max_doc_length, limit=limit)
|
||||||
|
|
||||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||||
n_train_words = sum(len(doc) for doc in docs)
|
|
||||||
print("Begin training (%d words)" % n_train_words)
|
|
||||||
for i in range(config.nr_epoch):
|
for i in range(config.nr_epoch):
|
||||||
docs = refresh_docs(docs)
|
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
if not batch:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
continue
|
for batch in batches:
|
||||||
batch_docs, batch_gold = zip(*batch)
|
batch_docs, batch_gold = zip(*batch)
|
||||||
|
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||||
drop=config.dropout, losses=losses)
|
drop=config.dropout, losses=losses)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user