mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	cleanup and formatting
This commit is contained in:
		
							parent
							
								
									0c35885751
								
							
						
					
					
						commit
						427dbecdd6
					
				| 
						 | 
					@ -71,9 +71,7 @@ def pretrain_cli(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with show_validation_error(config_path):
 | 
					    with show_validation_error(config_path):
 | 
				
			||||||
        config = util.load_config(
 | 
					        config = util.load_config(
 | 
				
			||||||
            config_path,
 | 
					            config_path, overrides=config_overrides, interpolate=True
 | 
				
			||||||
            overrides=config_overrides,
 | 
					 | 
				
			||||||
            interpolate=True
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    if not config.get("pretraining"):
 | 
					    if not config.get("pretraining"):
 | 
				
			||||||
        # TODO: What's the solution here? How do we handle optional blocks?
 | 
					        # TODO: What's the solution here? How do we handle optional blocks?
 | 
				
			||||||
| 
						 | 
					@ -99,7 +97,7 @@ def pretrain(
 | 
				
			||||||
    output_dir: Path,
 | 
					    output_dir: Path,
 | 
				
			||||||
    resume_path: Optional[Path] = None,
 | 
					    resume_path: Optional[Path] = None,
 | 
				
			||||||
    epoch_resume: Optional[int] = None,
 | 
					    epoch_resume: Optional[int] = None,
 | 
				
			||||||
    use_gpu: int=-1
 | 
					    use_gpu: int = -1,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if config["system"].get("seed") is not None:
 | 
					    if config["system"].get("seed") is not None:
 | 
				
			||||||
        fix_random_seed(config["system"]["seed"])
 | 
					        fix_random_seed(config["system"]["seed"])
 | 
				
			||||||
| 
						 | 
					@ -107,7 +105,7 @@ def pretrain(
 | 
				
			||||||
        use_pytorch_for_gpu_memory()
 | 
					        use_pytorch_for_gpu_memory()
 | 
				
			||||||
    nlp, config = util.load_model_from_config(config)
 | 
					    nlp, config = util.load_model_from_config(config)
 | 
				
			||||||
    P_cfg = config["pretraining"]
 | 
					    P_cfg = config["pretraining"]
 | 
				
			||||||
    corpus = dot_to_object(config, config["pretraining"]["corpus"])
 | 
					    corpus = dot_to_object(config, P_cfg["corpus"])
 | 
				
			||||||
    batcher = P_cfg["batcher"]
 | 
					    batcher = P_cfg["batcher"]
 | 
				
			||||||
    model = create_pretraining_model(nlp, config["pretraining"])
 | 
					    model = create_pretraining_model(nlp, config["pretraining"])
 | 
				
			||||||
    optimizer = config["pretraining"]["optimizer"]
 | 
					    optimizer = config["pretraining"]["optimizer"]
 | 
				
			||||||
| 
						 | 
					@ -148,9 +146,7 @@ def pretrain(
 | 
				
			||||||
            progress = tracker.update(epoch, loss, docs)
 | 
					            progress = tracker.update(epoch, loss, docs)
 | 
				
			||||||
            if progress:
 | 
					            if progress:
 | 
				
			||||||
                msg.row(progress, **row_settings)
 | 
					                msg.row(progress, **row_settings)
 | 
				
			||||||
            if P_cfg["n_save_every"] and (
 | 
					            if P_cfg["n_save_every"] and (batch_id % P_cfg["n_save_every"] == 0):
 | 
				
			||||||
                batch_id % P_cfg["n_save_every"] == 0
 | 
					 | 
				
			||||||
            ):
 | 
					 | 
				
			||||||
                _save_model(epoch, is_temp=True)
 | 
					                _save_model(epoch, is_temp=True)
 | 
				
			||||||
        _save_model(epoch)
 | 
					        _save_model(epoch)
 | 
				
			||||||
        tracker.epoch_loss = 0.0
 | 
					        tracker.epoch_loss = 0.0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,8 +93,8 @@ def train(
 | 
				
			||||||
    raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
 | 
					    raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
 | 
				
			||||||
    T_cfg = config["training"]
 | 
					    T_cfg = config["training"]
 | 
				
			||||||
    optimizer = T_cfg["optimizer"]
 | 
					    optimizer = T_cfg["optimizer"]
 | 
				
			||||||
    train_corpus = dot_to_object(config, config["training"]["train_corpus"])
 | 
					    train_corpus = dot_to_object(config, T_cfg["train_corpus"])
 | 
				
			||||||
    dev_corpus = dot_to_object(config, config["training"]["dev_corpus"])
 | 
					    dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
 | 
				
			||||||
    batcher = T_cfg["batcher"]
 | 
					    batcher = T_cfg["batcher"]
 | 
				
			||||||
    train_logger = T_cfg["logger"]
 | 
					    train_logger = T_cfg["logger"]
 | 
				
			||||||
    # Components that shouldn't be updated during training
 | 
					    # Components that shouldn't be updated during training
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -104,7 +104,7 @@ class TokenPatternOperator(str, Enum):
 | 
				
			||||||
StringValue = Union[TokenPatternString, StrictStr]
 | 
					StringValue = Union[TokenPatternString, StrictStr]
 | 
				
			||||||
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
 | 
					NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
 | 
				
			||||||
UnderscoreValue = Union[
 | 
					UnderscoreValue = Union[
 | 
				
			||||||
    TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
 | 
					    TokenPatternString, TokenPatternNumber, str, int, float, list, bool
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,12 +26,15 @@ def test_readers():
 | 
				
			||||||
    [components.textcat]
 | 
					    [components.textcat]
 | 
				
			||||||
    factory = "textcat"
 | 
					    factory = "textcat"
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @registry.readers.register("myreader.v1")
 | 
					    @registry.readers.register("myreader.v1")
 | 
				
			||||||
    def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
 | 
					    def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
 | 
				
			||||||
        annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
 | 
					        annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def reader(nlp: Language):
 | 
					        def reader(nlp: Language):
 | 
				
			||||||
            doc = nlp.make_doc(f"This is an example")
 | 
					            doc = nlp.make_doc(f"This is an example")
 | 
				
			||||||
            return [Example.from_dict(doc, annots)]
 | 
					            return [Example.from_dict(doc, annots)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return {"train": reader, "dev": reader, "extra": reader, "something": reader}
 | 
					        return {"train": reader, "dev": reader, "extra": reader, "something": reader}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = Config().from_str(config_string)
 | 
					    config = Config().from_str(config_string)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user