mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Pass cfg through loading, for training.
This commit is contained in:
parent
608d8f5421
commit
a2f55e7015
|
@ -31,6 +31,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
|
||||||
from .syntax.parser import get_templates
|
from .syntax.parser import get_templates
|
||||||
from .syntax.nonproj import PseudoProjectivity
|
from .syntax.nonproj import PseudoProjectivity
|
||||||
from .pipeline import DependencyParser, EntityRecognizer
|
from .pipeline import DependencyParser, EntityRecognizer
|
||||||
|
from .syntax.arc_eager import ArcEager
|
||||||
|
from .syntax.ner import BiluoPushDown
|
||||||
|
|
||||||
|
|
||||||
class BaseDefaults(object):
|
class BaseDefaults(object):
|
||||||
|
@ -65,7 +67,7 @@ class BaseDefaults(object):
|
||||||
prefix_search = util.compile_prefix_regex(cls.prefixes).search
|
prefix_search = util.compile_prefix_regex(cls.prefixes).search
|
||||||
suffix_search = util.compile_suffix_regex(cls.suffixes).search
|
suffix_search = util.compile_suffix_regex(cls.suffixes).search
|
||||||
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
|
infix_finditer = util.compile_infix_regex(cls.infixes).finditer
|
||||||
vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
|
vocab = nlp.vocab if nlp is not None else cls.Default.create_vocab(nlp)
|
||||||
return Tokenizer(nlp.vocab, rules=rules,
|
return Tokenizer(nlp.vocab, rules=rules,
|
||||||
prefix_search=prefix_search, suffix_search=suffix_search,
|
prefix_search=prefix_search, suffix_search=suffix_search,
|
||||||
infix_finditer=infix_finditer)
|
infix_finditer=infix_finditer)
|
||||||
|
@ -82,26 +84,27 @@ class BaseDefaults(object):
|
||||||
return Tagger.load(nlp.path / 'pos', nlp.vocab)
|
return Tagger.load(nlp.path / 'pos', nlp.vocab)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_parser(cls, nlp=None):
|
def create_parser(cls, nlp=None, **cfg):
|
||||||
if nlp is None:
|
if nlp is None:
|
||||||
return DependencyParser(cls.create_vocab(), features=cls.parser_features)
|
return DependencyParser(cls.create_vocab(), features=cls.parser_features,
|
||||||
|
**cfg)
|
||||||
elif nlp.path is False:
|
elif nlp.path is False:
|
||||||
return DependencyParser(nlp.vocab, features=cls.parser_features)
|
return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
|
||||||
elif nlp.path is None or not (nlp.path / 'deps').exists():
|
elif nlp.path is None or not (nlp.path / 'deps').exists():
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
|
return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_entity(cls, nlp=None):
|
def create_entity(cls, nlp=None, **cfg):
|
||||||
if nlp is None:
|
if nlp is None:
|
||||||
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
|
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
|
||||||
elif nlp.path is False:
|
elif nlp.path is False:
|
||||||
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
|
return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
|
||||||
elif nlp.path is None or not (nlp.path / 'ner').exists():
|
elif nlp.path is None or not (nlp.path / 'ner').exists():
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
|
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_matcher(cls, nlp=None):
|
def create_matcher(cls, nlp=None):
|
||||||
|
@ -202,8 +205,8 @@ class Language(object):
|
||||||
# preprocess training data here before ArcEager.get_labels() is called
|
# preprocess training data here before ArcEager.get_labels() is called
|
||||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||||
|
|
||||||
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples)
|
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||||
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples)
|
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||||
|
|
||||||
with (dep_model_dir / 'config.json').open('wb') as file_:
|
with (dep_model_dir / 'config.json').open('wb') as file_:
|
||||||
json.dump(parser_cfg, file_)
|
json.dump(parser_cfg, file_)
|
||||||
|
@ -224,22 +227,18 @@ class Language(object):
|
||||||
vectors=False,
|
vectors=False,
|
||||||
pipeline=False)
|
pipeline=False)
|
||||||
|
|
||||||
self.defaults.parser_labels = parser_cfg['labels']
|
self.vocab = self.Defaults.create_vocab(self)
|
||||||
self.defaults.entity_labels = entity_cfg['labels']
|
self.tokenizer = self.Defaults.create_tokenizer(self)
|
||||||
|
self.tagger = self.Defaults.create_tagger(self)
|
||||||
self.vocab = self.defaults.Vocab()
|
self.parser = self.Defaults.create_parser(self)
|
||||||
self.tokenizer = self.defaults.Tokenizer(self.vocab)
|
self.entity = self.Defaults.create_entity(self)
|
||||||
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg)
|
self.pipeline = self.Defaults.create_pipeline(self)
|
||||||
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
|
|
||||||
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
|
|
||||||
self.pipeline = self.defaults.Pipeline(self)
|
|
||||||
yield Trainer(self, gold_tuples)
|
yield Trainer(self, gold_tuples)
|
||||||
self.end_training()
|
self.end_training()
|
||||||
|
|
||||||
def __init__(self, path=True, **overrides):
|
def __init__(self, path=True, **overrides):
|
||||||
if 'data_dir' in overrides and 'path' not in overrides:
|
if 'data_dir' in overrides and 'path' is True:
|
||||||
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
|
raise ValueError("The argument 'data_dir' has been renamed to 'path'")
|
||||||
path = overrides.get('path', True)
|
|
||||||
if isinstance(path, basestring):
|
if isinstance(path, basestring):
|
||||||
path = pathlib.Path(path)
|
path = pathlib.Path(path)
|
||||||
if path is True:
|
if path is True:
|
||||||
|
@ -253,7 +252,7 @@ class Language(object):
|
||||||
add_vectors = self.Defaults.add_vectors(self) \
|
add_vectors = self.Defaults.add_vectors(self) \
|
||||||
if 'add_vectors' not in overrides \
|
if 'add_vectors' not in overrides \
|
||||||
else overrides['add_vectors']
|
else overrides['add_vectors']
|
||||||
if add_vectors:
|
if self.vocab and add_vectors:
|
||||||
add_vectors(self.vocab)
|
add_vectors(self.vocab)
|
||||||
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
||||||
if 'tokenizer' not in overrides \
|
if 'tokenizer' not in overrides \
|
||||||
|
|
Loading…
Reference in New Issue
Block a user