mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Update train.py
This commit is contained in:
parent
41f88ce938
commit
bd1bfcca61
|
@ -17,6 +17,7 @@ import spacy.util
|
|||
from spacy.syntax.util import Config
|
||||
from spacy.gold import read_json_file
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.gold import merge_sents
|
||||
|
||||
from spacy.scorer import Scorer
|
||||
|
||||
|
@ -63,22 +64,6 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
|||
scorer.score(tokens, gold, verbose=verbose)
|
||||
|
||||
|
||||
def _merge_sents(sents):
|
||||
m_deps = [[], [], [], [], [], []]
|
||||
m_brackets = []
|
||||
i = 0
|
||||
for (ids, words, tags, heads, labels, ner), brackets in sents:
|
||||
m_deps[0].extend(id_ + i for id_ in ids)
|
||||
m_deps[1].extend(words)
|
||||
m_deps[2].extend(tags)
|
||||
m_deps[3].extend(head + i for head in heads)
|
||||
m_deps[4].extend(labels)
|
||||
m_deps[5].extend(ner)
|
||||
m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets)
|
||||
i += len(ids)
|
||||
return [(m_deps, m_brackets)]
|
||||
|
||||
|
||||
def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
||||
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
|
||||
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||
|
@ -86,10 +71,11 @@ def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, ent
|
|||
with Language.train(model_dir, train_data,
|
||||
tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
||||
loss = 0
|
||||
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
||||
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=gold_preproc,
|
||||
augment_data=None)):
|
||||
for doc, gold in epoch:
|
||||
trainer.update(doc, gold)
|
||||
dev_scores = trainer.evaluate(dev_data)
|
||||
dev_scores = trainer.evaluate(dev_data, gold_preproc=gold_preproc)
|
||||
print(format_str.format(itn, loss, **dev_scores.scores))
|
||||
|
||||
|
||||
|
@ -105,7 +91,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
|
|||
if gold_preproc:
|
||||
raw_text = None
|
||||
else:
|
||||
sents = _merge_sents(sents)
|
||||
sents = merge_sents(sents)
|
||||
for annot_tuples, brackets in sents:
|
||||
if raw_text is None:
|
||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||
|
|
Loading…
Reference in New Issue
Block a user