mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 13:11:03 +03:00 
			
		
		
		
	* Add support for tag dictionary, and fix error-code for predict method
This commit is contained in:
		
							parent
							
								
									f00afe12c4
								
							
						
					
					
						commit
						3819a88e1b
					
				|  | @ -3,6 +3,7 @@ from cymem.cymem cimport Pool | |||
| from thinc.learner cimport LinearModel | ||||
| from thinc.features cimport Extractor | ||||
| from thinc.typedefs cimport atom_t, feat_t, weight_t, class_t | ||||
| from preshed.maps cimport PreshMap | ||||
| 
 | ||||
| from .typedefs cimport hash_t | ||||
| from .tokens cimport Tokens | ||||
|  | @ -15,7 +16,7 @@ cpdef enum TagType: | |||
| 
 | ||||
| cdef class Tagger: | ||||
|     cpdef int set_tags(self, Tokens tokens) except -1 | ||||
|     cpdef class_t predict(self, int i, Tokens tokens, object golds=*) except 0 | ||||
|     cpdef class_t predict(self, int i, Tokens tokens, object golds=*) except * | ||||
|   | ||||
|     cpdef readonly Pool mem | ||||
|     cpdef readonly Extractor extractor | ||||
|  | @ -23,3 +24,4 @@ cdef class Tagger: | |||
| 
 | ||||
|     cpdef readonly TagType tag_type | ||||
|     cpdef readonly list tag_names | ||||
|     cdef dict tagdict | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ from thinc.features cimport Feature, count_feats | |||
| NULL_TAG = 0 | ||||
| 
 | ||||
| 
 | ||||
| def setup_model_dir(tag_type, tag_names, templates, model_dir): | ||||
| def setup_model_dir(tag_type, tag_names, tag_counts, templates, model_dir): | ||||
|     if path.exists(model_dir): | ||||
|         shutil.rmtree(model_dir) | ||||
|     os.mkdir(model_dir) | ||||
|  | @ -26,6 +26,7 @@ def setup_model_dir(tag_type, tag_names, templates, model_dir): | |||
|         'tag_type': tag_type, | ||||
|         'templates': templates, | ||||
|         'tag_names': tag_names, | ||||
|         'tag_counts': tag_counts, | ||||
|     } | ||||
|     with open(path.join(model_dir, 'config.json'), 'w') as file_: | ||||
|         json.dump(config, file_) | ||||
|  | @ -35,24 +36,19 @@ def train(train_sents, model_dir, nr_iter=10): | |||
|     cdef Tokens tokens | ||||
|     cdef Tagger tagger = Tagger(model_dir) | ||||
|     cdef int i | ||||
|     cdef class_t guess = 0 | ||||
|     cdef class_t gold | ||||
|     for _ in range(nr_iter): | ||||
|         n_corr = 0 | ||||
|         total = 0 | ||||
|         for tokens, golds in train_sents: | ||||
|             assert len(tokens) == len(golds), [t.string for t in tokens] | ||||
|             for i in range(tokens.length): | ||||
|                 if tagger.tag_type == POS: | ||||
|                     gold = _get_gold_pos(i, golds) | ||||
|                 else: | ||||
|                     raise StandardError | ||||
| 
 | ||||
|                 guess = tagger.predict(i, tokens) | ||||
|                 gold = golds[i] | ||||
|                 guess = tagger.predict(i, tokens, [gold]) | ||||
|                 tokens.set_tag(i, tagger.tag_type, guess) | ||||
|                 if gold is not None: | ||||
|                     tagger.tell_answer(gold) | ||||
|                 total += 1 | ||||
|                     n_corr += guess in gold | ||||
|                 #print('%s\t%d\t%d' % (tokens[i].string, guess, gold)) | ||||
|                 n_corr += guess == gold | ||||
|         print('%.4f' % ((n_corr / total) * 100)) | ||||
|         random.shuffle(train_sents) | ||||
|     tagger.model.end_training() | ||||
|  | @ -96,8 +92,9 @@ cdef class Tagger: | |||
|         templates = cfg['templates'] | ||||
|         self.tag_names = cfg['tag_names'] | ||||
|         self.tag_type = cfg['tag_type'] | ||||
|         self.tagdict = _make_tag_dict(cfg['tag_counts']) | ||||
|         self.extractor = Extractor(templates) | ||||
|         self.model = LinearModel(len(self.tag_names)) | ||||
|         self.model = LinearModel(len(self.tag_names), self.extractor.n_templ+2) | ||||
|         if path.exists(path.join(model_dir, 'model')): | ||||
|             self.model.load(path.join(model_dir, 'model')) | ||||
| 
 | ||||
|  | @ -113,7 +110,7 @@ cdef class Tagger: | |||
|         for i in range(tokens.length): | ||||
|             tokens.set_tag(i, self.tag_type, self.predict(i, tokens)) | ||||
| 
 | ||||
|     cpdef class_t predict(self, int i, Tokens tokens, object golds=None) except 0: | ||||
|     cpdef class_t predict(self, int i, Tokens tokens, object golds=None) except *: | ||||
|         """Predict the tag of tokens[i].  The tagger remembers the features and | ||||
|         prediction, in case you later call tell_answer. | ||||
| 
 | ||||
|  | @ -121,16 +118,18 @@ cdef class Tagger: | |||
|         >>> tag = EN.pos_tagger.predict(0, tokens) | ||||
|         >>> assert tag == EN.pos_tagger.tag_id('DT') == 5 | ||||
|         """ | ||||
|         cdef int n_feats | ||||
|         cdef atom_t sic = tokens.data[i].lex.sic | ||||
|         if sic in self.tagdict: | ||||
|             return self.tagdict[sic] | ||||
|         cdef atom_t[N_FIELDS] context | ||||
|         print sizeof(context) | ||||
|         fill_context(context, i, tokens.data) | ||||
|         cdef int n_feats | ||||
|         cdef Feature* feats = self.extractor.get_feats(context, &n_feats) | ||||
|         cdef weight_t* scores = self.model.get_scores(feats, n_feats) | ||||
|         cdef class_t guess = _arg_max(scores, self.nr_class) | ||||
|         guess = _arg_max(scores, self.model.nr_class) | ||||
|         if golds is not None and guess not in golds: | ||||
|             best = _arg_max_among(scores, golds) | ||||
|             counts = {} | ||||
|             counts = {guess: {}, best: {}} | ||||
|             count_feats(counts[guess], feats, n_feats, -1) | ||||
|             count_feats(counts[best], feats, n_feats, 1) | ||||
|             self.model.update(counts) | ||||
|  | @ -145,12 +144,28 @@ cdef class Tagger: | |||
|         return tag_id | ||||
| 
 | ||||
| 
 | ||||
| cdef class_t _arg_max(weight_t* scores, int n_classes): | ||||
| def _make_tag_dict(counts): | ||||
|     freq_thresh = 50 | ||||
|     ambiguity_thresh = 0.98 | ||||
|     tagdict = {} | ||||
|     cdef atom_t word | ||||
|     cdef atom_t tag | ||||
|     for word_str, tag_freqs in counts.items(): | ||||
|         tag_str, mode = max(tag_freqs.items(), key=lambda item: item[1]) | ||||
|         n = sum(tag_freqs.values()) | ||||
|         word = int(word_str) | ||||
|         tag = int(tag_str) | ||||
|         if n >= freq_thresh and (float(mode) / n) >= ambiguity_thresh: | ||||
|             tagdict[word] = tag | ||||
|     return tagdict | ||||
| 
 | ||||
| 
 | ||||
| cdef class_t _arg_max(weight_t* scores, int n_classes) except 9000: | ||||
|     cdef int best = 0 | ||||
|     cdef weight_t score = scores[best] | ||||
|     cdef int i | ||||
|     for i in range(1, n_classes): | ||||
|         if scores[i] > score: | ||||
|         if scores[i] >= score: | ||||
|             score = scores[i] | ||||
|             best = i | ||||
|     return best | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user