mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Update train.py
This commit is contained in:
		
							parent
							
								
									49145b9ec1
								
							
						
					
					
						commit
						17efd6bfec
					
				| 
						 | 
				
			
			@ -210,7 +210,8 @@ def train(
 | 
			
		|||
        nlp.resume_training()
 | 
			
		||||
    else:
 | 
			
		||||
        msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
 | 
			
		||||
        nlp.begin_training(lambda: corpus.train_dataset(nlp))
 | 
			
		||||
        train_examples = list(corpus.train_dataset(nlp, shuffle=False))
 | 
			
		||||
        nlp.begin_training(lambda: train_examples)
 | 
			
		||||
 | 
			
		||||
    # Update tag map with provided mapping
 | 
			
		||||
    nlp.vocab.morphology.tag_map.update(tag_map)
 | 
			
		||||
| 
						 | 
				
			
			@ -280,11 +281,14 @@ def train(
 | 
			
		|||
                eg.reference = None
 | 
			
		||||
                eg.predicted = None
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        msg.warn(
 | 
			
		||||
            f"Aborting and saving the final best model. "
 | 
			
		||||
            f"Encountered exception: {str(e)}",
 | 
			
		||||
            exits=1,
 | 
			
		||||
        )
 | 
			
		||||
        if output_path is not None:
 | 
			
		||||
            msg.warn(
 | 
			
		||||
                f"Aborting and saving the final best model. "
 | 
			
		||||
                f"Encountered exception: {str(e)}",
 | 
			
		||||
                exits=1,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise e
 | 
			
		||||
    finally:
 | 
			
		||||
        if output_path is not None:
 | 
			
		||||
            final_model_path = output_path / "model-final"
 | 
			
		||||
| 
						 | 
				
			
			@ -300,7 +304,6 @@ def create_train_batches(nlp, corpus, cfg):
 | 
			
		|||
    epochs_todo = cfg.get("max_epochs", 0)
 | 
			
		||||
    while True:
 | 
			
		||||
        train_examples = list(corpus.train_dataset(nlp))
 | 
			
		||||
 | 
			
		||||
        if len(train_examples) == 0:
 | 
			
		||||
            raise ValueError(Errors.E988)
 | 
			
		||||
        random.shuffle(train_examples)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user