mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Gradients look correct
This commit is contained in:
		
							parent
							
								
									7e04260d38
								
							
						
					
					
						commit
						8e48b58cd6
					
				|  | @ -1,4 +1,4 @@ | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals, print_function | ||||||
| import plac | import plac | ||||||
| import json | import json | ||||||
| import random | import random | ||||||
|  | @ -9,7 +9,7 @@ from spacy.syntax.nonproj import PseudoProjectivity | ||||||
| from spacy.language import Language | from spacy.language import Language | ||||||
| from spacy.gold import GoldParse | from spacy.gold import GoldParse | ||||||
| from spacy.tagger import Tagger | from spacy.tagger import Tagger | ||||||
| from spacy.pipeline import DependencyParser, BeamDependencyParser | from spacy.pipeline import DependencyParser, TokenVectorEncoder | ||||||
| from spacy.syntax.parser import get_templates | from spacy.syntax.parser import get_templates | ||||||
| from spacy.syntax.arc_eager import ArcEager | from spacy.syntax.arc_eager import ArcEager | ||||||
| from spacy.scorer import Scorer | from spacy.scorer import Scorer | ||||||
|  | @ -36,10 +36,10 @@ def read_conllx(loc, n=0): | ||||||
|                 try: |                 try: | ||||||
|                     id_ = int(id_) - 1 |                     id_ = int(id_) - 1 | ||||||
|                     head = (int(head) - 1) if head != '0' else id_ |                     head = (int(head) - 1) if head != '0' else id_ | ||||||
|                     dep = 'ROOT' if dep == 'root' else dep |                     dep = 'ROOT' if dep == 'root' else 'unlabelled' | ||||||
|                     tokens.append((id_, word, tag, head, dep, 'O')) |                     # Hack for efficiency | ||||||
|  |                     tokens.append((id_, word, pos+'__'+morph, head, dep, 'O')) | ||||||
|                 except: |                 except: | ||||||
|                     print(line) |  | ||||||
|                     raise |                     raise | ||||||
|             tuples = [list(t) for t in zip(*tokens)] |             tuples = [list(t) for t in zip(*tokens)] | ||||||
|             yield (None, [[tuples, []]]) |             yield (None, [[tuples, []]]) | ||||||
|  | @ -48,19 +48,37 @@ def read_conllx(loc, n=0): | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def score_model(vocab, tagger, parser, gold_docs, verbose=False): | def score_model(vocab, encoder, tagger, parser, Xs, ys, verbose=False): | ||||||
|     scorer = Scorer() |     scorer = Scorer() | ||||||
|     for _, gold_doc in gold_docs: |     correct = 0. | ||||||
|         for (ids, words, tags, heads, deps, entities), _ in gold_doc: |     total = 0. | ||||||
|             doc = Doc(vocab, words=words) |     for doc, gold in zip(Xs, ys): | ||||||
|             tagger(doc) |         doc = Doc(vocab, words=[w.text for w in doc]) | ||||||
|             parser(doc) |         encoder(doc) | ||||||
|             PseudoProjectivity.deprojectivize(doc) |         tagger(doc) | ||||||
|             gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) |         parser(doc) | ||||||
|             scorer.score(doc, gold, verbose=verbose) |         PseudoProjectivity.deprojectivize(doc) | ||||||
|  |         scorer.score(doc, gold, verbose=verbose) | ||||||
|  |         for token, tag in zip(doc, gold.tags): | ||||||
|  |             univ_guess, _ = token.tag_.split('_', 1) | ||||||
|  |             univ_truth, _ = tag.split('_', 1) | ||||||
|  |             correct += univ_guess == univ_truth | ||||||
|  |             total += 1 | ||||||
|     return scorer |     return scorer | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def organize_data(vocab, train_sents): | ||||||
|  |     Xs = [] | ||||||
|  |     ys = [] | ||||||
|  |     for _, doc_sents in train_sents: | ||||||
|  |         for (ids, words, tags, heads, deps, ner), _ in doc_sents: | ||||||
|  |             doc = Doc(vocab, words=words) | ||||||
|  |             gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) | ||||||
|  |             Xs.append(doc) | ||||||
|  |             ys.append(gold) | ||||||
|  |     return Xs, ys | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | ||||||
|     LangClass = spacy.util.get_lang_class(lang_name) |     LangClass = spacy.util.get_lang_class(lang_name) | ||||||
|     train_sents = list(read_conllx(train_loc)) |     train_sents = list(read_conllx(train_loc)) | ||||||
|  | @ -114,21 +132,37 @@ def main(lang_name, train_loc, dev_loc, model_dir, clusters_loc=None): | ||||||
|                 for tag in tags: |                 for tag in tags: | ||||||
|                     assert tag in vocab.morphology.tag_map, repr(tag) |                     assert tag in vocab.morphology.tag_map, repr(tag) | ||||||
|     tagger = Tagger(vocab) |     tagger = Tagger(vocab) | ||||||
|  |     encoder = TokenVectorEncoder(vocab) | ||||||
|     parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) |     parser = DependencyParser(vocab, actions=actions, features=features, L1=0.0) | ||||||
| 
 | 
 | ||||||
|     for itn in range(30): |      | ||||||
|         loss = 0. |     Xs, ys = organize_data(vocab, train_sents) | ||||||
|         for _, doc_sents in train_sents: |     Xs = Xs[:1] | ||||||
|             for (ids, words, tags, heads, deps, ner), _ in doc_sents: |     ys = ys[:1] | ||||||
|                 doc = Doc(vocab, words=words) |     with encoder.model.begin_training(Xs[:100], ys[:100]) as (trainer, optimizer): | ||||||
|                 gold = GoldParse(doc, tags=tags, heads=heads, deps=deps) |         docs = list(Xs) | ||||||
|                 tagger(doc) |         for doc in docs: | ||||||
|                 loss += parser.update(doc, gold, itn=itn) |             encoder(doc) | ||||||
|                 doc = Doc(vocab, words=words) |         parser.begin_training(docs, ys) | ||||||
|  |         nn_loss = [0.] | ||||||
|  |         def track_progress(): | ||||||
|  |             scorer = score_model(vocab, encoder, tagger, parser, Xs, ys) | ||||||
|  |             itn = len(nn_loss) | ||||||
|  |             print('%d:\t%.3f\t%.3f\t%.3f' % (itn, nn_loss[-1], scorer.uas, scorer.tags_acc)) | ||||||
|  |             nn_loss.append(0.) | ||||||
|  |         trainer.each_epoch.append(track_progress) | ||||||
|  |         trainer.batch_size = 1 | ||||||
|  |         trainer.nb_epoch = 100 | ||||||
|  |         for docs, golds in trainer.iterate(Xs, ys, progress_bar=False): | ||||||
|  |             docs = [Doc(vocab, words=[w.text for w in doc]) for doc in docs] | ||||||
|  |             tokvecs, upd_tokvecs = encoder.begin_update(docs) | ||||||
|  |             for doc, tokvec in zip(docs, tokvecs): | ||||||
|  |                 doc.tensor = tokvec | ||||||
|  |             for doc, gold in zip(docs, golds): | ||||||
|                 tagger.update(doc, gold) |                 tagger.update(doc, gold) | ||||||
|         random.shuffle(train_sents) |             d_tokvecs, loss = parser.update(docs, golds, sgd=optimizer) | ||||||
|         scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) |             upd_tokvecs(d_tokvecs, sgd=optimizer) | ||||||
|         print('%d:\t%.3f\t%.3f\t%.3f' % (itn, loss, scorer.uas, scorer.tags_acc)) |             nn_loss[-1] += loss | ||||||
|     nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) |     nlp = LangClass(vocab=vocab, tagger=tagger, parser=parser) | ||||||
|     nlp.end_training(model_dir) |     nlp.end_training(model_dir) | ||||||
|     scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) |     scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc)) | ||||||
|  |  | ||||||
							
								
								
									
										34
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -1,5 +1,5 @@ | ||||||
| from thinc.api import layerize, chain, clone, concatenate, with_flatten | from thinc.api import layerize, chain, clone, concatenate, with_flatten | ||||||
| from thinc.neural import Model, Maxout, Softmax | from thinc.neural import Model, Maxout, Softmax, Affine | ||||||
| from thinc.neural._classes.hash_embed import HashEmbed | from thinc.neural._classes.hash_embed import HashEmbed | ||||||
| 
 | 
 | ||||||
| from thinc.neural._classes.convolution import ExtractWindow | from thinc.neural._classes.convolution import ExtractWindow | ||||||
|  | @ -21,11 +21,41 @@ def build_model(state2vec, width, depth, nr_class): | ||||||
|             state2vec |             state2vec | ||||||
|             >> Maxout(width, 1344) |             >> Maxout(width, 1344) | ||||||
|             >> Maxout(width, width) |             >> Maxout(width, width) | ||||||
|             >> Softmax(nr_class, width) |             >> Affine(nr_class, width) | ||||||
|         ) |         ) | ||||||
|     return model |     return model | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def build_debug_model(state2vec, width, depth, nr_class): | ||||||
|  |     with Model.define_operators({'>>': chain, '**': clone}): | ||||||
|  |         model = ( | ||||||
|  |             state2vec | ||||||
|  |             >> Maxout(width) | ||||||
|  |             >> Affine(nr_class) | ||||||
|  |         ) | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_debug_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): | ||||||
|  |     ops = Model.ops | ||||||
|  |     def forward(tokens_attrs_vectors, drop=0.): | ||||||
|  |         tokens, attr_vals, tokvecs = tokens_attrs_vectors | ||||||
|  |          | ||||||
|  |         orig_tokvecs_shape = tokvecs.shape | ||||||
|  |         tokvecs = tokvecs.reshape((tokvecs.shape[0], tokvecs.shape[1] * | ||||||
|  |                                    tokvecs.shape[2])) | ||||||
|  | 
 | ||||||
|  |         vector = tokvecs | ||||||
|  | 
 | ||||||
|  |         def backward(d_vector, sgd=None): | ||||||
|  |             d_tokvecs = vector.reshape(orig_tokvecs_shape) | ||||||
|  |             return (tokens, d_tokvecs) | ||||||
|  |         return vector, backward | ||||||
|  |     model = layerize(forward) | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): | def build_parser_state2vec(width, nr_vector=1000, nF=1, nB=0, nS=1, nL=2, nR=2): | ||||||
|     embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector))) |     embed_tags = _reshape(chain(get_col(0), HashEmbed(16, nr_vector))) | ||||||
|     embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector))) |     embed_deps = _reshape(chain(get_col(1), HashEmbed(16, nr_vector))) | ||||||
|  |  | ||||||
|  | @ -28,6 +28,8 @@ from murmurhash.mrmr cimport hash64 | ||||||
| from preshed.maps cimport MapStruct | from preshed.maps cimport MapStruct | ||||||
| from preshed.maps cimport map_get | from preshed.maps cimport map_get | ||||||
| 
 | 
 | ||||||
|  | from numpy import exp | ||||||
|  | 
 | ||||||
| from . import _parse_features | from . import _parse_features | ||||||
| from ._parse_features cimport CONTEXT_SIZE | from ._parse_features cimport CONTEXT_SIZE | ||||||
| from ._parse_features cimport fill_context | from ._parse_features cimport fill_context | ||||||
|  | @ -43,6 +45,7 @@ from ..gold cimport GoldParse | ||||||
| from ..attrs cimport TAG, DEP | from ..attrs cimport TAG, DEP | ||||||
| 
 | 
 | ||||||
| from .._ml import build_parser_state2vec, build_model | from .._ml import build_parser_state2vec, build_model | ||||||
|  | from .._ml import build_debug_state2vec, build_debug_model | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| USE_FTRL = True | USE_FTRL = True | ||||||
|  | @ -111,8 +114,8 @@ cdef class Parser: | ||||||
|         return (Parser, (self.vocab, self.moves, self.model), None, None) |         return (Parser, (self.vocab, self.moves, self.model), None, None) | ||||||
| 
 | 
 | ||||||
|     def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): |     def build_model(self, width=8, nr_vector=1000, nF=1, nB=1, nS=1, nL=1, nR=1, **_): | ||||||
|         state2vec = build_parser_state2vec(width, nr_vector, nF, nB, nL, nR) |         state2vec = build_debug_state2vec(width, nr_vector, nF, nB, nL, nR) | ||||||
|         model = build_model(state2vec, width, 2, self.moves.n_moves) |         model = build_debug_model(state2vec, width, 2, self.moves.n_moves) | ||||||
|         return model |         return model | ||||||
| 
 | 
 | ||||||
|     def __call__(self, Doc tokens): |     def __call__(self, Doc tokens): | ||||||
|  | @ -166,32 +169,22 @@ cdef class Parser: | ||||||
|         cdef Doc doc |         cdef Doc doc | ||||||
|         cdef StateClass state |         cdef StateClass state | ||||||
|         cdef int guess |         cdef int guess | ||||||
|         is_valid = self.model.ops.allocate((len(docs), nr_class), dtype='i') |  | ||||||
|         tokvecs = [d.tensor for d in docs] |         tokvecs = [d.tensor for d in docs] | ||||||
|         attr_names = self.model.ops.allocate((2,), dtype='i') |  | ||||||
|         attr_names[0] = TAG |  | ||||||
|         attr_names[1] = DEP |  | ||||||
|         all_states = list(states) |         all_states = list(states) | ||||||
|         todo = zip(states, tokvecs) |         todo = zip(states, tokvecs) | ||||||
|         while todo: |         while todo: | ||||||
|             states, tokvecs = zip(*todo) |             states, tokvecs = zip(*todo) | ||||||
|             features = self._get_features(states, tokvecs, attr_names) |             scores, _ = self._begin_update(states, tokvecs) | ||||||
|             scores = self.model.predict(features) |  | ||||||
|             self._validate_batch(is_valid, states) |  | ||||||
|             scores *= is_valid |  | ||||||
|             for state, guess in zip(states, scores.argmax(axis=1)): |             for state, guess in zip(states, scores.argmax(axis=1)): | ||||||
|                 action = self.moves.c[guess] |                 action = self.moves.c[guess] | ||||||
|                 action.do(state.c, action.label) |                 action.do(state.c, action.label) | ||||||
|             todo = filter(lambda sp: not sp[0].is_final(), todo) |             todo = filter(lambda sp: not sp[0].py_is_final(), todo) | ||||||
|         for state, doc in zip(all_states, docs): |         for state, doc in zip(all_states, docs): | ||||||
|             self.moves.finalize_state(state.c) |             self.moves.finalize_state(state.c) | ||||||
|             for i in range(doc.length): |             for i in range(doc.length): | ||||||
|                 doc.c[i] = state.c._sent[i] |                 doc.c[i] = state.c._sent[i] | ||||||
| 
 | 
 | ||||||
| 
 |     def begin_training(self, docs, golds): | ||||||
|     def update(self, docs, golds, drop=0., sgd=None): |  | ||||||
|         if isinstance(docs, Doc) and isinstance(golds, GoldParse): |  | ||||||
|             return self.update([docs], [golds], drop=drop) |  | ||||||
|         for gold in golds: |         for gold in golds: | ||||||
|             self.moves.preprocess_gold(gold) |             self.moves.preprocess_gold(gold) | ||||||
|         states = self._init_states(docs) |         states = self._init_states(docs) | ||||||
|  | @ -204,39 +197,60 @@ cdef class Parser: | ||||||
|         attr_names = self.model.ops.allocate((2,), dtype='i') |         attr_names = self.model.ops.allocate((2,), dtype='i') | ||||||
|         attr_names[0] = TAG |         attr_names[0] = TAG | ||||||
|         attr_names[1] = DEP |         attr_names[1] = DEP | ||||||
|  |          | ||||||
|  |         features = self._get_features(states, tokvecs, attr_names) | ||||||
|  |         self.model.begin_training(features) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def update(self, docs, golds, drop=0., sgd=None): | ||||||
|  |         if isinstance(docs, Doc) and isinstance(golds, GoldParse): | ||||||
|  |             return self.update([docs], [golds], drop=drop) | ||||||
|  |         for gold in golds: | ||||||
|  |             self.moves.preprocess_gold(gold) | ||||||
|  |         states = self._init_states(docs) | ||||||
|  |         tokvecs = [d.tensor for d in docs] | ||||||
|  |         d_tokens = [self.model.ops.allocate(d.tensor.shape) for d in docs] | ||||||
|  |         nr_class = self.moves.n_moves | ||||||
|         output = list(d_tokens) |         output = list(d_tokens) | ||||||
|         todo = zip(states, tokvecs, golds, d_tokens) |         todo = zip(states, tokvecs, golds, d_tokens) | ||||||
|         assert len(states) == len(todo) |         assert len(states) == len(todo) | ||||||
|         loss = 0. |         loss = 0. | ||||||
|         while todo: |         while todo: | ||||||
|             states, tokvecs, golds, d_tokens = zip(*todo) |             states, tokvecs, golds, d_tokens = zip(*todo) | ||||||
|             features = self._get_features(states, tokvecs, attr_names) |             scores, finish_update = self._begin_update(states, tokvecs) | ||||||
| 
 |             token_ids, batch_token_grads = finish_update(golds, sgd=sgd) | ||||||
|             scores, finish_update = self.model.begin_update(features, drop=drop) |  | ||||||
|             assert scores.shape == (len(states), self.moves.n_moves), (len(states), scores.shape) |  | ||||||
| 
 |  | ||||||
|             self._cost_batch(costs, is_valid, states, golds) |  | ||||||
|             scores *= is_valid |  | ||||||
|             self._set_gradient(gradients, scores, costs) |  | ||||||
|             loss += numpy.abs(gradients).sum() / gradients.shape[0] |  | ||||||
| 
 |  | ||||||
|             token_ids, batch_token_grads = finish_update(gradients, sgd=sgd) |  | ||||||
|             for i, tok_i in enumerate(token_ids): |             for i, tok_i in enumerate(token_ids): | ||||||
|                 d_tokens[i][tok_i] += batch_token_grads[i] |                 d_tokens[i][tok_i] += batch_token_grads[i] | ||||||
| 
 | 
 | ||||||
|             self._transition_batch(states, scores) |             self._transition_batch(states, scores) | ||||||
| 
 | 
 | ||||||
|             # Get unfinished states (and their matching gold and token gradients) |             # Get unfinished states (and their matching gold and token gradients) | ||||||
|             todo = filter(lambda sp: not sp[0].is_final(), todo) |             todo = filter(lambda sp: not sp[0].py_is_final(), todo) | ||||||
|             costs = costs[:len(todo)] |  | ||||||
|             is_valid = is_valid[:len(todo)] |  | ||||||
|             gradients = gradients[:len(todo)] |  | ||||||
| 
 |  | ||||||
|             gradients.fill(0) |  | ||||||
|             costs.fill(0) |  | ||||||
|             is_valid.fill(1) |  | ||||||
|         return output, loss |         return output, loss | ||||||
| 
 | 
 | ||||||
|  |     def _begin_update(self, states, tokvecs, drop=0.): | ||||||
|  |         nr_class = self.moves.n_moves | ||||||
|  |         attr_names = self.model.ops.allocate((2,), dtype='i') | ||||||
|  |         attr_names[0] = TAG | ||||||
|  |         attr_names[1] = DEP | ||||||
|  | 
 | ||||||
|  |         features = self._get_features(states, tokvecs, attr_names) | ||||||
|  |         scores, finish_update = self.model.begin_update(features, drop=drop) | ||||||
|  |         is_valid = self.model.ops.allocate((len(states), nr_class), dtype='i') | ||||||
|  |         self._validate_batch(is_valid, states) | ||||||
|  |         softmaxed = self.model.ops.softmax(scores) | ||||||
|  |         softmaxed *= is_valid | ||||||
|  |         softmaxed /= softmaxed.sum(axis=1) | ||||||
|  |         print('Scores', softmaxed[0]) | ||||||
|  |         def backward(golds, sgd=None): | ||||||
|  |             costs = self.model.ops.allocate((len(states), nr_class), dtype='f') | ||||||
|  |             d_scores = self.model.ops.allocate((len(states), nr_class), dtype='f') | ||||||
|  | 
 | ||||||
|  |             self._cost_batch(costs, is_valid, states, golds) | ||||||
|  |             self._set_gradient(d_scores, scores, is_valid, costs) | ||||||
|  |             return finish_update(d_scores, sgd=sgd) | ||||||
|  |         return softmaxed, backward | ||||||
|  | 
 | ||||||
|     def _init_states(self, docs): |     def _init_states(self, docs): | ||||||
|         states = [] |         states = [] | ||||||
|         cdef Doc doc |         cdef Doc doc | ||||||
|  | @ -281,20 +295,20 @@ cdef class Parser: | ||||||
|             action = self.moves.c[guess] |             action = self.moves.c[guess] | ||||||
|             action.do(state.c, action.label) |             action.do(state.c, action.label) | ||||||
| 
 | 
 | ||||||
|     def _set_gradient(self, gradients, scores, costs): |     def _set_gradient(self, gradients, scores, is_valid, costs): | ||||||
|         """Do multi-label log loss""" |         """Do multi-label log loss""" | ||||||
|         cdef double Z, gZ, max_, g_max |         cdef double Z, gZ, max_, g_max | ||||||
|         g_scores = scores * (costs <= 0) |         scores = scores * is_valid | ||||||
|         maxes = scores.max(axis=1).reshape((scores.shape[0], 1)) |         g_scores = scores * is_valid * (costs <= 0.) | ||||||
|         g_maxes = g_scores.max(axis=1).reshape((g_scores.shape[0], 1)) |         exps = numpy.exp(scores - scores.max(axis=1)) | ||||||
|         exps = numpy.exp((scores-maxes)) |         exps *= is_valid | ||||||
|         g_exps = numpy.exp(g_scores-g_maxes) |         g_exps = numpy.exp(g_scores - g_scores.max(axis=1)) | ||||||
| 
 |         g_exps *= costs <= 0. | ||||||
|         Zs = exps.sum(axis=1).reshape((exps.shape[0], 1)) |         g_exps *= is_valid | ||||||
|         gZs = g_exps.sum(axis=1).reshape((g_exps.shape[0], 1)) |         gradients[:] = exps / exps.sum(axis=1) | ||||||
|         logprob = exps / Zs |         gradients -= g_exps / g_exps.sum(axis=1) | ||||||
|         g_logprob = g_exps / gZs |         print('Gradient', gradients[0]) | ||||||
|         gradients[:] = logprob - g_logprob |         print('Costs', costs[0]) | ||||||
| 
 | 
 | ||||||
|     def step_through(self, Doc doc, GoldParse gold=None): |     def step_through(self, Doc doc, GoldParse gold=None): | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ cdef class StateClass: | ||||||
|     def token_vector_lenth(self): |     def token_vector_lenth(self): | ||||||
|         return self.doc.tensor.shape[1] |         return self.doc.tensor.shape[1] | ||||||
| 
 | 
 | ||||||
|     def is_final(self): |     def py_is_final(self): | ||||||
|         return self.c.is_final() |         return self.c.is_final() | ||||||
| 
 | 
 | ||||||
|     def print_state(self, words): |     def print_state(self, words): | ||||||
|  | @ -47,31 +47,38 @@ cdef class StateClass: | ||||||
|         return ' '.join((third, second, top, '|', n0, n1)) |         return ' '.join((third, second, top, '|', n0, n1)) | ||||||
| 
 | 
 | ||||||
|     def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR): |     def nr_context_tokens(self, int nF, int nB, int nS, int nL, int nR): | ||||||
|         return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR) |         return 3 | ||||||
|  |         #return 1+nF+nB+nS + nL + (nS * nL) + (nS * nR) | ||||||
| 
 | 
 | ||||||
|     def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, |     def set_context_tokens(self, int[:] output, nF=1, nB=0, nS=2, | ||||||
|             nL=2, nR=2): |             nL=2, nR=2): | ||||||
|         output[0] = self.B(0) |         output[0] = self.B(0) | ||||||
|         output[1] = self.S(0) |         output[1] = self.S(0) | ||||||
|         output[2] = self.S(1) |         output[2] = self.S(1) | ||||||
|         output[3] = self.L(self.S(0), 1) |         #output[3] = self.L(self.S(0), 1) | ||||||
|         output[4] = self.L(self.S(0), 2) |         #output[4] = self.L(self.S(0), 2) | ||||||
|         output[5] = self.R(self.S(0), 1) |         #output[5] = self.R(self.S(0), 1) | ||||||
|         output[6] = self.R(self.S(0), 2) |         #output[6] = self.R(self.S(0), 2) | ||||||
|         output[7] = self.L(self.S(1), 1) |         #output[7] = self.L(self.S(1), 1) | ||||||
|         output[8] = self.L(self.S(1), 2) |         #output[8] = self.L(self.S(1), 2) | ||||||
|         output[9] = self.R(self.S(1), 1) |         #output[9] = self.R(self.S(1), 1) | ||||||
|         output[10] = self.R(self.S(1), 2) |         #output[10] = self.R(self.S(1), 2) | ||||||
| 
 | 
 | ||||||
|     def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): |     def set_attributes(self, uint64_t[:, :] vals, int[:] tokens, int[:] names): | ||||||
|         cdef int i, j, tok_i |         cdef int i, j, tok_i | ||||||
|         for i in range(tokens.shape[0]): |         for i in range(tokens.shape[0]): | ||||||
|             tok_i = tokens[i] |             tok_i = tokens[i] | ||||||
|             token = &self.c._sent[tok_i] |             if tok_i >= 0: | ||||||
|             for j in range(names.shape[0]): |                 token = &self.c._sent[tok_i] | ||||||
|                 vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j]) |                 for j in range(names.shape[0]): | ||||||
|  |                     vals[i, j] = Token.get_struct_attr(token, <attr_id_t>names[j]) | ||||||
|  |             else: | ||||||
|  |                 vals[i] = 0 | ||||||
| 
 | 
 | ||||||
|     def set_token_vectors(self, float[:, :] tokvecs, |     def set_token_vectors(self, float[:, :] tokvecs, | ||||||
|             float[:, :] all_tokvecs, int[:] indices): |             float[:, :] all_tokvecs, int[:] indices): | ||||||
|         for i in range(indices.shape[0]): |         for i in range(indices.shape[0]): | ||||||
|             tokvecs[i] = all_tokvecs[indices[i]] |             if indices[i] >= 0: | ||||||
|  |                 tokvecs[i] = all_tokvecs[indices[i]] | ||||||
|  |             else: | ||||||
|  |                 tokvecs[i] = 0 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user