mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	Update train_ner example for spaCy v2.0
This commit is contained in:
		
							parent
							
								
									e904075f35
								
							
						
					
					
						commit
						9d58673aaf
					
				|  | @ -1,13 +1,104 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # coding: utf8 | ||||||
|  | """ | ||||||
|  | Example of training spaCy's named entity recognizer, starting off with an | ||||||
|  | existing model or a blank model. | ||||||
|  | 
 | ||||||
|  | For more details, see the documentation: | ||||||
|  | * Training: https://alpha.spacy.io/usage/training | ||||||
|  | * NER: https://alpha.spacy.io/usage/linguistic-features#named-entities | ||||||
|  | 
 | ||||||
|  | Developed for: spaCy 2.0.0a18 | ||||||
|  | Last updated for: spaCy 2.0.0a18 | ||||||
|  | """ | ||||||
| from __future__ import unicode_literals, print_function | from __future__ import unicode_literals, print_function | ||||||
| 
 | 
 | ||||||
| import random | import random | ||||||
|  | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from spacy.lang.en import English | import spacy | ||||||
| from spacy.gold import GoldParse, biluo_tags_from_offsets | from spacy.gold import GoldParse, biluo_tags_from_offsets | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # training data | ||||||
|  | TRAIN_DATA = [ | ||||||
|  |     ('Who is Shaka Khan?', [(7, 17, 'PERSON')]), | ||||||
|  |     ('I like London and Berlin.', [(7, 13, 'LOC'), (18, 24, 'LOC')]) | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(model=None, output_dir=None, n_iter=100): | ||||||
|  |     """Load the model, set up the pipeline and train the entity recognizer. | ||||||
|  | 
 | ||||||
|  |     model (unicode): Model name to start off with. If None, a blank English | ||||||
|  |         Language class is created. | ||||||
|  |     output_dir (unicode / Path): Optional output directory. If None, no model | ||||||
|  |         will be saved. | ||||||
|  |     n_iter (int): Number of iterations during training. | ||||||
|  |     """ | ||||||
|  |     if model is not None: | ||||||
|  |         nlp = spacy.load(model)  # load existing spaCy model | ||||||
|  |         print("Loaded model '%s'" % model) | ||||||
|  |     else: | ||||||
|  |         nlp = spacy.blank('en')  # create blank Language class | ||||||
|  |         print("Created blank 'en' model") | ||||||
|  | 
 | ||||||
|  |     # create the built-in pipeline components and add them to the pipeline | ||||||
|  |     # ner.create_pipe works for built-ins that are registered with spaCy! | ||||||
|  |     if 'ner' not in nlp.pipe_names: | ||||||
|  |         ner = nlp.create_pipe('ner') | ||||||
|  |         nlp.add_pipe(ner, last=True) | ||||||
|  | 
 | ||||||
|  |     # function that allows begin_training to get the training data | ||||||
|  |     get_data = lambda: reformat_train_data(nlp.tokenizer, TRAIN_DATA) | ||||||
|  | 
 | ||||||
|  |     # get names of other pipes to disable them during training | ||||||
|  |     other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner'] | ||||||
|  |     with nlp.disable_pipes(*other_pipes) as disabled:  # only train NER | ||||||
|  |         optimizer = nlp.begin_training(get_data) | ||||||
|  |         for itn in range(n_iter): | ||||||
|  |             random.shuffle(TRAIN_DATA) | ||||||
|  |             losses = {} | ||||||
|  |             for raw_text, entity_offsets in TRAIN_DATA: | ||||||
|  |                 doc = nlp.make_doc(raw_text) | ||||||
|  |                 gold = GoldParse(doc, entities=entity_offsets) | ||||||
|  |                 nlp.update( | ||||||
|  |                     [doc], # Batch of Doc objects | ||||||
|  |                     [gold], # Batch of GoldParse objects | ||||||
|  |                     drop=0.5, # Dropout -- make it harder to memorise data | ||||||
|  |                     sgd=optimizer, # Callable to update weights | ||||||
|  |                     losses=losses) | ||||||
|  |             print(losses) | ||||||
|  | 
 | ||||||
|  |     # test the trained model | ||||||
|  |     for text, _ in TRAIN_DATA: | ||||||
|  |         doc = nlp(text) | ||||||
|  |         print('Entities', [(ent.text, ent.label_) for ent in doc.ents]) | ||||||
|  |         print('Tokens', [(t.text, t.ent_type_, t.ent_iob) for t in doc]) | ||||||
|  | 
 | ||||||
|  |     # save model to output directory | ||||||
|  |     if output_dir is not None: | ||||||
|  |         output_dir = Path(output_dir) | ||||||
|  |         if not output_dir.exists(): | ||||||
|  |             output_dir.mkdir() | ||||||
|  |         nlp.to_disk(output_dir) | ||||||
|  |         print("Saved model to", output_dir) | ||||||
|  | 
 | ||||||
|  |         # test the saved model | ||||||
|  |         print("Loading from", output_dir) | ||||||
|  |         for text, _ in TRAIN_DATA: | ||||||
|  |             doc = nlp(text) | ||||||
|  |             print('Entities', [(ent.text, ent.label_) for ent in doc.ents]) | ||||||
|  |             print('Tokens', [(t.text, t.ent_type_, t.ent_iob) for t in doc]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def reformat_train_data(tokenizer, examples): | def reformat_train_data(tokenizer, examples): | ||||||
|     """Reformat data to match JSON format""" |     """Reformat data to match JSON format. | ||||||
|  |     https://alpha.spacy.io/api/annotation#json-input | ||||||
|  | 
 | ||||||
|  |     tokenizer (Tokenizer): Tokenizer to process the raw text. | ||||||
|  |     examples (list): The trainig data. | ||||||
|  |     RETURNS (list): The reformatted training data.""" | ||||||
|     output = [] |     output = [] | ||||||
|     for i, (text, entity_offsets) in enumerate(examples): |     for i, (text, entity_offsets) in enumerate(examples): | ||||||
|         doc = tokenizer(text) |         doc = tokenizer(text) | ||||||
|  | @ -21,49 +112,6 @@ def reformat_train_data(tokenizer, examples): | ||||||
|     return output |     return output | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(model_dir=None): |  | ||||||
|     train_data = [ |  | ||||||
|         ( |  | ||||||
|             'Who is Shaka Khan?', |  | ||||||
|             [(len('Who is '), len('Who is Shaka Khan'), 'PERSON')] |  | ||||||
|         ), |  | ||||||
|         ( |  | ||||||
|             'I like London and Berlin.', |  | ||||||
|             [(len('I like '), len('I like London'), 'LOC'), |  | ||||||
|             (len('I like London and '), len('I like London and Berlin'), 'LOC')] |  | ||||||
|         ) |  | ||||||
|     ] |  | ||||||
|     nlp = English(pipeline=['tensorizer', 'ner']) |  | ||||||
|     get_data = lambda: reformat_train_data(nlp.tokenizer, train_data) |  | ||||||
|     optimizer = nlp.begin_training(get_data) |  | ||||||
|     for itn in range(100): |  | ||||||
|         random.shuffle(train_data) |  | ||||||
|         losses = {} |  | ||||||
|         for raw_text, entity_offsets in train_data: |  | ||||||
|             doc = nlp.make_doc(raw_text) |  | ||||||
|             gold = GoldParse(doc, entities=entity_offsets) |  | ||||||
|             nlp.update( |  | ||||||
|                 [doc], # Batch of Doc objects |  | ||||||
|                 [gold], # Batch of GoldParse objects |  | ||||||
|                 drop=0.5, # Dropout -- make it harder to memorise data |  | ||||||
|                 sgd=optimizer, # Callable to update weights |  | ||||||
|                 losses=losses) |  | ||||||
|         print(losses) |  | ||||||
|     print("Save to", model_dir) |  | ||||||
|     nlp.to_disk(model_dir) |  | ||||||
|     print("Load from", model_dir) |  | ||||||
|     nlp = spacy.lang.en.English(pipeline=['tensorizer', 'ner']) |  | ||||||
|     nlp.from_disk(model_dir) |  | ||||||
|     for raw_text, _ in train_data: |  | ||||||
|         doc = nlp(raw_text) |  | ||||||
|         for word in doc: |  | ||||||
|             print(word.text, word.ent_type_, word.ent_iob_) |  | ||||||
| 
 |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     import plac |     import plac | ||||||
|     plac.call(main) |     plac.call(main) | ||||||
|     # Who "" 2 |  | ||||||
|     # is "" 2 |  | ||||||
|     # Shaka "" PERSON 3 |  | ||||||
|     # Khan "" PERSON 1 |  | ||||||
|     # ? "" 2 |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user