mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* Remove support for force_gold flag from GreedyParser, since it's not so useful, and it's clutter
This commit is contained in:
parent
6a6085f8b9
commit
e854ba0a13
|
@ -187,7 +187,7 @@ def get_labels(sents):
|
||||||
|
|
||||||
|
|
||||||
def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
gold_preproc=False, force_gold=False, n_sents=0):
|
gold_preproc=False, n_sents=0):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
pos_model_dir = path.join(model_dir, 'pos')
|
||||||
ner_model_dir = path.join(model_dir, 'ner')
|
ner_model_dir = path.join(model_dir, 'ner')
|
||||||
|
@ -230,8 +230,8 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
for tokens in sents:
|
for tokens in sents:
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
gold = GoldParse(tokens, annot_tuples)
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser.train(tokens, gold, force_gold=force_gold)
|
nlp.parser.train(tokens, gold)
|
||||||
nlp.entity.train(tokens, gold, force_gold=force_gold)
|
nlp.entity.train(tokens, gold)
|
||||||
nlp.tagger.train(tokens, gold.tags)
|
nlp.tagger.train(tokens, gold.tags)
|
||||||
|
|
||||||
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
|
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
|
||||||
|
@ -280,7 +280,7 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
|
||||||
def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False,
|
def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False,
|
||||||
debug=False):
|
debug=False):
|
||||||
train(English, train_loc, model_dir, feat_set='basic' if not debug else 'debug',
|
train(English, train_loc, model_dir, feat_set='basic' if not debug else 'debug',
|
||||||
gold_preproc=False, force_gold=False, n_sents=n_sents)
|
gold_preproc=False, n_sents=n_sents)
|
||||||
if out_loc:
|
if out_loc:
|
||||||
write_parses(English, dev_loc, model_dir, out_loc)
|
write_parses(English, dev_loc, model_dir, out_loc)
|
||||||
scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False, verbose=verbose)
|
scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False, verbose=verbose)
|
||||||
|
|
|
@ -93,7 +93,7 @@ cdef class GreedyParser:
|
||||||
tokens.set_parse(state.sent)
|
tokens.set_parse(state.sent)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def train(self, Tokens tokens, GoldParse gold, force_gold=False):
|
def train(self, Tokens tokens, GoldParse gold):
|
||||||
self.moves.preprocess_gold(gold)
|
self.moves.preprocess_gold(gold)
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||||
|
@ -109,12 +109,11 @@ cdef class GreedyParser:
|
||||||
while not is_final(state):
|
while not is_final(state):
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
|
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
best = self.moves.best_gold(scores, state, gold)
|
best = self.moves.best_gold(scores, state, gold)
|
||||||
|
|
||||||
cost = guess.get_cost(&guess, state, gold)
|
cost = guess.get_cost(&guess, state, gold)
|
||||||
|
|
||||||
self.model.update(context, guess.clas, best.clas, cost)
|
self.model.update(context, guess.clas, best.clas, cost)
|
||||||
if force_gold:
|
|
||||||
best.do(&best, state)
|
guess.do(&guess, state)
|
||||||
else:
|
|
||||||
guess.do(&guess, state)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user