mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 00:04:15 +03:00
* Update train.py for language-generic spaCy
This commit is contained in:
parent
950ce36660
commit
d1eea2d865
|
@ -14,7 +14,6 @@ import re
|
||||||
|
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.en import English
|
from spacy.en import English
|
||||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
|
|
||||||
|
|
||||||
from spacy.syntax.util import Config
|
from spacy.syntax.util import Config
|
||||||
from spacy.gold import read_json_file
|
from spacy.gold import read_json_file
|
||||||
|
@ -22,6 +21,11 @@ from spacy.gold import GoldParse
|
||||||
|
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
|
|
||||||
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
|
from spacy.syntax.ner import BiluoPushDown
|
||||||
|
from spacy.tagger import Tagger
|
||||||
|
from spacy.syntax.parser import Parser
|
||||||
|
|
||||||
|
|
||||||
def _corrupt(c, noise_level):
|
def _corrupt(c, noise_level):
|
||||||
if random.random() >= noise_level:
|
if random.random() >= noise_level:
|
||||||
|
@ -80,32 +84,28 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
beam_width=1, verbose=False,
|
beam_width=1, verbose=False,
|
||||||
use_orig_arc_eager=False):
|
use_orig_arc_eager=False):
|
||||||
dep_model_dir = path.join(model_dir, 'deps')
|
dep_model_dir = path.join(model_dir, 'deps')
|
||||||
pos_model_dir = path.join(model_dir, 'pos')
|
|
||||||
ner_model_dir = path.join(model_dir, 'ner')
|
ner_model_dir = path.join(model_dir, 'ner')
|
||||||
if path.exists(dep_model_dir):
|
if path.exists(dep_model_dir):
|
||||||
shutil.rmtree(dep_model_dir)
|
shutil.rmtree(dep_model_dir)
|
||||||
if path.exists(pos_model_dir):
|
|
||||||
shutil.rmtree(pos_model_dir)
|
|
||||||
if path.exists(ner_model_dir):
|
if path.exists(ner_model_dir):
|
||||||
shutil.rmtree(ner_model_dir)
|
shutil.rmtree(ner_model_dir)
|
||||||
os.mkdir(dep_model_dir)
|
os.mkdir(dep_model_dir)
|
||||||
os.mkdir(pos_model_dir)
|
|
||||||
os.mkdir(ner_model_dir)
|
os.mkdir(ner_model_dir)
|
||||||
|
|
||||||
setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
|
|
||||||
|
|
||||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||||
labels=Language.ParserTransitionSystem.get_labels(gold_tuples),
|
labels=ArcEager.get_labels(gold_tuples),
|
||||||
beam_width=beam_width)
|
beam_width=beam_width)
|
||||||
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
Config.write(ner_model_dir, 'config', features='ner', seed=seed,
|
||||||
labels=Language.EntityTransitionSystem.get_labels(gold_tuples),
|
labels=BiluoPushDown.get_labels(gold_tuples),
|
||||||
beam_width=0)
|
beam_width=0)
|
||||||
|
|
||||||
if n_sents > 0:
|
if n_sents > 0:
|
||||||
gold_tuples = gold_tuples[:n_sents]
|
gold_tuples = gold_tuples[:n_sents]
|
||||||
|
|
||||||
nlp = Language(data_dir=model_dir)
|
nlp = Language(data_dir=model_dir, tagger=False, parser=False, entity=False)
|
||||||
|
nlp.tagger = Tagger.blank(nlp.vocab, Tagger.default_templates())
|
||||||
|
nlp.parser = Parser.from_dir(dep_model_dir, nlp.vocab.strings, ArcEager)
|
||||||
|
nlp.entity = Parser.from_dir(ner_model_dir, nlp.vocab.strings, BiluoPushDown)
|
||||||
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
print("Itn.\tP.Loss\tUAS\tNER F.\tTag %\tToken %")
|
||||||
for itn in range(n_iter):
|
for itn in range(n_iter):
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
|
@ -140,7 +140,7 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
|
||||||
print('%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
|
print('%d:\t%d\t%.3f\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.ents_f,
|
||||||
scorer.tags_acc,
|
scorer.tags_acc,
|
||||||
scorer.token_acc))
|
scorer.token_acc))
|
||||||
nlp.end_training()
|
nlp.end_training(model_dir)
|
||||||
|
|
||||||
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
|
||||||
beam_width=None):
|
beam_width=None):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user