mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Update train.py
This commit is contained in:
parent
41f88ce938
commit
bd1bfcca61
|
@ -17,6 +17,7 @@ import spacy.util
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
from spacy.gold import read_json_file
|
from spacy.gold import read_json_file
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
|
from spacy.gold import merge_sents
|
||||||
|
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
|
|
||||||
|
@ -63,22 +64,6 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||||
scorer.score(tokens, gold, verbose=verbose)
|
scorer.score(tokens, gold, verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
def _merge_sents(sents):
|
|
||||||
m_deps = [[], [], [], [], [], []]
|
|
||||||
m_brackets = []
|
|
||||||
i = 0
|
|
||||||
for (ids, words, tags, heads, labels, ner), brackets in sents:
|
|
||||||
m_deps[0].extend(id_ + i for id_ in ids)
|
|
||||||
m_deps[1].extend(words)
|
|
||||||
m_deps[2].extend(tags)
|
|
||||||
m_deps[3].extend(head + i for head in heads)
|
|
||||||
m_deps[4].extend(labels)
|
|
||||||
m_deps[5].extend(ner)
|
|
||||||
m_brackets.extend((b['first'] + i, b['last'] + i, b['label']) for b in brackets)
|
|
||||||
i += len(ids)
|
|
||||||
return [(m_deps, m_brackets)]
|
|
||||||
|
|
||||||
|
|
||||||
def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, entity_cfg,
|
||||||
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
|
n_iter=15, seed=0, gold_preproc=False, n_sents=0, corruption_level=0):
|
||||||
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||||
|
@ -86,10 +71,11 @@ def train(Language, train_data, dev_data, model_dir, tagger_cfg, parser_cfg, ent
|
||||||
with Language.train(model_dir, train_data,
|
with Language.train(model_dir, train_data,
|
||||||
tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
tagger_cfg, parser_cfg, entity_cfg) as trainer:
|
||||||
loss = 0
|
loss = 0
|
||||||
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
for itn, epoch in enumerate(trainer.epochs(n_iter, gold_preproc=gold_preproc,
|
||||||
|
augment_data=None)):
|
||||||
for doc, gold in epoch:
|
for doc, gold in epoch:
|
||||||
trainer.update(doc, gold)
|
trainer.update(doc, gold)
|
||||||
dev_scores = trainer.evaluate(dev_data)
|
dev_scores = trainer.evaluate(dev_data, gold_preproc=gold_preproc)
|
||||||
print(format_str.format(itn, loss, **dev_scores.scores))
|
print(format_str.format(itn, loss, **dev_scores.scores))
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,7 +91,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
|
||||||
if gold_preproc:
|
if gold_preproc:
|
||||||
raw_text = None
|
raw_text = None
|
||||||
else:
|
else:
|
||||||
sents = _merge_sents(sents)
|
sents = merge_sents(sents)
|
||||||
for annot_tuples, brackets in sents:
|
for annot_tuples, brackets in sents:
|
||||||
if raw_text is None:
|
if raw_text is None:
|
||||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user