Pass cfg through loading, for training.

This commit is contained in:
Matthew Honnibal 2016-11-25 09:01:20 -06:00
parent 608d8f5421
commit a2f55e7015

View File

@ -31,6 +31,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
from .syntax.parser import get_templates from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.nonproj import PseudoProjectivity
from .pipeline import DependencyParser, EntityRecognizer from .pipeline import DependencyParser, EntityRecognizer
from .syntax.arc_eager import ArcEager
from .syntax.ner import BiluoPushDown
class BaseDefaults(object): class BaseDefaults(object):
@ -65,7 +67,7 @@ class BaseDefaults(object):
prefix_search = util.compile_prefix_regex(cls.prefixes).search prefix_search = util.compile_prefix_regex(cls.prefixes).search
suffix_search = util.compile_suffix_regex(cls.suffixes).search suffix_search = util.compile_suffix_regex(cls.suffixes).search
infix_finditer = util.compile_infix_regex(cls.infixes).finditer infix_finditer = util.compile_infix_regex(cls.infixes).finditer
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp) vocab = nlp.vocab if nlp is not None else cls.Default.create_vocab(nlp)
return Tokenizer(nlp.vocab, rules=rules, return Tokenizer(nlp.vocab, rules=rules,
prefix_search=prefix_search, suffix_search=suffix_search, prefix_search=prefix_search, suffix_search=suffix_search,
infix_finditer=infix_finditer) infix_finditer=infix_finditer)
@ -82,26 +84,27 @@ class BaseDefaults(object):
return Tagger.load(nlp.path / 'pos', nlp.vocab) return Tagger.load(nlp.path / 'pos', nlp.vocab)
@classmethod @classmethod
def create_parser(cls, nlp=None): def create_parser(cls, nlp=None, **cfg):
if nlp is None: if nlp is None:
return DependencyParser(cls.create_vocab(), features=cls.parser_features) return DependencyParser(cls.create_vocab(), features=cls.parser_features,
**cfg)
elif nlp.path is False: elif nlp.path is False:
return DependencyParser(nlp.vocab, features=cls.parser_features) return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
elif nlp.path is None or not (nlp.path / 'deps').exists(): elif nlp.path is None or not (nlp.path / 'deps').exists():
return None return None
else: else:
return DependencyParser.load(nlp.path / 'deps', nlp.vocab) return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
@classmethod @classmethod
def create_entity(cls, nlp=None): def create_entity(cls, nlp=None, **cfg):
if nlp is None: if nlp is None:
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features) return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
elif nlp.path is False: elif nlp.path is False:
return EntityRecognizer(nlp.vocab, features=cls.entity_features) return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
elif nlp.path is None or not (nlp.path / 'ner').exists(): elif nlp.path is None or not (nlp.path / 'ner').exists():
return None return None
else: else:
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab) return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
@classmethod @classmethod
def create_matcher(cls, nlp=None): def create_matcher(cls, nlp=None):
@ -202,8 +205,8 @@ class Language(object):
# preprocess training data here before ArcEager.get_labels() is called # preprocess training data here before ArcEager.get_labels() is called
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples) gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples) parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples) entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('wb') as file_: with (dep_model_dir / 'config.json').open('wb') as file_:
json.dump(parser_cfg, file_) json.dump(parser_cfg, file_)
@ -224,22 +227,18 @@ class Language(object):
vectors=False, vectors=False,
pipeline=False) pipeline=False)
self.defaults.parser_labels = parser_cfg['labels'] self.vocab = self.Defaults.create_vocab(self)
self.defaults.entity_labels = entity_cfg['labels'] self.tokenizer = self.Defaults.create_tokenizer(self)
self.tagger = self.Defaults.create_tagger(self)
self.vocab = self.defaults.Vocab() self.parser = self.Defaults.create_parser(self)
self.tokenizer = self.defaults.Tokenizer(self.vocab) self.entity = self.Defaults.create_entity(self)
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg) self.pipeline = self.Defaults.create_pipeline(self)
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
self.pipeline = self.defaults.Pipeline(self)
yield Trainer(self, gold_tuples) yield Trainer(self, gold_tuples)
self.end_training() self.end_training()
def __init__(self, path=True, **overrides): def __init__(self, path=True, **overrides):
if 'data_dir' in overrides and 'path' not in overrides: if 'data_dir' in overrides and 'path' is True:
raise ValueError("The argument 'data_dir' has been renamed to 'path'") raise ValueError("The argument 'data_dir' has been renamed to 'path'")
path = overrides.get('path', True)
if isinstance(path, basestring): if isinstance(path, basestring):
path = pathlib.Path(path) path = pathlib.Path(path)
if path is True: if path is True:
@ -253,7 +252,7 @@ class Language(object):
add_vectors = self.Defaults.add_vectors(self) \ add_vectors = self.Defaults.add_vectors(self) \
if 'add_vectors' not in overrides \ if 'add_vectors' not in overrides \
else overrides['add_vectors'] else overrides['add_vectors']
if add_vectors: if self.vocab and add_vectors:
add_vectors(self.vocab) add_vectors(self.vocab)
self.tokenizer = self.Defaults.create_tokenizer(self) \ self.tokenizer = self.Defaults.create_tokenizer(self) \
if 'tokenizer' not in overrides \ if 'tokenizer' not in overrides \