mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Improve printing in train_ud script
This commit is contained in:
parent
ca9c8c57c0
commit
a155482fda
|
@ -24,9 +24,10 @@ import io
|
|||
|
||||
|
||||
|
||||
def read_conllx(loc):
|
||||
def read_conllx(loc, n=0):
|
||||
with io.open(loc, 'r', encoding='utf8') as file_:
|
||||
text = file_.read()
|
||||
i = 0
|
||||
for sent in text.strip().split('\n\n'):
|
||||
lines = sent.strip().split('\n')
|
||||
if lines:
|
||||
|
@ -47,6 +48,9 @@ def read_conllx(loc):
|
|||
raise
|
||||
tuples = [list(t) for t in zip(*tokens)]
|
||||
yield (None, [[tuples, []]])
|
||||
i += 1
|
||||
if n >= 1 and i >= n:
|
||||
break
|
||||
|
||||
|
||||
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
||||
|
@ -73,7 +77,7 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
|
|||
|
||||
actions = ArcEager.get_actions(gold_parses=train_sents)
|
||||
features = get_templates('basic')
|
||||
|
||||
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
if not (model_dir / 'deps').exists():
|
||||
(model_dir / 'deps').mkdir()
|
||||
|
@ -95,25 +99,26 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
|
|||
for tag in tags:
|
||||
assert tag in tag_map, repr(tag)
|
||||
tagger = Tagger(vocab, tag_map=tag_map)
|
||||
parser = DependencyParser(vocab, actions=actions, features=features)
|
||||
|
||||
parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0)
|
||||
|
||||
for itn in range(15):
|
||||
loss = 0.
|
||||
for _, doc_sents in train_sents:
|
||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||
doc = Doc(vocab, words=words)
|
||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||
tagger(doc)
|
||||
parser.update(doc, gold)
|
||||
loss += parser.update(doc, gold, itn=itn)
|
||||
doc = Doc(vocab, words=words)
|
||||
tagger.update(doc, gold)
|
||||
random.shuffle(train_sents)
|
||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
|
||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc))
|
||||
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
|
||||
nlp.end_training(model_dir)
|
||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
||||
|
|
Loading…
Reference in New Issue
Block a user