mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Update CLI and add [initialize] block
This commit is contained in:
parent
d5155376fd
commit
e44a7519cd
|
@ -98,7 +98,7 @@ universal = false
|
||||||
formats = gztar
|
formats = gztar
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
ignore = E203, E266, E501, E731, W503
|
ignore = E203, E266, E501, E731, W503, E741
|
||||||
max-line-length = 80
|
max-line-length = 80
|
||||||
select = B,C,E,F,W,T4,B9
|
select = B,C,E,F,W,T4,B9
|
||||||
exclude =
|
exclude =
|
||||||
|
|
|
@ -459,24 +459,3 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
|
||||||
p = int(p)
|
p = int(p)
|
||||||
result.append(p)
|
result.append(p)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def load_from_paths(
|
|
||||||
config: Config,
|
|
||||||
) -> Tuple[List[Dict[str, str]], Dict[str, dict], bytes]:
|
|
||||||
# TODO: separate checks from loading
|
|
||||||
raw_text = ensure_path(config["training"]["raw_text"])
|
|
||||||
if raw_text is not None:
|
|
||||||
if not raw_text.exists():
|
|
||||||
msg.fail("Can't find raw text", raw_text, exits=1)
|
|
||||||
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
|
|
||||||
tag_map = {}
|
|
||||||
morph_rules = {}
|
|
||||||
weights_data = None
|
|
||||||
init_tok2vec = ensure_path(config["training"]["init_tok2vec"])
|
|
||||||
if init_tok2vec is not None:
|
|
||||||
if not init_tok2vec.exists():
|
|
||||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
|
||||||
with init_tok2vec.open("rb") as file_:
|
|
||||||
weights_data = file_.read()
|
|
||||||
return raw_text, tag_map, morph_rules, weights_data
|
|
||||||
|
|
|
@ -8,12 +8,12 @@ import srsly
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..util import registry, resolve_dot_names, OOV_RANK
|
from ..util import registry, resolve_dot_names, OOV_RANK
|
||||||
from ..schemas import ConfigSchemaTraining, ConfigSchemaPretrain
|
from ..schemas import ConfigSchemaTraining, ConfigSchemaPretrain, ConfigSchemaInit
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code, get_sourced_components, load_from_paths
|
from ._util import import_code, get_sourced_components
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_OOV_PROB = -20
|
DEFAULT_OOV_PROB = -20
|
||||||
|
@ -67,14 +67,15 @@ def init_pipeline(config: Config, use_gpu: int = -1) -> Language:
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced_components = get_sourced_components(config)
|
sourced_components = get_sourced_components(config)
|
||||||
with show_validation_error():
|
with show_validation_error():
|
||||||
nlp = util.load_model_from_config(raw_config)
|
nlp = util.load_model_from_config(raw_config, auto_fill=True)
|
||||||
msg.good("Set up nlp object from config")
|
msg.good("Set up nlp object from config")
|
||||||
|
config = nlp.config.interpolate()
|
||||||
# Resolve all training-relevant sections using the filled nlp config
|
# Resolve all training-relevant sections using the filled nlp config
|
||||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
||||||
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
||||||
# TODO: move lookups to [initialize], add vocab data
|
I = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
|
||||||
init_vocab(nlp, lookups=T["lookups"])
|
init_vocab(nlp, data=I["vocab"]["data"], lookups=I["vocab"]["lookups"])
|
||||||
msg.good("Created vocabulary")
|
msg.good("Created vocabulary")
|
||||||
if T["vectors"] is not None:
|
if T["vectors"] is not None:
|
||||||
add_vectors(nlp, T["vectors"])
|
add_vectors(nlp, T["vectors"])
|
||||||
|
@ -98,22 +99,19 @@ def init_pipeline(config: Config, use_gpu: int = -1) -> Language:
|
||||||
verify_config(nlp)
|
verify_config(nlp)
|
||||||
if "pretraining" in config and config["pretraining"]:
|
if "pretraining" in config and config["pretraining"]:
|
||||||
P = registry.resolve(config["pretraining"], schema=ConfigSchemaPretrain)
|
P = registry.resolve(config["pretraining"], schema=ConfigSchemaPretrain)
|
||||||
add_tok2vec_weights({"training": T, "pretraining": P}, nlp)
|
add_tok2vec_weights(nlp, P, I)
|
||||||
# TODO: this should be handled better?
|
# TODO: this should be handled better?
|
||||||
nlp = before_to_disk(nlp)
|
nlp = before_to_disk(nlp)
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
def init_vocab(
|
def init_vocab(
|
||||||
nlp: Language,
|
nlp: Language, *, data: Optional[Path] = None, lookups: Optional[Lookups] = None,
|
||||||
*,
|
|
||||||
vocab_data: Optional[Path] = None,
|
|
||||||
lookups: Optional[Lookups] = None,
|
|
||||||
) -> Language:
|
) -> Language:
|
||||||
if lookups:
|
if lookups:
|
||||||
nlp.vocab.lookups = lookups
|
nlp.vocab.lookups = lookups
|
||||||
msg.good(f"Added vocab lookups: {', '.join(lookups.tables)}")
|
msg.good(f"Added vocab lookups: {', '.join(lookups.tables)}")
|
||||||
data_path = util.ensure_path(vocab_data)
|
data_path = util.ensure_path(data)
|
||||||
if data_path is not None:
|
if data_path is not None:
|
||||||
lex_attrs = srsly.read_jsonl(data_path)
|
lex_attrs = srsly.read_jsonl(data_path)
|
||||||
for lexeme in nlp.vocab:
|
for lexeme in nlp.vocab:
|
||||||
|
@ -131,11 +129,29 @@ def init_vocab(
|
||||||
msg.good(f"Added {len(nlp.vocab)} lexical entries to the vocab")
|
msg.good(f"Added {len(nlp.vocab)} lexical entries to the vocab")
|
||||||
|
|
||||||
|
|
||||||
def add_tok2vec_weights(config: Config, nlp: Language) -> None:
|
def add_tok2vec_weights(
|
||||||
|
nlp: Language, pretrain_config: Dict[str, Any], init_config: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
# Load pretrained tok2vec weights - cf. CLI command 'pretrain'
|
# Load pretrained tok2vec weights - cf. CLI command 'pretrain'
|
||||||
weights_data = load_from_paths(config)
|
P = pretrain_config
|
||||||
|
I = init_config
|
||||||
|
raw_text = util.ensure_path(I["vocab"]["raw_text"])
|
||||||
|
if raw_text is not None:
|
||||||
|
if not raw_text.exists():
|
||||||
|
msg.fail("Can't find raw text", raw_text, exits=1)
|
||||||
|
raw_text = list(srsly.read_jsonl(raw_text))
|
||||||
|
weights_data = None
|
||||||
|
init_tok2vec = util.ensure_path(I["vocab"]["init_tok2vec"])
|
||||||
|
if init_tok2vec is not None:
|
||||||
|
if P["objective"].get("type") == "vectors" and not I["vectors"]:
|
||||||
|
err = "Need initialize.vectors if pretraining.objective.type is vectors"
|
||||||
|
msg.fail(err, exits=1)
|
||||||
|
if not init_tok2vec.exists():
|
||||||
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
|
with init_tok2vec.open("rb") as file_:
|
||||||
|
weights_data = file_.read()
|
||||||
if weights_data is not None:
|
if weights_data is not None:
|
||||||
tok2vec_component = config["pretraining"]["component"]
|
tok2vec_component = P["component"]
|
||||||
if tok2vec_component is None:
|
if tok2vec_component is None:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"To use pretrained tok2vec weights, [pretraining.component] "
|
f"To use pretrained tok2vec weights, [pretraining.component] "
|
||||||
|
@ -143,9 +159,8 @@ def add_tok2vec_weights(config: Config, nlp: Language) -> None:
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
layer = nlp.get_pipe(tok2vec_component).model
|
layer = nlp.get_pipe(tok2vec_component).model
|
||||||
tok2vec_layer = config["pretraining"]["layer"]
|
if P["layer"]:
|
||||||
if tok2vec_layer:
|
layer = layer.get_ref(P["layer"])
|
||||||
layer = layer.get_ref(tok2vec_layer)
|
|
||||||
layer.from_bytes(weights_data)
|
layer.from_bytes(weights_data)
|
||||||
msg.good(f"Loaded pretrained weights into component '{tok2vec_component}'")
|
msg.good(f"Loaded pretrained weights into component '{tok2vec_component}'")
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ from .init_pipeline import init_pipeline, must_initialize
|
||||||
from .init_pipeline import create_before_to_disk_callback
|
from .init_pipeline import create_before_to_disk_callback
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code
|
from ._util import import_code
|
||||||
from ._util import load_from_paths # noqa: F401 (needed for Ray extension for now)
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..training.example import Example
|
from ..training.example import Example
|
||||||
|
@ -381,3 +380,26 @@ def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> No
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
msg.good(f"Created output directory: {output_path}")
|
msg.good(f"Created output directory: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this is currently imported by the ray extension and not used otherwise
|
||||||
|
def load_from_paths(
|
||||||
|
config: Config,
|
||||||
|
) -> Tuple[List[Dict[str, str]], Dict[str, dict], bytes]:
|
||||||
|
import srsly
|
||||||
|
# TODO: separate checks from loading
|
||||||
|
raw_text = util.ensure_path(config["training"]["raw_text"])
|
||||||
|
if raw_text is not None:
|
||||||
|
if not raw_text.exists():
|
||||||
|
msg.fail("Can't find raw text", raw_text, exits=1)
|
||||||
|
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
|
||||||
|
tag_map = {}
|
||||||
|
morph_rules = {}
|
||||||
|
weights_data = None
|
||||||
|
init_tok2vec = util.ensure_path(config["training"]["init_tok2vec"])
|
||||||
|
if init_tok2vec is not None:
|
||||||
|
if not init_tok2vec.exists():
|
||||||
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
|
with init_tok2vec.open("rb") as file_:
|
||||||
|
weights_data = file_.read()
|
||||||
|
return raw_text, tag_map, morph_rules, weights_data
|
||||||
|
|
|
@ -108,3 +108,15 @@ grad_clip = 1.0
|
||||||
use_averages = false
|
use_averages = false
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
learn_rate = 0.001
|
learn_rate = 0.001
|
||||||
|
|
||||||
|
[initialize]
|
||||||
|
tokenizer = {}
|
||||||
|
components = {}
|
||||||
|
|
||||||
|
[initialize.vocab]
|
||||||
|
data = null
|
||||||
|
lookups = null
|
||||||
|
vectors = null
|
||||||
|
# Extra resources for transfer-learning or pseudo-rehearsal
|
||||||
|
init_tok2vec = ${paths.init_tok2vec}
|
||||||
|
raw_text = ${paths.raw}
|
||||||
|
|
|
@ -273,22 +273,37 @@ class ConfigSchemaPretrain(BaseModel):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaInitVocab(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
data: Optional[str] = Field(..., title="Path to JSON-formatted vocabulary file")
|
||||||
|
lookups: Optional[Lookups] = Field(..., title="Vocabulary lookups, e.g. lexeme normalization")
|
||||||
|
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
|
||||||
|
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
||||||
|
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaInit(BaseModel):
|
||||||
|
vocab: ConfigSchemaInitVocab
|
||||||
|
tokenizer: Any
|
||||||
|
components: Dict[str, Any]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class ConfigSchema(BaseModel):
|
class ConfigSchema(BaseModel):
|
||||||
training: ConfigSchemaTraining
|
training: ConfigSchemaTraining
|
||||||
nlp: ConfigSchemaNlp
|
nlp: ConfigSchemaNlp
|
||||||
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
||||||
components: Dict[str, Dict[str, Any]]
|
components: Dict[str, Dict[str, Any]]
|
||||||
corpora: Dict[str, Reader]
|
corpora: Dict[str, Reader]
|
||||||
|
initialize: ConfigSchemaInit
|
||||||
@root_validator(allow_reuse=True)
|
|
||||||
def validate_config(cls, values):
|
|
||||||
"""Perform additional validation for settings with dependencies."""
|
|
||||||
pt = values.get("pretraining")
|
|
||||||
if pt and not isinstance(pt, ConfigSchemaPretrainEmpty):
|
|
||||||
if pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
|
|
||||||
err = "Need nlp.vectors if pretraining.objective.type is vectors"
|
|
||||||
raise ValueError(err)
|
|
||||||
return values
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
|
@ -61,7 +61,7 @@ LEXEME_NORM_LANGS = ["da", "de", "el", "en", "id", "lb", "pt", "ru", "sr", "ta",
|
||||||
# Default order of sections in the config.cfg. Not all sections needs to exist,
|
# Default order of sections in the config.cfg. Not all sections needs to exist,
|
||||||
# and additional sections are added at the end, in alphabetical order.
|
# and additional sections are added at the end, in alphabetical order.
|
||||||
# fmt: off
|
# fmt: off
|
||||||
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining"]
|
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user