mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Update conll_train.py script for spaCy v0.97
This commit is contained in:
parent
cfaa4bde5d
commit
4fb038a9eb
|
@ -16,11 +16,15 @@ import pstats
|
||||||
|
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.en import English
|
from spacy.en import English
|
||||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
|
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
|
from spacy.syntax.parser import Parser
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
|
from spacy.tagger import Tagger
|
||||||
|
|
||||||
|
# Last updated for spaCy v0.97
|
||||||
|
|
||||||
|
|
||||||
def read_conll(file_):
|
def read_conll(file_):
|
||||||
|
@ -79,20 +83,25 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
shutil.rmtree(pos_model_dir)
|
shutil.rmtree(pos_model_dir)
|
||||||
os.mkdir(dep_model_dir)
|
os.mkdir(dep_model_dir)
|
||||||
os.mkdir(pos_model_dir)
|
os.mkdir(pos_model_dir)
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
|
||||||
pos_model_dir)
|
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
labels=ArcEager.get_labels(gold_tuples))
|
||||||
beam_width=0)
|
|
||||||
|
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||||
|
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||||
|
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||||
|
|
||||||
|
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
loss = 0
|
loss = 0
|
||||||
for _, sents in gold_tuples:
|
for _, sents in gold_tuples:
|
||||||
for annot_tuples, _ in sents:
|
for annot_tuples, _ in sents:
|
||||||
|
if len(annot_tuples[1]) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
score_model(scorer, nlp, None, annot_tuples, verbose=False)
|
score_model(scorer, nlp, None, annot_tuples, verbose=False)
|
||||||
|
|
||||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||||
|
@ -101,22 +110,21 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
"Non-projective sentence in training, after we should "
|
"Non-projective sentence in training, after we should "
|
||||||
"have enforced projectivity: %s" % annot_tuples
|
"have enforced projectivity: %s" % annot_tuples
|
||||||
)
|
)
|
||||||
|
|
||||||
loss += nlp.parser.train(tokens, gold)
|
loss += nlp.parser.train(tokens, gold)
|
||||||
nlp.tagger.train(tokens, gold.tags)
|
nlp.tagger.train(tokens, gold.tags)
|
||||||
random.shuffle(gold_tuples)
|
random.shuffle(gold_tuples)
|
||||||
print('%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas,
|
print('%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas,
|
||||||
scorer.tags_acc, scorer.token_acc))
|
scorer.tags_acc, scorer.token_acc))
|
||||||
nlp.tagger.model.end_training()
|
print('end training')
|
||||||
nlp.parser.model.end_training()
|
nlp.end_training(model_dir)
|
||||||
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
|
print('done')
|
||||||
return nlp
|
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
#with codecs.open(train_loc, 'r', 'utf8') as file_:
|
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||||
# train_sents = read_conll(file_)
|
train_sents = read_conll(file_)
|
||||||
#train_sents = train_sents
|
train(English, train_sents, model_dir)
|
||||||
#train(English, train_sents, model_dir)
|
|
||||||
nlp = English(data_dir=model_dir)
|
nlp = English(data_dir=model_dir)
|
||||||
dev_sents = read_conll(open(dev_loc))
|
dev_sents = read_conll(open(dev_loc))
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user