mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Tmp commit of train, while I move to better alignment in gold standard
This commit is contained in:
parent
bdaddc4103
commit
f35503018e
|
@ -11,6 +11,7 @@ import random
|
|||
import plac
|
||||
import cProfile
|
||||
import pstats
|
||||
import re
|
||||
|
||||
import spacy.util
|
||||
from spacy.en import English
|
||||
|
@ -51,11 +52,10 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
gold_tuples = gold_tuples[:n_sents]
|
||||
nlp = Language(data_dir=model_dir)
|
||||
|
||||
print "Itn.\tUAS\tNER F.\tTag %"
|
||||
print "Itn.\tUAS\tNER F.\tTag %\tToken %"
|
||||
for itn in range(n_iter):
|
||||
scorer = Scorer()
|
||||
for raw_text, segmented_text, annot_tuples, ctnt in gold_tuples:
|
||||
# Eval before train
|
||||
tokens = nlp(raw_text, merge_mwes=False)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=False)
|
||||
|
@ -67,12 +67,18 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
for tokens in sents:
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
nlp.tagger(tokens)
|
||||
nlp.parser.train(tokens, gold)
|
||||
try:
|
||||
nlp.parser.train(tokens, gold)
|
||||
except AssertionError:
|
||||
# TODO: Do something about non-projective sentences
|
||||
continue
|
||||
if gold.ents:
|
||||
nlp.entity.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
|
||||
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
|
||||
print '%d:\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f,
|
||||
scorer.tags_acc,
|
||||
scorer.token_acc)
|
||||
random.shuffle(gold_tuples)
|
||||
nlp.parser.model.end_training()
|
||||
nlp.entity.model.end_training()
|
||||
|
@ -106,18 +112,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
|||
|
||||
|
||||
def get_sents(json_dir, section):
|
||||
if section == 'train':
|
||||
file_range = range(2, 22)
|
||||
elif section == 'dev':
|
||||
file_range = range(22, 23)
|
||||
|
||||
for i in file_range:
|
||||
sec = str(i)
|
||||
if len(sec) == 1:
|
||||
sec = '0' + sec
|
||||
loc = path.join(json_dir, sec + '.json')
|
||||
for sent in read_json_file(loc):
|
||||
if path.exists(path.join(json_dir, section + '.json')):
|
||||
for sent in read_json_file(path.join(json_dir, section + '.json')):
|
||||
yield sent
|
||||
else:
|
||||
if section == 'train':
|
||||
file_range = range(2, 22)
|
||||
elif section == 'dev':
|
||||
file_range = range(22, 23)
|
||||
|
||||
for i in file_range:
|
||||
sec = str(i)
|
||||
if len(sec) == 1:
|
||||
sec = '0' + sec
|
||||
loc = path.join(json_dir, sec + '.json')
|
||||
for sent in read_json_file(loc):
|
||||
yield sent
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
@ -137,7 +147,7 @@ def main(json_dir, model_dir, n_sents=0, out_loc="", verbose=False,
|
|||
write_parses(English, dev_loc, model_dir, out_loc)
|
||||
scorer = evaluate(English, list(get_sents(json_dir, 'dev')),
|
||||
model_dir, gold_preproc=False, verbose=verbose)
|
||||
print 'TOK', scorer.mistokened
|
||||
print 'TOK', 100-scorer.token_acc
|
||||
print 'POS', scorer.tags_acc
|
||||
print 'UAS', scorer.uas
|
||||
print 'LAS', scorer.las
|
||||
|
|
Loading…
Reference in New Issue
Block a user