mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	train_ner should save vocab; add load_ner example
This commit is contained in:
		
							parent
							
								
									5ad5408242
								
							
						
					
					
						commit
						ad54a929f8
					
				|  | @ -10,6 +10,13 @@ from spacy.tagger import Tagger | |||
| 
 | ||||
| 
 | ||||
| def train_ner(nlp, train_data, entity_types): | ||||
|     # Add new words to vocab. | ||||
|     for raw_text, _ in train_data: | ||||
|         doc = nlp.make_doc(raw_text) | ||||
|         for word in doc: | ||||
|             _ = nlp.vocab[word.orth] | ||||
| 
 | ||||
|     # Train NER. | ||||
|     ner = EntityRecognizer(nlp.vocab, entity_types=entity_types) | ||||
|     for itn in range(5): | ||||
|         random.shuffle(train_data) | ||||
|  | @ -20,21 +27,30 @@ def train_ner(nlp, train_data, entity_types): | |||
|     ner.model.end_training() | ||||
|     return ner | ||||
| 
 | ||||
| def save_model(ner, model_dir): | ||||
|     model_dir = pathlib.Path(model_dir) | ||||
|     if not model_dir.exists(): | ||||
|         model_dir.mkdir() | ||||
|     assert model_dir.is_dir() | ||||
| 
 | ||||
|     with (model_dir / 'config.json').open('w') as file_: | ||||
|         json.dump(ner.cfg, file_) | ||||
|     ner.model.dump(str(model_dir / 'model')) | ||||
|     if not (model_dir / 'vocab').exists(): | ||||
|         (model_dir / 'vocab').mkdir() | ||||
|     ner.vocab.dump(str(model_dir / 'vocab' / 'lexemes.bin')) | ||||
|     with (model_dir / 'vocab' / 'strings.json').open('w', encoding='utf8') as file_: | ||||
|         ner.vocab.strings.dump(file_) | ||||
| 
 | ||||
| 
 | ||||
| def main(model_dir=None): | ||||
|     if model_dir is not None: | ||||
|         model_dir = pathlib.Path(model_dir) | ||||
|         if not model_dir.exists(): | ||||
|             model_dir.mkdir() | ||||
|         assert model_dir.is_dir() | ||||
| 
 | ||||
|     nlp = spacy.load('en', parser=False, entity=False, add_vectors=False) | ||||
| 
 | ||||
|     # v1.1.2 onwards | ||||
|     if nlp.tagger is None: | ||||
|         print('---- WARNING ----') | ||||
|         print('Data directory not found') | ||||
|         print('please run: `python -m spacy.en.download –force all` for better performance') | ||||
|         print('please run: `python -m spacy.en.download --force all` for better performance') | ||||
|         print('Using feature templates for tagging') | ||||
|         print('-----------------') | ||||
|         nlp.tagger = Tagger(nlp.vocab, features=Tagger.feature_templates) | ||||
|  | @ -56,16 +72,17 @@ def main(model_dir=None): | |||
|     nlp.tagger(doc) | ||||
|     ner(doc) | ||||
|     for word in doc: | ||||
|         print(word.text, word.tag_, word.ent_type_, word.ent_iob) | ||||
|         print(word.text, word.orth, word.lower, word.tag_, word.ent_type_, word.ent_iob) | ||||
| 
 | ||||
|     if model_dir is not None: | ||||
|         with (model_dir / 'config.json').open('w') as file_: | ||||
|             json.dump(ner.cfg, file_) | ||||
|         ner.model.dump(str(model_dir / 'model')) | ||||
|         save_model(ner, model_dir) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|     main('ner') | ||||
|     # Who "" 2 | ||||
|     # is "" 2 | ||||
|     # Shaka "" PERSON 3 | ||||
|  |  | |||
|  | @ -69,7 +69,7 @@ def main(output_dir=None): | |||
|         print(word.text, word.tag_, word.pos_) | ||||
|     if output_dir is not None: | ||||
|         tagger.model.dump(str(output_dir / 'pos' / 'model')) | ||||
|         with (output_dir / 'vocab' / 'strings.json').open('wb') as file_: | ||||
|         with (output_dir / 'vocab' / 'strings.json').open('w') as file_: | ||||
|             tagger.vocab.strings.dump(file_) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user