mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
* Allow gold tokenization training, for debugging
This commit is contained in:
parent
8da53cbe3c
commit
7a1a333f04
|
@ -21,7 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
|||
from spacy.syntax.parser import GreedyParser
|
||||
from spacy.syntax.parser import OracleError
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.syntax.conll import GoldParse
|
||||
from spacy.syntax.conll import GoldParse, is_punct_label
|
||||
|
||||
|
||||
def is_punct_label(label):
|
||||
|
@ -206,15 +206,22 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
|||
heads_corr = 0
|
||||
pos_corr = 0
|
||||
n_tokens = 0
|
||||
n_all_tokens = 0
|
||||
for gold_sent in gold_sents:
|
||||
tokens = nlp.tokenizer(gold_sent.raw_text)
|
||||
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
|
||||
if gold_preproc:
|
||||
#print ' '.join(gold_sent.words)
|
||||
tokens = nlp.tokenizer.tokens_from_list(gold_sent.words)
|
||||
gold_sent.map_heads(nlp.parser.moves.label_ids)
|
||||
else:
|
||||
tokens = nlp.tokenizer(gold_sent.raw_text)
|
||||
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
|
||||
nlp.tagger(tokens)
|
||||
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
||||
pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
|
||||
n_tokens += len(tokens)
|
||||
n_tokens += gold_sent.n_non_punct
|
||||
n_all_tokens += len(tokens)
|
||||
acc = float(heads_corr) / n_tokens
|
||||
pos_acc = float(pos_corr) / n_tokens
|
||||
pos_acc = float(pos_corr) / n_all_tokens
|
||||
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
||||
random.shuffle(gold_sents)
|
||||
nlp.parser.model.end_training()
|
||||
|
@ -241,21 +248,26 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
|||
nlp.tagger(tokens)
|
||||
nlp.parser(tokens)
|
||||
for i, token in enumerate(tokens):
|
||||
pos_corr += token.tag_ == tag_strs[i]
|
||||
pos_corr += token.tag_ == gold_sent.tags[i]
|
||||
n_tokens += 1
|
||||
if heads[i] is None:
|
||||
if gold_sent.heads[i] is None:
|
||||
skipped += 1
|
||||
continue
|
||||
if is_punct_label(labels[i]):
|
||||
continue
|
||||
uas_corr += token.head.i == heads[i]
|
||||
las_corr += token.head.i == heads[i] and token.dep_ == labels[i]
|
||||
#print token.orth_, token.head.orth_, token.dep_, labels[i]
|
||||
total += 1
|
||||
#print i, token.orth_, token.head.i, gold_sent.py_heads[i], gold_sent.labels[i],
|
||||
#print gold_sent.is_correct(i, token.head.i)
|
||||
if gold_sent.labels[i] != 'P':
|
||||
n_corr += gold_sent.is_correct(i, token.head.i)
|
||||
total += 1
|
||||
print loss, skipped, (loss+skipped + total)
|
||||
print pos_corr / n_tokens
|
||||
print float(las_corr) / (total + loss)
|
||||
return float(uas_corr) / (total + loss)
|
||||
return float(n_corr) / (total + loss)
|
||||
|
||||
|
||||
def read_gold(loc, n=0):
|
||||
sent_strs = open(loc).read().strip().split('\n\n')
|
||||
if n == 0:
|
||||
n = len(sent_strs)
|
||||
return [GoldParse.from_docparse(sent) for sent in sent_strs[:n]]
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
@ -265,8 +277,8 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
|||
n_sents=("Number of training sentences", "option", "n", int)
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0):
|
||||
train(English, read_gold(train_loc, n=n_sents), model_dir,
|
||||
gold_preproc=False, force_gold=False)
|
||||
#train(English, read_gold(train_loc, n=n_sents), model_dir,
|
||||
# gold_preproc=False, force_gold=False)
|
||||
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user