mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Tmp commit
This commit is contained in:
parent
10ed738df2
commit
f5f15a1ef2
|
@ -21,6 +21,7 @@ from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
|||
from spacy.syntax.parser import GreedyParser
|
||||
from spacy.syntax.parser import OracleError
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.syntax.conll import GoldParse
|
||||
|
||||
|
||||
def is_punct_label(label):
|
||||
|
@ -184,6 +185,7 @@ def get_labels(sents):
|
|||
|
||||
def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||
gold_preproc=False, force_gold=False):
|
||||
print "Setup model dir"
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
if path.exists(dep_model_dir):
|
||||
|
@ -198,7 +200,6 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
|||
labels = Language.ParserTransitionSystem.get_labels(gold_sents)
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=labels)
|
||||
|
||||
nlp = Language()
|
||||
|
||||
for itn in range(n_iter):
|
||||
|
@ -206,16 +207,16 @@ def train(Language, paragraphs, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
|||
pos_corr = 0
|
||||
n_tokens = 0
|
||||
for gold_sent in gold_sents:
|
||||
tokens = nlp.tokenizer(gold_sent.raw)
|
||||
gold_sent.align_to_tokens(tokens)
|
||||
tokens = nlp.tokenizer(gold_sent.raw_text)
|
||||
gold_sent.align_to_tokens(tokens, nlp.parser.moves.label_ids)
|
||||
nlp.tagger(tokens)
|
||||
heads_corr += nlp.parser.train(tokens, gold_sent, force_gold=force_gold)
|
||||
pos_corr += nlp.tagger.train(tokens, gold_parse.tags)
|
||||
pos_corr += nlp.tagger.train(tokens, gold_sent.tags)
|
||||
n_tokens += len(tokens)
|
||||
acc = float(heads_corr) / n_tokens
|
||||
pos_acc = float(pos_corr) / n_tokens
|
||||
print '%d: ' % itn, '%.3f' % acc, '%.3f' % pos_acc
|
||||
random.shuffle(paragraphs)
|
||||
random.shuffle(gold_sents)
|
||||
nlp.parser.model.end_training()
|
||||
nlp.tagger.model.end_training()
|
||||
return acc
|
||||
|
@ -257,10 +258,16 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False):
|
|||
return float(uas_corr) / (total + loss)
|
||||
|
||||
|
||||
def main(train_loc, dev_loc, model_dir):
|
||||
train(English, read_docparse_gold(train_loc), model_dir,
|
||||
@plac.annotations(
|
||||
train_loc=("Training file location",),
|
||||
dev_loc=("Dev. file location",),
|
||||
model_dir=("Location of output model directory",),
|
||||
n_sents=("Number of training sentences", "option", "n", int)
|
||||
)
|
||||
def main(train_loc, dev_loc, model_dir, n_sents=0):
|
||||
train(English, read_gold(train_loc, n=n_sents), model_dir,
|
||||
gold_preproc=False, force_gold=False)
|
||||
print evaluate(English, read_docparse_gold(dev_loc), model_dir, gold_preproc=False)
|
||||
print evaluate(English, read_gold(dev_loc), model_dir, gold_preproc=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user