mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Tidy up example and only save/test if output_directory is not None
This commit is contained in:
		
							parent
							
								
									d10bd0eaf9
								
							
						
					
					
						commit
						c7adca58a9
					
				| 
						 | 
				
			
			@ -1,22 +1,16 @@
 | 
			
		|||
from __future__ import unicode_literals, print_function
 | 
			
		||||
import json
 | 
			
		||||
import pathlib
 | 
			
		||||
 | 
			
		||||
import random
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import spacy
 | 
			
		||||
from spacy.pipeline import EntityRecognizer
 | 
			
		||||
from spacy.gold import GoldParse
 | 
			
		||||
from spacy.tagger import Tagger
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
try:
 | 
			
		||||
    unicode
 | 
			
		||||
except:
 | 
			
		||||
    unicode = str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train_ner(nlp, train_data, output_dir):
 | 
			
		||||
    # Add new words to vocab.
 | 
			
		||||
    # Add new words to vocab
 | 
			
		||||
    for raw_text, _ in train_data:
 | 
			
		||||
        doc = nlp.make_doc(raw_text)
 | 
			
		||||
        for word in doc:
 | 
			
		||||
| 
						 | 
				
			
			@ -30,11 +24,14 @@ def train_ner(nlp, train_data, output_dir):
 | 
			
		|||
            nlp.tagger(doc)
 | 
			
		||||
            loss = nlp.entity.update(doc, gold)
 | 
			
		||||
    nlp.end_training()
 | 
			
		||||
    nlp.save_to_directory(output_dir)
 | 
			
		||||
    if output_dir:
 | 
			
		||||
        nlp.save_to_directory(output_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(model_name, output_directory=None):
 | 
			
		||||
    nlp = spacy.load(model_name)
 | 
			
		||||
    if output_directory is not None:
 | 
			
		||||
        output_directory = Path(output_directory)
 | 
			
		||||
 | 
			
		||||
    train_data = [
 | 
			
		||||
        (
 | 
			
		||||
| 
						 | 
				
			
			@ -55,18 +52,18 @@ def main(model_name, output_directory=None):
 | 
			
		|||
        )
 | 
			
		||||
    ]
 | 
			
		||||
    nlp.entity.add_label('ANIMAL')
 | 
			
		||||
    if output_directory is not None:
 | 
			
		||||
        output_directory = pathlib.Path(output_directory)
 | 
			
		||||
    ner = train_ner(nlp, train_data, output_directory)
 | 
			
		||||
 | 
			
		||||
    # Test that the entity is recognized
 | 
			
		||||
    doc = nlp('Do you like horses?')
 | 
			
		||||
    for ent in doc.ents:
 | 
			
		||||
        print(ent.label_, ent.text)
 | 
			
		||||
    nlp2 = spacy.load('en', path=output_directory)
 | 
			
		||||
    nlp2.entity.add_label('ANIMAL')
 | 
			
		||||
    doc2 = nlp2('Do you like horses?')
 | 
			
		||||
    for ent in doc2.ents:
 | 
			
		||||
        print(ent.label_, ent.text)
 | 
			
		||||
    if output_directory:
 | 
			
		||||
        nlp2 = spacy.load('en', path=output_directory)
 | 
			
		||||
        nlp2.entity.add_label('ANIMAL')
 | 
			
		||||
        doc2 = nlp2('Do you like horses?')
 | 
			
		||||
        for ent in doc2.ents:
 | 
			
		||||
            print(ent.label_, ent.text)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user