mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
cleanup and formatting
This commit is contained in:
parent
0c35885751
commit
427dbecdd6
|
@ -71,9 +71,7 @@ def pretrain_cli(
|
|||
|
||||
with show_validation_error(config_path):
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
overrides=config_overrides,
|
||||
interpolate=True
|
||||
config_path, overrides=config_overrides, interpolate=True
|
||||
)
|
||||
if not config.get("pretraining"):
|
||||
# TODO: What's the solution here? How do we handle optional blocks?
|
||||
|
@ -84,7 +82,7 @@ def pretrain_cli(
|
|||
|
||||
config.to_disk(output_dir / "config.cfg")
|
||||
msg.good("Saved config file in the output directory")
|
||||
|
||||
|
||||
pretrain(
|
||||
config,
|
||||
output_dir,
|
||||
|
@ -99,7 +97,7 @@ def pretrain(
|
|||
output_dir: Path,
|
||||
resume_path: Optional[Path] = None,
|
||||
epoch_resume: Optional[int] = None,
|
||||
use_gpu: int=-1
|
||||
use_gpu: int = -1,
|
||||
):
|
||||
if config["system"].get("seed") is not None:
|
||||
fix_random_seed(config["system"]["seed"])
|
||||
|
@ -107,7 +105,7 @@ def pretrain(
|
|||
use_pytorch_for_gpu_memory()
|
||||
nlp, config = util.load_model_from_config(config)
|
||||
P_cfg = config["pretraining"]
|
||||
corpus = dot_to_object(config, config["pretraining"]["corpus"])
|
||||
corpus = dot_to_object(config, P_cfg["corpus"])
|
||||
batcher = P_cfg["batcher"]
|
||||
model = create_pretraining_model(nlp, config["pretraining"])
|
||||
optimizer = config["pretraining"]["optimizer"]
|
||||
|
@ -148,9 +146,7 @@ def pretrain(
|
|||
progress = tracker.update(epoch, loss, docs)
|
||||
if progress:
|
||||
msg.row(progress, **row_settings)
|
||||
if P_cfg["n_save_every"] and (
|
||||
batch_id % P_cfg["n_save_every"] == 0
|
||||
):
|
||||
if P_cfg["n_save_every"] and (batch_id % P_cfg["n_save_every"] == 0):
|
||||
_save_model(epoch, is_temp=True)
|
||||
_save_model(epoch)
|
||||
tracker.epoch_loss = 0.0
|
||||
|
|
|
@ -93,8 +93,8 @@ def train(
|
|||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||
T_cfg = config["training"]
|
||||
optimizer = T_cfg["optimizer"]
|
||||
train_corpus = dot_to_object(config, config["training"]["train_corpus"])
|
||||
dev_corpus = dot_to_object(config, config["training"]["dev_corpus"])
|
||||
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
|
||||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||
batcher = T_cfg["batcher"]
|
||||
train_logger = T_cfg["logger"]
|
||||
# Components that shouldn't be updated during training
|
||||
|
|
|
@ -104,7 +104,7 @@ class TokenPatternOperator(str, Enum):
|
|||
StringValue = Union[TokenPatternString, StrictStr]
|
||||
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
|
||||
UnderscoreValue = Union[
|
||||
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
|
||||
TokenPatternString, TokenPatternNumber, str, int, float, list, bool
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -26,12 +26,15 @@ def test_readers():
|
|||
[components.textcat]
|
||||
factory = "textcat"
|
||||
"""
|
||||
|
||||
@registry.readers.register("myreader.v1")
|
||||
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
||||
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
||||
|
||||
def reader(nlp: Language):
|
||||
doc = nlp.make_doc(f"This is an example")
|
||||
return [Example.from_dict(doc, annots)]
|
||||
|
||||
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
||||
|
||||
config = Config().from_str(config_string)
|
||||
|
|
Loading…
Reference in New Issue
Block a user