Pass cfg through loading, for training.

This commit is contained in:
Matthew Honnibal 2016-11-25 09:01:20 -06:00
parent 608d8f5421
commit a2f55e7015

View File

@ -31,6 +31,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity
from .pipeline import DependencyParser, EntityRecognizer
from .syntax.arc_eager import ArcEager
from .syntax.ner import BiluoPushDown
class BaseDefaults(object):
@ -65,7 +67,7 @@ class BaseDefaults(object):
prefix_search = util.compile_prefix_regex(cls.prefixes).search
suffix_search = util.compile_suffix_regex(cls.suffixes).search
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
vocab = nlp.vocab if nlp is not None else cls.Default.create_vocab(nlp)
return Tokenizer(nlp.vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer)
@ -82,26 +84,27 @@ class BaseDefaults(object):
return Tagger.load(nlp.path / 'pos', nlp.vocab)
@classmethod
def create_parser(cls, nlp=None):
def create_parser(cls, nlp=None, **cfg):
if nlp is None:
return DependencyParser(cls.create_vocab(), features=cls.parser_features)
return DependencyParser(cls.create_vocab(), features=cls.parser_features,
**cfg)
elif nlp.path is False:
return DependencyParser(nlp.vocab, features=cls.parser_features)
return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
elif nlp.path is None or not (nlp.path / 'deps').exists():
return None
else:
return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
@classmethod
def create_entity(cls, nlp=None):
def create_entity(cls, nlp=None, **cfg):
if nlp is None:
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
elif nlp.path is False:
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
elif nlp.path is None or not (nlp.path / 'ner').exists():
return None
else:
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
@classmethod
def create_matcher(cls, nlp=None):
@ -202,8 +205,8 @@ class Language(object):
# preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples)
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples)
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('wb') as file_:
json.dump(parser_cfg, file_)
@ -224,22 +227,18 @@ class Language(object):
vectors=False,
pipeline=False)
self.defaults.parser_labels = parser_cfg['labels']
self.defaults.entity_labels = entity_cfg['labels']
self.vocab = self.defaults.Vocab()
self.tokenizer = self.defaults.Tokenizer(self.vocab)
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg)
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
self.pipeline = self.defaults.Pipeline(self)
self.vocab = self.Defaults.create_vocab(self)
self.tokenizer = self.Defaults.create_tokenizer(self)
self.tagger = self.Defaults.create_tagger(self)
self.parser = self.Defaults.create_parser(self)
self.entity = self.Defaults.create_entity(self)
self.pipeline = self.Defaults.create_pipeline(self)
yield Trainer(self, gold_tuples)
self.end_training()
def __init__(self, path=True, **overrides):
if 'data_dir' in overrides and 'path' not in overrides:
if 'data_dir' in overrides and 'path' is True:
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
path = overrides.get('path', True)
if isinstance(path, basestring):
path = pathlib.Path(path)
if path is True:
@ -253,7 +252,7 @@ class Language(object):
add_vectors = self.Defaults.add_vectors(self) \
if 'add_vectors' not in overrides \
else overrides['add_vectors']
if add_vectors:
if self.vocab and add_vectors:
add_vectors(self.vocab)
self.tokenizer = self.Defaults.create_tokenizer(self) \
if 'tokenizer' not in overrides \