mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Increase default number of iterations from 5 to 10
This commit is contained in:
		
							parent
							
								
									3cab1d9a29
								
							
						
					
					
						commit
						949a6245f9
					
				| 
						 | 
					@ -10,6 +10,7 @@ import random
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import cython
 | 
					import cython
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .context cimport fill_slots
 | 
					from .context cimport fill_slots
 | 
				
			||||||
from .context cimport fill_flat
 | 
					from .context cimport fill_flat
 | 
				
			||||||
from .context cimport N_FIELDS
 | 
					from .context cimport N_FIELDS
 | 
				
			||||||
| 
						 | 
					@ -33,7 +34,7 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir):
 | 
				
			||||||
        json.dump(config, file_)
 | 
					        json.dump(config, file_)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(train_sents, model_dir, nr_iter=5):
 | 
					def train(train_sents, model_dir, nr_iter=10):
 | 
				
			||||||
    tagger = Tagger(model_dir)
 | 
					    tagger = Tagger(model_dir)
 | 
				
			||||||
    for _ in range(nr_iter):
 | 
					    for _ in range(nr_iter):
 | 
				
			||||||
        n_corr = 0
 | 
					        n_corr = 0
 | 
				
			||||||
| 
						 | 
					@ -43,15 +44,15 @@ def train(train_sents, model_dir, nr_iter=5):
 | 
				
			||||||
            for i, gold in enumerate(golds):
 | 
					            for i, gold in enumerate(golds):
 | 
				
			||||||
                guess = tagger.predict(i, tokens)
 | 
					                guess = tagger.predict(i, tokens)
 | 
				
			||||||
                tokens.set_tag(i, tagger.tag_type, guess)
 | 
					                tokens.set_tag(i, tagger.tag_type, guess)
 | 
				
			||||||
                tagger.tell_answer(gold)
 | 
					 | 
				
			||||||
                if gold != NULL_TAG:
 | 
					                if gold != NULL_TAG:
 | 
				
			||||||
 | 
					                    tagger.tell_answer(gold)
 | 
				
			||||||
                    total += 1
 | 
					                    total += 1
 | 
				
			||||||
                    n_corr += guess == gold
 | 
					                    n_corr += guess == gold
 | 
				
			||||||
                #print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
 | 
					                #print('%s\t%d\t%d' % (tokens[i].string, guess, gold))
 | 
				
			||||||
        print('%.4f' % ((n_corr / total) * 100))
 | 
					        print('%.4f' % ((n_corr / total) * 100))
 | 
				
			||||||
        random.shuffle(train_sents)
 | 
					        random.shuffle(train_sents)
 | 
				
			||||||
    tagger.model.end_training()
 | 
					    tagger.model.end_training()
 | 
				
			||||||
    tagger.model.dump(path.join(model_dir, 'model'), freq_thresh=10)
 | 
					    tagger.model.dump(path.join(model_dir, 'model'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def evaluate(tagger, sents):
 | 
					def evaluate(tagger, sents):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user