mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* Report LAS in train script
This commit is contained in:
parent
b07632a9ef
commit
053814ffc8
|
@ -226,7 +226,8 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
|||
def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||
global loss
|
||||
nlp = Language()
|
||||
n_corr = 0
|
||||
uas_corr = 0
|
||||
las_corr = 0
|
||||
pos_corr = 0
|
||||
n_tokens = 0
|
||||
total = 0
|
||||
|
@ -251,11 +252,14 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
|||
continue
|
||||
if is_punct_label(labels[i]):
|
||||
continue
|
||||
n_corr += token.head.i == heads[i]
|
||||
uas_corr += token.head.i == heads[i]
|
||||
las_corr += token.head.i == heads[i] and token.dep_ == labels[i]
|
||||
#print token.orth_, token.head.orth_, token.dep_, labels[i]
|
||||
total += 1
|
||||
print loss, skipped, (loss+skipped + total)
|
||||
print pos_corr / n_tokens
|
||||
return float(n_corr) / (total + loss)
|
||||
print float(las_corr) / (total + loss)
|
||||
return float(uas_corr) / (total + loss)
|
||||
|
||||
|
||||
def main(train_loc, dev_loc, model_dir):
|
||||
|
|
Loading…
Reference in New Issue
Block a user