mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Pass cfg through loading, for training.
This commit is contained in:
parent
608d8f5421
commit
a2f55e7015
|
@ -31,6 +31,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
|
|||
from .syntax.parser import get_templates
|
||||
from .syntax.nonproj import PseudoProjectivity
|
||||
from .pipeline import DependencyParser, EntityRecognizer
|
||||
from .syntax.arc_eager import ArcEager
|
||||
from .syntax.ner import BiluoPushDown
|
||||
|
||||
|
||||
class BaseDefaults(object):
|
||||
|
@ -65,7 +67,7 @@ class BaseDefaults(object):
|
|||
prefix_search = util.compile_prefix_regex(cls.prefixes).search
|
||||
suffix_search = util.compile_suffix_regex(cls.suffixes).search
|
||||
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
|
||||
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
|
||||
vocab = nlp.vocab if nlp is not None else cls.Default.create_vocab(nlp)
|
||||
return Tokenizer(nlp.vocab, rules=rules,
|
||||
prefix_search=prefix_search, suffix_search=suffix_search,
|
||||
infix_finditer=infix_finditer)
|
||||
|
@ -82,26 +84,27 @@ class BaseDefaults(object):
|
|||
return Tagger.load(nlp.path / 'pos', nlp.vocab)
|
||||
|
||||
@classmethod
|
||||
def create_parser(cls, nlp=None):
|
||||
def create_parser(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return DependencyParser(cls.create_vocab(), features=cls.parser_features)
|
||||
return DependencyParser(cls.create_vocab(), features=cls.parser_features,
|
||||
**cfg)
|
||||
elif nlp.path is False:
|
||||
return DependencyParser(nlp.vocab, features=cls.parser_features)
|
||||
return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
|
||||
elif nlp.path is None or not (nlp.path / 'deps').exists():
|
||||
return None
|
||||
else:
|
||||
return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
|
||||
return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_entity(cls, nlp=None):
|
||||
def create_entity(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
|
||||
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
|
||||
elif nlp.path is False:
|
||||
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
|
||||
return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
|
||||
elif nlp.path is None or not (nlp.path / 'ner').exists():
|
||||
return None
|
||||
else:
|
||||
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
|
||||
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_matcher(cls, nlp=None):
|
||||
|
@ -202,8 +205,8 @@ class Language(object):
|
|||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
|
||||
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples)
|
||||
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples)
|
||||
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
|
||||
with (dep_model_dir / 'config.json').open('wb') as file_:
|
||||
json.dump(parser_cfg, file_)
|
||||
|
@ -224,22 +227,18 @@ class Language(object):
|
|||
vectors=False,
|
||||
pipeline=False)
|
||||
|
||||
self.defaults.parser_labels = parser_cfg['labels']
|
||||
self.defaults.entity_labels = entity_cfg['labels']
|
||||
|
||||
self.vocab = self.defaults.Vocab()
|
||||
self.tokenizer = self.defaults.Tokenizer(self.vocab)
|
||||
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg)
|
||||
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
|
||||
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
|
||||
self.pipeline = self.defaults.Pipeline(self)
|
||||
self.vocab = self.Defaults.create_vocab(self)
|
||||
self.tokenizer = self.Defaults.create_tokenizer(self)
|
||||
self.tagger = self.Defaults.create_tagger(self)
|
||||
self.parser = self.Defaults.create_parser(self)
|
||||
self.entity = self.Defaults.create_entity(self)
|
||||
self.pipeline = self.Defaults.create_pipeline(self)
|
||||
yield Trainer(self, gold_tuples)
|
||||
self.end_training()
|
||||
|
||||
def __init__(self, path=True, **overrides):
|
||||
if 'data_dir' in overrides and 'path' not in overrides:
|
||||
if 'data_dir' in overrides and 'path' is True:
|
||||
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
|
||||
path = overrides.get('path', True)
|
||||
if isinstance(path, basestring):
|
||||
path = pathlib.Path(path)
|
||||
if path is True:
|
||||
|
@ -253,7 +252,7 @@ class Language(object):
|
|||
add_vectors = self.Defaults.add_vectors(self) \
|
||||
if 'add_vectors' not in overrides \
|
||||
else overrides['add_vectors']
|
||||
if add_vectors:
|
||||
if self.vocab and add_vectors:
|
||||
add_vectors(self.vocab)
|
||||
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
||||
if 'tokenizer' not in overrides \
|
||||
|
|
Loading…
Reference in New Issue
Block a user