mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Update conll_train.py script for spaCy v0.97
This commit is contained in:
parent
cfaa4bde5d
commit
4fb038a9eb
|
@ -16,11 +16,15 @@ import pstats
|
|||
|
||||
import spacy.util
|
||||
from spacy.en import English
|
||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||
from spacy.gold import GoldParse
|
||||
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.syntax.parser import Parser
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.tagger import Tagger
|
||||
|
||||
# Last updated for spaCy v0.97
|
||||
|
||||
|
||||
def read_conll(file_):
|
||||
|
@ -79,20 +83,25 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
shutil.rmtree(pos_model_dir)
|
||||
os.mkdir(dep_model_dir)
|
||||
os.mkdir(pos_model_dir)
|
||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES,
|
||||
pos_model_dir)
|
||||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
||||
beam_width=0)
|
||||
labels=ArcEager.get_labels(gold_tuples))
|
||||
|
||||
nlp = Language(data_dir=model_dir)
|
||||
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||
|
||||
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||
for itn in range(n_iter):
|
||||
scorer = Scorer()
|
||||
loss = 0
|
||||
for _, sents in gold_tuples:
|
||||
for annot_tuples, _ in sents:
|
||||
if len(annot_tuples[1]) == 1:
|
||||
continue
|
||||
|
||||
score_model(scorer, nlp, None, annot_tuples, verbose=False)
|
||||
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
nlp.tagger(tokens)
|
||||
gold = GoldParse(tokens, annot_tuples, make_projective=True)
|
||||
|
@ -101,22 +110,21 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
|||
"Non-projective sentence in training, after we should "
|
||||
"have enforced projectivity: %s" % annot_tuples
|
||||
)
|
||||
|
||||
loss += nlp.parser.train(tokens, gold)
|
||||
nlp.tagger.train(tokens, gold.tags)
|
||||
random.shuffle(gold_tuples)
|
||||
print('%d:\t%d\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas,
|
||||
scorer.tags_acc, scorer.token_acc))
|
||||
nlp.tagger.model.end_training()
|
||||
nlp.parser.model.end_training()
|
||||
nlp.vocab.strings.dump(path.join(model_dir, 'vocab', 'strings.txt'))
|
||||
return nlp
|
||||
print('end training')
|
||||
nlp.end_training(model_dir)
|
||||
print('done')
|
||||
|
||||
|
||||
def main(train_loc, dev_loc, model_dir):
|
||||
#with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||
# train_sents = read_conll(file_)
|
||||
#train_sents = train_sents
|
||||
#train(English, train_sents, model_dir)
|
||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
||||
train_sents = read_conll(file_)
|
||||
train(English, train_sents, model_dir)
|
||||
nlp = English(data_dir=model_dir)
|
||||
dev_sents = read_conll(open(dev_loc))
|
||||
scorer = Scorer()
|
||||
|
|
Loading…
Reference in New Issue
Block a user