mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Improve printing in train_ud script
This commit is contained in:
parent
ca9c8c57c0
commit
a155482fda
|
@ -24,9 +24,10 @@ import io
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_conllx(loc):
|
def read_conllx(loc, n=0):
|
||||||
with io.open(loc, 'r', encoding='utf8') as file_:
|
with io.open(loc, 'r', encoding='utf8') as file_:
|
||||||
text = file_.read()
|
text = file_.read()
|
||||||
|
i = 0
|
||||||
for sent in text.strip().split('\n\n'):
|
for sent in text.strip().split('\n\n'):
|
||||||
lines = sent.strip().split('\n')
|
lines = sent.strip().split('\n')
|
||||||
if lines:
|
if lines:
|
||||||
|
@ -47,6 +48,9 @@ def read_conllx(loc):
|
||||||
raise
|
raise
|
||||||
tuples = [list(t) for t in zip(*tokens)]
|
tuples = [list(t) for t in zip(*tokens)]
|
||||||
yield (None, [[tuples, []]])
|
yield (None, [[tuples, []]])
|
||||||
|
i += 1
|
||||||
|
if n >= 1 and i >= n:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
||||||
|
@ -95,20 +99,21 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
assert tag in tag_map, repr(tag)
|
assert tag in tag_map, repr(tag)
|
||||||
tagger = Tagger(vocab, tag_map=tag_map)
|
tagger = Tagger(vocab, tag_map=tag_map)
|
||||||
parser = DependencyParser(vocab, actions=actions, features=features)
|
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
|
||||||
|
|
||||||
for itn in range(15):
|
for itn in range(15):
|
||||||
|
loss = 0.
|
||||||
for _, doc_sents in train_sents:
|
for _, doc_sents in train_sents:
|
||||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||||
doc = Doc(vocab, words=words)
|
doc = Doc(vocab, words=words)
|
||||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||||
tagger(doc)
|
tagger(doc)
|
||||||
parser.update(doc, gold)
|
loss += parser.update(doc, gold, itn=itn)
|
||||||
doc = Doc(vocab, words=words)
|
doc = Doc(vocab, words=words)
|
||||||
tagger.update(doc, gold)
|
tagger.update(doc, gold)
|
||||||
random.shuffle(train_sents)
|
random.shuffle(train_sents)
|
||||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||||
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
|
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
|
||||||
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
|
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
|
||||||
nlp.end_training(model_dir)
|
nlp.end_training(model_dir)
|
||||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user