cleanup and formatting

This commit is contained in:
svlandeg 2020-09-17 11:48:04 +02:00
parent 0c35885751
commit 427dbecdd6
4 changed files with 11 additions and 12 deletions

View File

@ -71,9 +71,7 @@ def pretrain_cli(
with show_validation_error(config_path):
config = util.load_config(
config_path,
overrides=config_overrides,
interpolate=True
config_path, overrides=config_overrides, interpolate=True
)
if not config.get("pretraining"):
# TODO: What's the solution here? How do we handle optional blocks?
@ -84,7 +82,7 @@ def pretrain_cli(
config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory")
pretrain(
config,
output_dir,
@ -99,7 +97,7 @@ def pretrain(
output_dir: Path,
resume_path: Optional[Path] = None,
epoch_resume: Optional[int] = None,
use_gpu: int=-1
use_gpu: int = -1,
):
if config["system"].get("seed") is not None:
fix_random_seed(config["system"]["seed"])
@ -107,7 +105,7 @@ def pretrain(
use_pytorch_for_gpu_memory()
nlp, config = util.load_model_from_config(config)
P_cfg = config["pretraining"]
corpus = dot_to_object(config, config["pretraining"]["corpus"])
corpus = dot_to_object(config, P_cfg["corpus"])
batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"])
optimizer = config["pretraining"]["optimizer"]
@ -148,9 +146,7 @@ def pretrain(
progress = tracker.update(epoch, loss, docs)
if progress:
msg.row(progress, **row_settings)
if P_cfg["n_save_every"] and (
batch_id % P_cfg["n_save_every"] == 0
):
if P_cfg["n_save_every"] and (batch_id % P_cfg["n_save_every"] == 0):
_save_model(epoch, is_temp=True)
_save_model(epoch)
tracker.epoch_loss = 0.0

View File

@ -93,8 +93,8 @@ def train(
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
T_cfg = config["training"]
optimizer = T_cfg["optimizer"]
train_corpus = dot_to_object(config, config["training"]["train_corpus"])
dev_corpus = dot_to_object(config, config["training"]["dev_corpus"])
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"]
# Components that shouldn't be updated during training

View File

@ -104,7 +104,7 @@ class TokenPatternOperator(str, Enum):
StringValue = Union[TokenPatternString, StrictStr]
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
UnderscoreValue = Union[
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
TokenPatternString, TokenPatternNumber, str, int, float, list, bool
]

View File

@ -26,12 +26,15 @@ def test_readers():
[components.textcat]
factory = "textcat"
"""
@registry.readers.register("myreader.v1")
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
def reader(nlp: Language):
doc = nlp.make_doc(f"This is an example")
return [Example.from_dict(doc, annots)]
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string)