mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
* Fix standard conll file reading. Script needs refactoring.
This commit is contained in:
parent
c55a33d045
commit
27986d7f5c
|
@ -19,6 +19,7 @@ from spacy.en import English
|
||||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||||
|
|
||||||
from spacy.syntax.parser import GreedyParser
|
from spacy.syntax.parser import GreedyParser
|
||||||
|
from spacy.syntax.parser import OracleError
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ def read_tokenized_gold(file_):
|
||||||
labels = []
|
labels = []
|
||||||
tags = []
|
tags = []
|
||||||
for i, line in enumerate(sent_str.split('\n')):
|
for i, line in enumerate(sent_str.split('\n')):
|
||||||
word, pos_string, head_idx, label = _parse_line(line)
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||||
words.append(word)
|
words.append(word)
|
||||||
if head_idx == -1:
|
if head_idx == -1:
|
||||||
head_idx = i
|
head_idx = i
|
||||||
|
@ -44,22 +45,29 @@ def read_tokenized_gold(file_):
|
||||||
heads.append(head_idx)
|
heads.append(head_idx)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
tags.append(pos_string)
|
tags.append(pos_string)
|
||||||
sents.append((ids_, words, heads, labels, tags))
|
text = ' '.join(words)
|
||||||
|
sents.append((text, [words], ids, words, tags, heads, labels))
|
||||||
return sents
|
return sents
|
||||||
|
|
||||||
|
|
||||||
def read_docparse_gold(file_):
|
def read_docparse_gold(file_):
|
||||||
paragraphs = []
|
paragraphs = []
|
||||||
for sent_str in file_.read().strip().split('\n\n'):
|
for sent_str in file_.read().strip().split('<text>'):
|
||||||
|
if not sent_str.strip():
|
||||||
|
continue
|
||||||
words = []
|
words = []
|
||||||
heads = []
|
heads = []
|
||||||
labels = []
|
labels = []
|
||||||
tags = []
|
tags = []
|
||||||
ids = []
|
ids = []
|
||||||
|
try:
|
||||||
|
raw_text, sent_str = sent_str.strip().split('</text>', 1)
|
||||||
|
except:
|
||||||
|
print sent_str
|
||||||
|
raise
|
||||||
lines = sent_str.strip().split('\n')
|
lines = sent_str.strip().split('\n')
|
||||||
raw_text = lines[0]
|
tok_text = lines.pop(0)
|
||||||
tok_text = lines[1]
|
for i, line in enumerate(lines):
|
||||||
for i, line in enumerate(lines[2:]):
|
|
||||||
id_, word, pos_string, head_idx, label = _parse_line(line)
|
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||||
if label == 'root':
|
if label == 'root':
|
||||||
label = 'ROOT'
|
label = 'ROOT'
|
||||||
|
@ -180,7 +188,7 @@ def get_labels(sents):
|
||||||
|
|
||||||
|
|
||||||
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
gold_preproc=False):
|
gold_preproc=False, force_gold=False):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
|
@ -205,7 +213,10 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
gold_preproc=gold_preproc):
|
gold_preproc=gold_preproc):
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=False)
|
try:
|
||||||
|
heads_corr += nlp.parser.train_sent(tokens, heads, labels, force_gold=force_gold)
|
||||||
|
except OracleError:
|
||||||
|
continue
|
||||||
pos_corr += nlp.tagger.train(tokens, tag_strs)
|
pos_corr += nlp.tagger.train(tokens, tag_strs)
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
|
@ -221,10 +232,13 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
global loss
|
global loss
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
n_corr = 0
|
n_corr = 0
|
||||||
|
pos_corr = 0
|
||||||
|
n_tokens = 0
|
||||||
total = 0
|
total = 0
|
||||||
skipped = 0
|
skipped = 0
|
||||||
loss = 0
|
loss = 0
|
||||||
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
||||||
|
#paragraphs = read_tokenized_gold(file_)
|
||||||
paragraphs = read_docparse_gold(file_)
|
paragraphs = read_docparse_gold(file_)
|
||||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
||||||
gold_preproc=gold_preproc):
|
gold_preproc=gold_preproc):
|
||||||
|
@ -232,6 +246,8 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser(tokens)
|
nlp.parser(tokens)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
|
pos_corr += token.tag_ == tag_strs[i]
|
||||||
|
n_tokens += 1
|
||||||
if heads[i] is None:
|
if heads[i] is None:
|
||||||
skipped += 1
|
skipped += 1
|
||||||
continue
|
continue
|
||||||
|
@ -240,14 +256,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
n_corr += token.head.i == heads[i]
|
n_corr += token.head.i == heads[i]
|
||||||
total += 1
|
total += 1
|
||||||
print loss, skipped, (loss+skipped + total)
|
print loss, skipped, (loss+skipped + total)
|
||||||
|
print pos_corr / n_tokens
|
||||||
return float(n_corr) / (total + loss)
|
return float(n_corr) / (total + loss)
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||||
train_sents = read_docparse_gold(file_)
|
#train_sents = read_docparse_gold(file_)
|
||||||
train(English, train_sents, model_dir, gold_preproc=False)
|
train_sents = read_tokenized_gold(file_)
|
||||||
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
|
#train(English, train_sents, model_dir, gold_preproc=True, force_gold=False)
|
||||||
|
print evaluate(English, dev_loc, model_dir, gold_preproc=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user