mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-06 06:03:11 +03:00
* Tmp commit
This commit is contained in:
parent
10ed738df2
commit
f5f15a1ef2
|
@ -21,6 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
||||||
from spacy.syntax.parser import GreedyParser
|
from spacy.syntax.parser import GreedyParser
|
||||||
from spacy.syntax.parser import OracleError
|
from spacy.syntax.parser import OracleError
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
|
from spacy.syntax.conll import GoldParse
|
||||||
|
|
||||||
|
|
||||||
def is_punct_label(label):
|
def is_punct_label(label):
|
||||||
|
@ -184,6 +185,7 @@ def get_labels(sents):
|
||||||
|
|
||||||
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
gold_preproc=False, force_gold=False):
|
gold_preproc=False, force_gold=False):
|
||||||
|
print "Setup model dir"
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
|
@ -198,7 +200,6 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
|
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=labels)
|
labels=labels)
|
||||||
|
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
|
@ -206,16 +207,16 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
pos_corr = 0
|
pos_corr = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
for gold_sent in gold_sents:
|
for gold_sent in gold_sents:
|
||||||
tokens = nlp.tokenizer(gold_sent.raw)
|
tokens = nlp.tokenizer(gold_sent.raw_text)
|
||||||
gold_sent.align_to_tokens(tokens)
|
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
||||||
pos_corr += nlp.tagger.train(tokens, gold_parse.tags)
|
pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
|
||||||
n_tokens += len(tokens)
|
n_tokens += len(tokens)
|
||||||
acc = float(heads_corr) / n_tokens
|
acc = float(heads_corr) / n_tokens
|
||||||
pos_acc = float(pos_corr) / n_tokens
|
pos_acc = float(pos_corr) / n_tokens
|
||||||
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
||||||
random.shuffle(paragraphs)
|
random.shuffle(gold_sents)
|
||||||
nlp.parser.model.end_training()
|
nlp.parser.model.end_training()
|
||||||
nlp.tagger.model.end_training()
|
nlp.tagger.model.end_training()
|
||||||
return acc
|
return acc
|
||||||
|
@ -257,10 +258,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
||||||
return float(uas_corr) / (total + loss)
|
return float(uas_corr) / (total + loss)
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
@plac.annotations(
|
||||||
train(English, read_docparse_gold(train_loc), model_dir,
|
train_loc=("Training file location",),
|
||||||
|
dev_loc=("Dev. file location",),
|
||||||
|
model_dir=("Location of output model directory",),
|
||||||
|
n_sents=("Number of training sentences", "option", "n", int)
|
||||||
|
)
|
||||||
|
def main(train_loc, dev_loc, model_dir, n_sents=0):
|
||||||
|
train(English, read_gold(train_loc, n=n_sents), model_dir,
|
||||||
gold_preproc=False, force_gold=False)
|
gold_preproc=False, force_gold=False)
|
||||||
print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False)
|
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user