mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Refactor sense tagger to get rid of intermediary layers
This commit is contained in:
		
							parent
							
								
									6735439abf
								
							
						
					
					
						commit
						2fbcdd0ea8
					
				|  | @ -1,13 +1,17 @@ | ||||||
| from thinc.api cimport Example |  | ||||||
| from thinc.typedefs cimport atom_t |  | ||||||
| 
 |  | ||||||
| from .typedefs cimport flags_t | from .typedefs cimport flags_t | ||||||
| from .structs cimport TokenC | from .structs cimport TokenC | ||||||
| from .strings cimport StringStore | from .strings cimport StringStore | ||||||
| from .tokens cimport Tokens | from .tokens cimport Tokens | ||||||
| from ._ml cimport Model |  | ||||||
| from .senses cimport POS_SENSES, N_SENSES, encode_sense_strs | from .senses cimport POS_SENSES, N_SENSES, encode_sense_strs | ||||||
| from .gold cimport GoldParse | from .gold cimport GoldParse | ||||||
|  | from .parts_of_speech cimport NOUN, VERB | ||||||
|  | 
 | ||||||
|  | from thinc.learner cimport LinearModel | ||||||
|  | from thinc.features cimport Extractor | ||||||
|  | 
 | ||||||
|  | from thinc.typedefs cimport atom_t, weight_t, feat_t | ||||||
|  | 
 | ||||||
|  | from os import path | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -173,6 +177,8 @@ cdef int fill_token(atom_t* ctxt, const TokenC* token) except -1: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1: | cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1: | ||||||
|  |     # NB: we have padding to keep us safe here | ||||||
|  |     # See tokens.pyx | ||||||
|     fill_token(&ctxt[P2W], token - 2) |     fill_token(&ctxt[P2W], token - 2) | ||||||
|     fill_token(&ctxt[P1W], token - 1) |     fill_token(&ctxt[P1W], token - 1) | ||||||
| 
 | 
 | ||||||
|  | @ -185,62 +191,79 @@ cdef int fill_context(atom_t* ctxt, const TokenC* token) except -1: | ||||||
| 
 | 
 | ||||||
| cdef class SenseTagger: | cdef class SenseTagger: | ||||||
|     cdef readonly StringStore strings |     cdef readonly StringStore strings | ||||||
|     cdef readonly Model model |     cdef readonly LinearModel model | ||||||
|  |     cdef readonly Extractor extractor | ||||||
|  |     cdef readonly model_dir | ||||||
| 
 | 
 | ||||||
|     def __init__(self, StringStore strings, model_dir): |     def __init__(self, StringStore strings, model_dir): | ||||||
|         self.strings = strings |         if model_dir is not None and path.isdir(model_dir): | ||||||
|  |             model_dir = path.join(model_dir, 'model') | ||||||
|  | 
 | ||||||
|         templates = unigrams + bigrams + trigrams |         templates = unigrams + bigrams + trigrams | ||||||
|         self.model = Model(N_SENSES, templates, model_dir) |         self.extractor = Extractor(templates) | ||||||
|  |         self.model = LinearModel(N_SENSES, self.extractor.n_templ) | ||||||
|  |         self.model_dir = model_dir | ||||||
|  |         if self.model_dir and path.exists(self.model_dir): | ||||||
|  |             self.model.load(self.model_dir, freq_thresh=0) | ||||||
|  |         self.strings = strings | ||||||
| 
 | 
 | ||||||
|     def __call__(self, Tokens tokens): |     def __call__(self, Tokens tokens): | ||||||
|         eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats, |         cdef atom_t[CONTEXT_SIZE] context | ||||||
|                      self.model.n_feats) |         cdef int i, guess, n_feats | ||||||
|         cdef int i |         cdef const TokenC* token | ||||||
|         for i in range(tokens.length): |         for i in range(tokens.length): | ||||||
|             n_valid = self._set_valid(<bint*>eg.c.is_valid, pos_senses(&tokens.data[i])) |             token = &tokens.data[i] | ||||||
|             if n_valid >= 1: |             if token.pos in (NOUN, VERB): | ||||||
|                 fill_context(eg.c.atoms, &tokens.data[i]) |                 fill_context(context, token) | ||||||
|                 self.model.predict(eg) |                 feats = self.extractor.get_feats(context, &n_feats) | ||||||
|                 tokens.data[i].sense = eg.c.guess |                 scores = self.model.get_scores(feats, n_feats) | ||||||
|  |                 tokens.data[i].sense = self.best_in_set(scores, POS_SENSES[<int>token.pos]) | ||||||
| 
 | 
 | ||||||
|     def train(self, Tokens tokens, GoldParse gold): |     def train(self, Tokens tokens, GoldParse gold): | ||||||
|         eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats+1, |         cdef int i, j | ||||||
|                      self.model.n_feats+1) |  | ||||||
|         cdef int i |  | ||||||
|         for i, ssenses in enumerate(gold.ssenses): |         for i, ssenses in enumerate(gold.ssenses): | ||||||
|             if ssenses: |             if ssenses: | ||||||
|                 gold.c.ssenses[i] = encode_sense_strs(ssenses) |                 gold.c.ssenses[i] = encode_sense_strs(ssenses) | ||||||
|             else: |             else: | ||||||
|                 gold.c.ssenses[i] = pos_senses(&tokens.data[i]) |                 gold.c.ssenses[i] = pos_senses(&tokens.data[i]) | ||||||
|  |          | ||||||
|  |         cdef atom_t[CONTEXT_SIZE] context | ||||||
|  |         cdef int n_feats | ||||||
|  |         cdef feat_t f_key | ||||||
|  |         cdef int f_i | ||||||
|         cdef int cost = 0 |         cdef int cost = 0 | ||||||
|         for i in range(tokens.length): |         for i in range(tokens.length): | ||||||
|             if tokens.data[i].lex.senses == 0 or tokens.data[i].lex.senses == 1: |             token = &tokens.data[i] | ||||||
|                 continue |             if token.pos in (NOUN, VERB) \ | ||||||
|             self._set_costs(<bint*>eg.c.is_valid, eg.c.costs, gold.c.ssenses[i]) |             and token.lex.senses >= 2 \ | ||||||
|             fill_context(eg.c.atoms, &tokens.data[i]) |             and gold.c.ssenses[i] >= 2: | ||||||
| 
 |                 fill_context(context, token) | ||||||
|             self.model.train(eg) |                 feats = self.extractor.get_feats(context, &n_feats) | ||||||
| 
 |                 scores = self.model.get_scores(feats, n_feats) | ||||||
|             tokens.data[i].sense = eg.c.guess |                 token.sense = self.best_in_set(scores, POS_SENSES[<int>token.pos]) | ||||||
|             cost += eg.c.cost |                 best = self.best_in_set(scores, gold.c.ssenses[i]) | ||||||
|  |                 guess_counts = {} | ||||||
|  |                 gold_counts = {} | ||||||
|  |                 if token.sense != best: | ||||||
|  |                     for j in range(n_feats): | ||||||
|  |                         f_key = feats[j].key | ||||||
|  |                         f_i = feats[j].i | ||||||
|  |                         feat = (f_i, f_key) | ||||||
|  |                         gold_counts[feat]  = gold_counts.get(feat, 0) + 1.0 | ||||||
|  |                         guess_counts[feat] = guess_counts.get(feat, 0) - 1.0 | ||||||
|  |                 #self.model.update({token.sense: guess_counts, best: gold_counts}) | ||||||
|         return cost |         return cost | ||||||
| 
 | 
 | ||||||
|     cdef int _set_valid(self, bint* is_valid, flags_t senses) except -1: |     cdef int best_in_set(self, const weight_t* scores, flags_t senses) except -1: | ||||||
|         cdef int n_valid |         cdef weight_t max_ = 0 | ||||||
|         cdef flags_t bit |         cdef int argmax = -1 | ||||||
|         is_valid[0] = False |         cdef flags_t i | ||||||
|         for bit in range(1, N_SENSES): |         for i in range(N_SENSES): | ||||||
|             is_valid[bit] = senses & (1 << bit) |             if (senses & (1 << i)) and (argmax == -1 or scores[i] > max_): | ||||||
|             n_valid += is_valid[bit] |                 max_ = scores[i] | ||||||
|         return n_valid |                 argmax = i | ||||||
| 
 |         assert argmax >= 0 | ||||||
|     cdef int _set_costs(self, bint* is_valid, int* costs, flags_t senses): |         return argmax | ||||||
|         cdef flags_t bit |  | ||||||
|         is_valid[0] = False |  | ||||||
|         costs[0] = 1 |  | ||||||
|         for bit in range(1, N_SENSES): |  | ||||||
|             is_valid[bit] = True  |  | ||||||
|             costs[bit] = 0 if (senses & (1 << bit)) else 1 |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef flags_t pos_senses(const TokenC* token) nogil: | cdef flags_t pos_senses(const TokenC* token) nogil: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user