mirror of
https://github.com/explosion/spaCy.git
synced 2024-09-21 03:19:13 +03:00
* Tmp commit of train, while I move to better alignment in gold standard
This commit is contained in:
parent
bdaddc4103
commit
f35503018e
|
@ -11,6 +11,7 @@ import random
|
||||||
import plac
|
import plac
|
||||||
import cProfile
|
import cProfile
|
||||||
import pstats
|
import pstats
|
||||||
|
import re
|
||||||
|
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.en import English
|
from spacy.en import English
|
||||||
|
@ -51,11 +52,10 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
gold_tuples = gold_tuples[:n_sents]
|
gold_tuples = gold_tuples[:n_sents]
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir)
|
||||||
|
|
||||||
print "Itn.\tUAS\tNER F.\tTag %"
|
print "Itn.\tUAS\tNER F.\tTag %\tToken %"
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for raw_text, segmented_text, annot_tuples, ctnt in gold_tuples:
|
for raw_text, segmented_text, annot_tuples, ctnt in gold_tuples:
|
||||||
# Eval before train
|
|
||||||
tokens = nlp(raw_text, merge_mwes=False)
|
tokens = nlp(raw_text, merge_mwes=False)
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
scorer.score(tokens, gold, verbose=False)
|
scorer.score(tokens, gold, verbose=False)
|
||||||
|
@ -67,12 +67,18 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
for tokens in sents:
|
for tokens in sents:
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser.train(tokens, gold)
|
try:
|
||||||
|
nlp.parser.train(tokens, gold)
|
||||||
|
except AssertionError:
|
||||||
|
# TODO: Do something about non-projective sentences
|
||||||
|
continue
|
||||||
if gold.ents:
|
if gold.ents:
|
||||||
nlp.entity.train(tokens, gold)
|
nlp.entity.train(tokens, gold)
|
||||||
nlp.tagger.train(tokens, gold.tags)
|
nlp.tagger.train(tokens, gold.tags)
|
||||||
|
|
||||||
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
|
print '%d:\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f,
|
||||||
|
scorer.tags_acc,
|
||||||
|
scorer.token_acc)
|
||||||
random.shuffle(gold_tuples)
|
random.shuffle(gold_tuples)
|
||||||
nlp.parser.model.end_training()
|
nlp.parser.model.end_training()
|
||||||
nlp.entity.model.end_training()
|
nlp.entity.model.end_training()
|
||||||
|
@ -106,18 +112,22 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
|
|
||||||
|
|
||||||
def get_sents(json_dir, section):
|
def get_sents(json_dir, section):
|
||||||
if section == 'train':
|
if path.exists(path.join(json_dir, section + '.json')):
|
||||||
file_range = range(2, 22)
|
for sent in read_json_file(path.join(json_dir, section + '.json')):
|
||||||
elif section == 'dev':
|
|
||||||
file_range = range(22, 23)
|
|
||||||
|
|
||||||
for i in file_range:
|
|
||||||
sec = str(i)
|
|
||||||
if len(sec) == 1:
|
|
||||||
sec = '0' + sec
|
|
||||||
loc = path.join(json_dir, sec + '.json')
|
|
||||||
for sent in read_json_file(loc):
|
|
||||||
yield sent
|
yield sent
|
||||||
|
else:
|
||||||
|
if section == 'train':
|
||||||
|
file_range = range(2, 22)
|
||||||
|
elif section == 'dev':
|
||||||
|
file_range = range(22, 23)
|
||||||
|
|
||||||
|
for i in file_range:
|
||||||
|
sec = str(i)
|
||||||
|
if len(sec) == 1:
|
||||||
|
sec = '0' + sec
|
||||||
|
loc = path.join(json_dir, sec + '.json')
|
||||||
|
for sent in read_json_file(loc):
|
||||||
|
yield sent
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
|
@ -137,7 +147,7 @@ def main(json_dir, model_dir, n_sents=0, out_loc="", verbose=False,
|
||||||
write_parses(English, dev_loc, model_dir, out_loc)
|
write_parses(English, dev_loc, model_dir, out_loc)
|
||||||
scorer = evaluate(English, list(get_sents(json_dir, 'dev')),
|
scorer = evaluate(English, list(get_sents(json_dir, 'dev')),
|
||||||
model_dir, gold_preproc=False, verbose=verbose)
|
model_dir, gold_preproc=False, verbose=verbose)
|
||||||
print 'TOK', scorer.mistokened
|
print 'TOK', 100-scorer.token_acc
|
||||||
print 'POS', scorer.tags_acc
|
print 'POS', scorer.tags_acc
|
||||||
print 'UAS', scorer.uas
|
print 'UAS', scorer.uas
|
||||||
print 'LAS', scorer.las
|
print 'LAS', scorer.las
|
||||||
|
|
Loading…
Reference in New Issue
Block a user