mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* Add train and parse scripts that use CoNLL formatted data
This commit is contained in:
parent
3072016155
commit
cfaa4bde5d
130
bin/parser/conll_parse.py
Normal file
130
bin/parser/conll_parse.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from os import path
|
||||
import shutil
|
||||
import codecs
|
||||
import random
|
||||
import time
|
||||
import gzip
|
||||
|
||||
import plac
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
import spacy.util
|
||||
from spacy.en import English
|
||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||
|
||||
from spacy.syntax.parser import GreedyParser
|
||||
from spacy.syntax.parser import OracleError
|
||||
from spacy.syntax.util import Config
|
||||
|
||||
|
||||
def is_punct_label(label):
|
||||
return label == 'P' or label.lower() == 'punct'
|
||||
|
||||
|
||||
def read_gold(file_):
|
||||
"""Read a standard CoNLL/MALT-style format"""
|
||||
sents = []
|
||||
for sent_str in file_.read().strip().split('\n\n'):
|
||||
ids = []
|
||||
words = []
|
||||
heads = []
|
||||
labels = []
|
||||
tags = []
|
||||
for i, line in enumerate(sent_str.split('\n')):
|
||||
id_, word, pos_string, head_idx, label = _parse_line(line)
|
||||
words.append(word)
|
||||
if head_idx == -1:
|
||||
head_idx = i
|
||||
ids.append(id_)
|
||||
heads.append(head_idx)
|
||||
labels.append(label)
|
||||
tags.append(pos_string)
|
||||
text = ' '.join(words)
|
||||
sents.append((text, [words], ids, words, tags, heads, labels))
|
||||
return sents
|
||||
|
||||
|
||||
def _parse_line(line):
|
||||
pieces = line.split()
|
||||
id_ = int(pieces[0])
|
||||
word = pieces[1]
|
||||
pos = pieces[3]
|
||||
head_idx = int(pieces[6])
|
||||
label = pieces[7]
|
||||
return id_, word, pos, head_idx, label
|
||||
|
||||
|
||||
def iter_data(paragraphs, tokenizer, gold_preproc=False):
|
||||
for raw, tokenized, ids, words, tags, heads, labels in paragraphs:
|
||||
assert len(words) == len(heads)
|
||||
for words in tokenized:
|
||||
sent_ids = ids[:len(words)]
|
||||
sent_tags = tags[:len(words)]
|
||||
sent_heads = heads[:len(words)]
|
||||
sent_labels = labels[:len(words)]
|
||||
sent_heads = _map_indices_to_tokens(sent_ids, sent_heads)
|
||||
tokens = tokenizer.tokens_from_list(words)
|
||||
yield tokens, sent_tags, sent_heads, sent_labels
|
||||
ids = ids[len(words):]
|
||||
tags = tags[len(words):]
|
||||
heads = heads[len(words):]
|
||||
labels = labels[len(words):]
|
||||
|
||||
|
||||
def _map_indices_to_tokens(ids, heads):
|
||||
mapped = []
|
||||
for head in heads:
|
||||
if head not in ids:
|
||||
mapped.append(None)
|
||||
else:
|
||||
mapped.append(ids.index(head))
|
||||
return mapped
|
||||
|
||||
|
||||
|
||||
def evaluate(Language, dev_loc, model_dir):
|
||||
global loss
|
||||
nlp = Language()
|
||||
n_corr = 0
|
||||
pos_corr = 0
|
||||
n_tokens = 0
|
||||
total = 0
|
||||
skipped = 0
|
||||
loss = 0
|
||||
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
||||
paragraphs = read_gold(file_)
|
||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer):
|
||||
assert len(tokens) == len(labels)
|
||||
nlp.tagger.tag_from_strings(tokens, tag_strs)
|
||||
nlp.parser(tokens)
|
||||
for i, token in enumerate(tokens):
|
||||
try:
|
||||
pos_corr += token.tag_ == tag_strs[i]
|
||||
except:
|
||||
print i, token.orth_, token.tag
|
||||
raise
|
||||
n_tokens += 1
|
||||
if heads[i] is None:
|
||||
skipped += 1
|
||||
continue
|
||||
if is_punct_label(labels[i]):
|
||||
continue
|
||||
n_corr += token.head.i == heads[i]
|
||||
total += 1
|
||||
print loss, skipped, (loss+skipped + total)
|
||||
print pos_corr / n_tokens
|
||||
return float(n_corr) / (total + loss)
|
||||
|
||||
|
||||
def main(dev_loc, model_dir):
|
||||
print evaluate(English, dev_loc, model_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
133
bin/parser/conll_train.py
Executable file
133
bin/parser/conll_train.py
Executable file
|
@ -0,0 +1,133 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from os import path
|
||||
import shutil
|
||||
import codecs
|
||||
import random
|
||||
import time
|
||||
import gzip
|
||||
|
||||
import plac
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
import spacy.util
|
||||
from spacy.en import English
|
||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||
from spacy.gold import GoldParse
|
||||
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.scorer import Scorer
|
||||
|
||||
|
||||
def read_conll(file_):
|
||||
"""Read a standard CoNLL/MALT-style format"""
|
||||
sents = []
|
||||
for sent_str in file_.read().strip().split('\n\n'):
|
||||
ids = []
|
||||
words = []
|
||||
heads = []
|
||||
labels = []
|
||||
tags = []
|
||||
for i, line in enumerate(sent_str.split('\n')):
|
||||
word, pos_string, head_idx, label = _parse_line(line)
|
||||
words.append(word)
|
||||
if head_idx < 0:
|
||||
head_idx = i
|
||||
ids.append(i)
|
||||
heads.append(head_idx)
|
||||
labels.append(label)
|
||||
tags.append(pos_string)
|
||||
text = ' '.join(words)
|
||||
annot = (ids, words, tags, heads, labels, ['O'] * len(ids))
|
||||
sents.append((None, [(annot, [])]))
|
||||
return sents
|
||||
|
||||
|
||||
def _parse_line(line):
|
||||
pieces = line.split()
|
||||
if len(pieces) == 4:
|
||||
word, pos, head_idx, label = pieces
|
||||
head_idx = int(head_idx)
|
||||
else:
|
||||
id_ = int(pieces[0])
|
||||
word = pieces[1]
|
||||
pos = pieces[4]
|
||||
head_idx = int(pieces[6])-1
|
||||
label = pieces[7]
|
||||
return word, pos, head_idx, label
|
||||
|
||||
|
||||
def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
nlp.parser(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
|
||||
|
||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||
gold_preproc=False, force_gold=False):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
if path.exists(dep_model_dir):
|
||||
shutil.rmtree(dep_model_dir)
|
||||
if path.exists(pos_model_dir):
|
||||
shutil.rmtree(pos_model_dir)
|
||||
os.mkdir(dep_model_dir)
|
||||
os.mkdir(pos_model_dir)
|
||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
||||
pos_model_dir)
|
||||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||
beam_width=0)
|
||||
|
||||
nlp = Language(data_dir=model_dir)
|
||||
for itn in range(n_iter):
|
||||
scorer = Scorer()
|
||||
loss = 0
|
||||
for _, sents in gold_tuples:
|
||||
for annot_tuples, _ in sents:
|
||||
score_model(scorer, nlp, None, annot_tuples, verbose=False)
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||
if not gold.is_projective:
|
||||
raise Exception(
|
||||
"Non-projective sentence in training, after we should "
|
||||
"have enforced projectivity: %s" % annot_tuples
|
||||
)
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
random.shuffle(gold_tuples)
|
||||
print('%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas,
|
||||
scorer.tags_acc, scorer.token_acc))
|
||||
nlp.tagger.model.end_training()
|
||||
nlp.parser.model.end_training()
|
||||
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
|
||||
return nlp
|
||||
|
||||
|
||||
def main(train_loc, dev_loc, model_dir):
|
||||
#with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||
# train_sents = read_conll(file_)
|
||||
#train_sents = train_sents
|
||||
#train(English, train_sents, model_dir)
|
||||
nlp = English(data_dir=model_dir)
|
||||
dev_sents = read_conll(open(dev_loc))
|
||||
scorer = Scorer()
|
||||
for _, sents in dev_sents:
|
||||
for annot_tuples, _ in sents:
|
||||
score_model(scorer, nlp, None, annot_tuples)
|
||||
print('TOK', 100-scorer.token_acc)
|
||||
print('POS', scorer.tags_acc)
|
||||
print('UAS', scorer.uas)
|
||||
print('LAS', scorer.las)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
Loading…
Reference in New Issue
Block a user