mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Allow parsers and taggers to be trained on text without gold pre-processing.
This commit is contained in:
parent
67d6e53a69
commit
ca7577d8a9
|
@ -9,6 +9,7 @@ import codecs
|
|||
import random
|
||||
import time
|
||||
import gzip
|
||||
import nltk
|
||||
|
||||
import plac
|
||||
import cProfile
|
||||
|
@ -22,6 +23,10 @@ from spacy.syntax.parser import GreedyParser
|
|||
from spacy.syntax.util import Config
|
||||
|
||||
|
||||
def is_punct_label(label):
|
||||
return label == 'P' or label.lower() == 'punct'
|
||||
|
||||
|
||||
def read_tokenized_gold(file_):
|
||||
"""Read a standard CoNLL/MALT-style format"""
|
||||
sents = []
|
||||
|
@ -96,21 +101,21 @@ def _parse_line(line):
|
|||
|
||||
|
||||
|
||||
loss = 0
|
||||
def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
||||
global loss
|
||||
tags = []
|
||||
heads = []
|
||||
labels = []
|
||||
loss = 0
|
||||
print [t.orth_ for t in tokens]
|
||||
print words
|
||||
orig_words = list(words)
|
||||
missed = []
|
||||
for token in tokens:
|
||||
print token.orth_, words[0]
|
||||
print token.idx, annot[0][0]
|
||||
while annot and token.idx > annot[0][0]:
|
||||
print 'pop', token.idx, annot[0][0]
|
||||
annot.pop(0)
|
||||
words.pop(0)
|
||||
loss += 1
|
||||
miss_id, miss_tag, miss_head, miss_label = annot.pop(0)
|
||||
miss_w = words.pop(0)
|
||||
if not is_punct_label(miss_label):
|
||||
missed.append(miss_w)
|
||||
loss += 1
|
||||
if not annot:
|
||||
tags.append(None)
|
||||
heads.append(None)
|
||||
|
@ -129,6 +134,11 @@ def _align_annotations_to_non_gold_tokens(tokens, words, annot):
|
|||
labels.append(None)
|
||||
else:
|
||||
raise StandardError
|
||||
#if missed:
|
||||
# print orig_words
|
||||
# print missed
|
||||
# for t in tokens:
|
||||
# print t.idx, t.orth_
|
||||
return loss, tags, heads, labels
|
||||
|
||||
|
||||
|
@ -137,7 +147,8 @@ def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
|||
if not gold_preproc:
|
||||
tokens = tokenizer(raw)
|
||||
loss, tags, heads, labels = _align_annotations_to_non_gold_tokens(
|
||||
tokens, words, zip(ids, tags, heads, labels))
|
||||
tokens, list(words),
|
||||
zip(ids, tags, heads, labels))
|
||||
ids = [t.idx for t in tokens]
|
||||
heads = _map_indices_to_tokens(ids, heads)
|
||||
yield tokens, tags, heads, labels
|
||||
|
@ -170,7 +181,7 @@ def get_labels(sents):
|
|||
|
||||
|
||||
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||
gold_preproc=True):
|
||||
gold_preproc=False):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
if path.exists(dep_model_dir):
|
||||
|
@ -194,10 +205,9 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
|||
n_tokens = 0
|
||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||
gold_preproc=gold_preproc):
|
||||
tags = [nlp.tagger.tag_names.index(tag) for tag in tag_strs]
|
||||
nlp.tagger(tokens)
|
||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
||||
pos_corr += nlp.tagger.train(tokens, tags)
|
||||
pos_corr += nlp.tagger.train(tokens, tag_strs)
|
||||
n_tokens += len(tokens)
|
||||
acc = float(heads_corr) / n_tokens
|
||||
pos_acc = float(pos_corr) / n_tokens
|
||||
|
@ -223,12 +233,13 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
|||
for i, token in enumerate(tokens):
|
||||
if heads[i] is None:
|
||||
skipped += 1
|
||||
if labels[i] == 'P' or labels[i] == 'punct':
|
||||
continue
|
||||
if is_punct_label(labels[i]):
|
||||
continue
|
||||
n_corr += token.head.i == heads[i]
|
||||
total += 1
|
||||
print skipped
|
||||
return float(n_corr) / total
|
||||
print loss, skipped, (loss+skipped + total)
|
||||
return float(n_corr) / (total + loss)
|
||||
|
||||
|
||||
def main(train_loc, dev_loc, model_dir):
|
||||
|
|
Loading…
Reference in New Issue
Block a user