mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
* Work on updating train script for named entity recognition
This commit is contained in:
parent
357dcdcc01
commit
4539c70542
|
@ -21,7 +21,8 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||||
from spacy.syntax.parser import GreedyParser
|
from spacy.syntax.parser import GreedyParser
|
||||||
from spacy.syntax.parser import OracleError
|
from spacy.syntax.parser import OracleError
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
from spacy.syntax.conll import GoldParse, is_punct_label
|
from spacy.syntax.conll import read_docparse_file
|
||||||
|
from spacy.syntax.conll import GoldParse
|
||||||
|
|
||||||
|
|
||||||
def is_punct_label(label):
|
def is_punct_label(label):
|
||||||
|
@ -183,47 +184,56 @@ def get_labels(sents):
|
||||||
return list(sorted(left_labels)), list(sorted(right_labels))
|
return list(sorted(left_labels)), list(sorted(right_labels))
|
||||||
|
|
||||||
|
|
||||||
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
gold_preproc=False, force_gold=False):
|
gold_preproc=False, force_gold=False, n_sents=0):
|
||||||
print "Setup model dir"
|
print "Setup model dir"
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
|
ner_model_dir = path.join(model_dir, 'ner')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
shutil.rmtree(dep_model_dir)
|
shutil.rmtree(dep_model_dir)
|
||||||
if path.exists(pos_model_dir):
|
if path.exists(pos_model_dir):
|
||||||
shutil.rmtree(pos_model_dir)
|
shutil.rmtree(pos_model_dir)
|
||||||
|
if path.exists(ner_model_dir):
|
||||||
|
shutil.rmtree(ner_model_dir)
|
||||||
os.mkdir(dep_model_dir)
|
os.mkdir(dep_model_dir)
|
||||||
os.mkdir(pos_model_dir)
|
os.mkdir(pos_model_dir)
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
os.mkdir(ner_model_dir)
|
||||||
pos_model_dir)
|
|
||||||
|
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
||||||
|
|
||||||
|
gold_tuples = read_docparse_file(train_loc)
|
||||||
|
|
||||||
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=labels)
|
labels=Language.ParserTransitionSystem.get_labels(gold_tuples))
|
||||||
|
Config.write(ner_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
|
labels=Language.EntityTransitionSystem.get_labels(gold_tuples))
|
||||||
|
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
heads_corr = 0
|
dep_corr = 0
|
||||||
pos_corr = 0
|
pos_corr = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
n_all_tokens = 0
|
for raw_text, segmented_text, annot_tuples in gold_tuples:
|
||||||
for gold_sent in gold_sents:
|
|
||||||
if gold_preproc:
|
if gold_preproc:
|
||||||
#print ' '.join(gold_sent.words)
|
sents = [nlp.tokenizer.tokens_from_list(s) for s in segmented_text]
|
||||||
tokens = nlp.tokenizer.tokens_from_list(gold_sent.words)
|
|
||||||
gold_sent.map_heads(nlp.parser.moves.label_ids)
|
|
||||||
else:
|
else:
|
||||||
tokens = nlp.tokenizer(gold_sent.raw_text)
|
sents = [nlp.tokenizer(raw_text)]
|
||||||
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
|
for tokens in sents:
|
||||||
nlp.tagger(tokens)
|
|
||||||
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
gold = GoldParse(tokens, annot_tuples, nlp.tags,
|
||||||
pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
|
nlp.parser.moves.label_ids,
|
||||||
n_tokens += gold_sent.n_non_punct
|
nlp.entity.moves.label_ids)
|
||||||
n_all_tokens += len(tokens)
|
|
||||||
acc = float(heads_corr) / n_tokens
|
nlp.tagger(tokens)
|
||||||
pos_acc = float(pos_corr) / n_all_tokens
|
dep_corr += nlp.parser.train(tokens, gold, force_gold=force_gold)
|
||||||
|
pos_corr += nlp.tagger.train(tokens, gold.tags_)
|
||||||
|
n_tokens += len(tokens)
|
||||||
|
acc = float(dep_corr) / n_tokens
|
||||||
|
pos_acc = float(pos_corr) / n_tokens
|
||||||
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
||||||
random.shuffle(gold_sents)
|
random.shuffle(gold_tuples)
|
||||||
nlp.parser.model.end_training()
|
nlp.parser.model.end_training()
|
||||||
nlp.tagger.model.end_training()
|
nlp.tagger.model.end_training()
|
||||||
return acc
|
return acc
|
||||||
|
@ -239,22 +249,22 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
total = 0
|
total = 0
|
||||||
skipped = 0
|
skipped = 0
|
||||||
loss = 0
|
loss = 0
|
||||||
with codecs.open(dev_loc, 'r', 'utf8') as file_:
|
gold_tuples = read_docparse_file(train_loc)
|
||||||
#paragraphs = read_tokenized_gold(file_)
|
for raw_text, segmented_text, annot_tuples in gold_tuples:
|
||||||
paragraphs = read_docparse_gold(file_)
|
if gold_preproc:
|
||||||
for tokens, tag_strs, heads, labels in iter_data(paragraphs, nlp.tokenizer,
|
tokens = nlp.tokenizer.tokens_from_list(gold_sent.words)
|
||||||
gold_preproc=gold_preproc):
|
nlp.tagger(tokens)
|
||||||
assert len(tokens) == len(labels)
|
nlp.parser(tokens)
|
||||||
nlp.tagger(tokens)
|
gold_sent.map_heads(nlp.parser.moves.label_ids)
|
||||||
nlp.parser(tokens)
|
else:
|
||||||
|
tokens = nlp(gold_sent.raw_text)
|
||||||
|
loss += gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
pos_corr += token.tag_ == gold_sent.tags[i]
|
pos_corr += token.tag_ == gold_sent.tags[i]
|
||||||
n_tokens += 1
|
n_tokens += 1
|
||||||
if gold_sent.heads[i] is None:
|
if gold_sent.heads[i] is None:
|
||||||
skipped += 1
|
skipped += 1
|
||||||
continue
|
continue
|
||||||
#print i, token.orth_, token.head.i, gold_sent.py_heads[i], gold_sent.labels[i],
|
|
||||||
#print gold_sent.is_correct(i, token.head.i)
|
|
||||||
if gold_sent.labels[i] != 'P':
|
if gold_sent.labels[i] != 'P':
|
||||||
n_corr += gold_sent.is_correct(i, token.head.i)
|
n_corr += gold_sent.is_correct(i, token.head.i)
|
||||||
total += 1
|
total += 1
|
||||||
|
@ -263,12 +273,6 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
return float(n_corr) / (total + loss)
|
return float(n_corr) / (total + loss)
|
||||||
|
|
||||||
|
|
||||||
def read_gold(loc, n=0):
|
|
||||||
sent_strs = open(loc).read().strip().split('\n\n')
|
|
||||||
if n == 0:
|
|
||||||
n = len(sent_strs)
|
|
||||||
return [GoldParse.from_docparse(sent) for sent in sent_strs[:n]]
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
train_loc=("Training file location",),
|
train_loc=("Training file location",),
|
||||||
|
@ -277,9 +281,9 @@ def read_gold(loc, n=0):
|
||||||
n_sents=("Number of training sentences", "option", "n", int)
|
n_sents=("Number of training sentences", "option", "n", int)
|
||||||
)
|
)
|
||||||
def main(train_loc, dev_loc, model_dir, n_sents=0):
|
def main(train_loc, dev_loc, model_dir, n_sents=0):
|
||||||
#train(English, read_gold(train_loc, n=n_sents), model_dir,
|
train(English, train_loc, model_dir,
|
||||||
# gold_preproc=False, force_gold=False)
|
gold_preproc=False, force_gold=False, n_sents=n_sents)
|
||||||
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)
|
print evaluate(English, dev_loc, model_dir, gold_preproc=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user