mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Remove state argument in pipeline. Other changes
This commit is contained in:
		
							parent
							
								
									66ea9aebe7
								
							
						
					
					
						commit
						c12ab47a56
					
				|  | @ -33,7 +33,7 @@ from .morphology cimport Morphology | |||
| from .vocab cimport Vocab | ||||
| 
 | ||||
| from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS | ||||
| from ._ml import Tok2Vec, flatten, get_col, doc2feats | ||||
| from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats | ||||
| from .parts_of_speech import X | ||||
| 
 | ||||
| 
 | ||||
|  | @ -57,18 +57,12 @@ class TokenVectorEncoder(object): | |||
|             docs = [docs] | ||||
|         tokvecs = self.predict(docs) | ||||
|         self.set_annotations(docs, tokvecs) | ||||
|         state = {} if state is None else state | ||||
|         state['tokvecs'] = tokvecs | ||||
|         return state | ||||
| 
 | ||||
|     def pipe(self, stream, batch_size=128, n_threads=-1): | ||||
|         for batch in cytoolz.partition_all(batch_size, stream): | ||||
|             docs, states = zip(*batch) | ||||
|         for docs in cytoolz.partition_all(batch_size, stream): | ||||
|             tokvecs = self.predict(docs) | ||||
|             self.set_annotations(docs, tokvecs) | ||||
|             for state in states: | ||||
|                 state['tokvecs'] = tokvecs | ||||
|             yield from zip(docs, states) | ||||
|             yield from docs | ||||
| 
 | ||||
|     def predict(self, docs): | ||||
|         feats = self.doc2feats(docs) | ||||
|  | @ -81,18 +75,12 @@ class TokenVectorEncoder(object): | |||
|             doc.tensor = tokvecs[start : start + len(doc)] | ||||
|             start += len(doc) | ||||
| 
 | ||||
|     def update(self, docs, golds, state=None, | ||||
|                drop=0., sgd=None): | ||||
|     def begin_update(self, docs, drop=0.): | ||||
|         if isinstance(docs, Doc): | ||||
|             docs = [docs] | ||||
|             golds = [golds] | ||||
|         state = {} if state is None else state | ||||
|         feats = self.doc2feats(docs) | ||||
|         tokvecs, bp_tokvecs = self.model.begin_update(feats, drop=drop) | ||||
|         state['feats'] = feats | ||||
|         state['tokvecs'] = tokvecs | ||||
|         state['bp_tokvecs'] = bp_tokvecs | ||||
|         return state | ||||
|         return tokvecs, bp_tokvecs | ||||
| 
 | ||||
|     def get_loss(self, docs, golds, scores): | ||||
|         raise NotImplementedError | ||||
|  | @ -113,22 +101,16 @@ class NeuralTagger(object): | |||
|         self.vocab = vocab | ||||
|         self.model = model | ||||
| 
 | ||||
|     def __call__(self, doc, state=None): | ||||
|         assert state is not None | ||||
|         assert 'tokvecs' in state | ||||
|         tokvecs = state['tokvecs'] | ||||
|         tags = self.predict(tokvecs) | ||||
|     def __call__(self, doc): | ||||
|         tags = self.predict(doc.tensor) | ||||
|         self.set_annotations([doc], tags) | ||||
|         return state | ||||
| 
 | ||||
|     def pipe(self, stream, batch_size=128, n_threads=-1): | ||||
|         for batch in cytoolz.partition_all(batch_size, stream): | ||||
|             docs, states = zip(*batch) | ||||
|             tag_ids = self.predict(states[0]['tokvecs']) | ||||
|         for docs in cytoolz.partition_all(batch_size, stream): | ||||
|             tokvecs = self.model.ops.flatten([d.tensor for d in docs]) | ||||
|             tag_ids = self.predict(tokvecs) | ||||
|             self.set_annotations(docs, tag_ids) | ||||
|             for state in states: | ||||
|                 state['tag_ids'] = tag_ids | ||||
|             yield from zip(docs, states) | ||||
|             yield from docs | ||||
| 
 | ||||
|     def predict(self, tokvecs): | ||||
|         scores = self.model(tokvecs) | ||||
|  | @ -150,11 +132,9 @@ class NeuralTagger(object): | |||
|                 vocab.morphology.assign_tag_id(&doc.c[j], tag_id) | ||||
|                 idx += 1 | ||||
| 
 | ||||
|     def update(self, docs, golds, state=None, drop=0., sgd=None): | ||||
|         state = {} if state is None else state | ||||
|     def update(self, docs_tokvecs, golds, drop=0., sgd=None): | ||||
|         docs, tokvecs = docs_tokvecs | ||||
| 
 | ||||
|         tokvecs = state['tokvecs'] | ||||
|         bp_tokvecs = state['bp_tokvecs'] | ||||
|         if self.model.nI is None: | ||||
|             self.model.nI = tokvecs.shape[1] | ||||
| 
 | ||||
|  | @ -163,20 +143,20 @@ class NeuralTagger(object): | |||
| 
 | ||||
|         d_tokvecs = bp_tag_scores(d_tag_scores, sgd=sgd) | ||||
| 
 | ||||
|         bp_tokvecs(d_tokvecs, sgd=sgd) | ||||
| 
 | ||||
|         state['tag_scores'] = tag_scores | ||||
|         state['tag_loss'] = loss | ||||
|         return state | ||||
|         return d_tokvecs | ||||
| 
 | ||||
|     def get_loss(self, docs, golds, scores): | ||||
|         tag_index = {tag: i for i, tag in enumerate(self.vocab.morphology.tag_names)} | ||||
| 
 | ||||
|         cdef int idx = 0 | ||||
|         correct = numpy.zeros((scores.shape[0],), dtype='i') | ||||
|         guesses = scores.argmax(axis=1) | ||||
|         for gold in golds: | ||||
|             for tag in gold.tags: | ||||
|                 correct[idx] = tag_index[tag] | ||||
|                 if tag is None: | ||||
|                     correct[idx] = guesses[idx] | ||||
|                 else: | ||||
|                     correct[idx] = tag_index[tag] | ||||
|                 idx += 1 | ||||
|         correct = self.model.ops.xp.array(correct, dtype='i') | ||||
|         d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1]) | ||||
|  | @ -198,15 +178,16 @@ class NeuralTagger(object): | |||
|         cdef Vocab vocab = self.vocab | ||||
|         vocab.morphology = Morphology(vocab.strings, new_tag_map, | ||||
|                                       vocab.morphology.lemmatizer) | ||||
|         self.model = Softmax(self.vocab.morphology.n_tags) | ||||
|         print("Tagging", self.model.nO, "tags") | ||||
|         token_vector_width = pipeline[0].model.nO | ||||
|         self.model = rebatch(1024, Softmax(self.vocab.morphology.n_tags, | ||||
|                                           token_vector_width)) | ||||
|         #self.model = Softmax(self.vocab.morphology.n_tags) | ||||
| 
 | ||||
|     def use_params(self, params): | ||||
|         with self.model.use_params(params): | ||||
|             yield | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| cdef class EntityRecognizer(LinearParser): | ||||
|     """ | ||||
|     Annotate named entities on Doc objects. | ||||
|  | @ -275,8 +256,6 @@ cdef class NeuralEntityRecognizer(NeuralParser): | |||
|         return ids | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| cdef class BeamDependencyParser(BeamParser): | ||||
|     TransitionSystem = ArcEager | ||||
| 
 | ||||
|  |  | |||
|  | @ -35,12 +35,12 @@ from preshed.maps cimport map_get | |||
| 
 | ||||
| from thinc.api import layerize, chain | ||||
| from thinc.neural import Model, Affine, ELU, ReLu, Maxout | ||||
| from thinc.neural.ops import NumpyOps | ||||
| from thinc.neural.ops import NumpyOps, CupyOps | ||||
| 
 | ||||
| from .. import util | ||||
| from ..util import get_async, get_cuda_stream | ||||
| from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts | ||||
| from .._ml import Tok2Vec, doc2feats | ||||
| from .._ml import Tok2Vec, doc2feats, rebatch | ||||
| 
 | ||||
| from . import _parse_features | ||||
| from ._parse_features cimport CONTEXT_SIZE | ||||
|  | @ -229,6 +229,8 @@ cdef class Parser: | |||
|                     nI=token_vector_width, | ||||
|                     pieces=maxout_pieces) | ||||
| 
 | ||||
|         lower = rebatch(1024, lower) | ||||
| 
 | ||||
|         with Model.use_device('cpu'): | ||||
|             upper = chain( | ||||
|                         Maxout(hidden_width), | ||||
|  | @ -274,7 +276,7 @@ cdef class Parser: | |||
|     def __reduce__(self): | ||||
|         return (Parser, (self.vocab, self.moves, self.model), None, None) | ||||
| 
 | ||||
|     def __call__(self, Doc tokens, state=None): | ||||
|     def __call__(self, Doc doc): | ||||
|         """ | ||||
|         Apply the parser or entity recognizer, setting the annotations onto the Doc object. | ||||
| 
 | ||||
|  | @ -283,10 +285,9 @@ cdef class Parser: | |||
|         Returns: | ||||
|             None | ||||
|         """ | ||||
|         self.parse_batch([tokens], state['tokvecs']) | ||||
|         return state | ||||
|         self.parse_batch([doc], doc.tensor) | ||||
| 
 | ||||
|     def pipe(self, stream, int batch_size=1000, int n_threads=2): | ||||
|     def pipe(self, docs, int batch_size=1000, int n_threads=2): | ||||
|         """ | ||||
|         Process a stream of documents. | ||||
| 
 | ||||
|  | @ -301,12 +302,11 @@ cdef class Parser: | |||
|         cdef StateClass parse_state | ||||
|         cdef Doc doc | ||||
|         queue = [] | ||||
|         for batch in cytoolz.partition_all(batch_size, stream): | ||||
|             batch = list(batch) | ||||
|             docs, states = zip(*batch) | ||||
|             parse_states = self.parse_batch(docs, states[0]['tokvecs']) | ||||
|         for docs in cytoolz.partition_all(batch_size, docs): | ||||
|             tokvecs = self.model[0].ops.flatten([d.tensor for d in docs]) | ||||
|             parse_states = self.parse_batch(docs, tokvecs) | ||||
|             self.set_annotations(docs, parse_states) | ||||
|             yield from zip(docs, states) | ||||
|             yield from docs | ||||
| 
 | ||||
|     def parse_batch(self, docs, tokvecs): | ||||
|         cuda_stream = get_cuda_stream() | ||||
|  | @ -324,10 +324,8 @@ cdef class Parser: | |||
|             todo = [st for st in states if not st.is_final()] | ||||
|         return states | ||||
| 
 | ||||
|     def update(self, docs, golds, state=None, drop=0., sgd=None): | ||||
|         assert state is not None | ||||
|         assert 'tokvecs' in state | ||||
|         assert 'bp_tokvecs' in state | ||||
|     def update(self, docs_tokvecs, golds, drop=0., sgd=None): | ||||
|         docs, tokvecs = docs_tokvecs | ||||
|         if isinstance(docs, Doc) and isinstance(golds, GoldParse): | ||||
|             docs = [docs] | ||||
|             golds = [golds] | ||||
|  | @ -336,9 +334,6 @@ cdef class Parser: | |||
|         for gold in golds: | ||||
|             self.moves.preprocess_gold(gold) | ||||
| 
 | ||||
|         tokvecs = state['tokvecs'] | ||||
|         bp_tokvecs = state['bp_tokvecs'] | ||||
| 
 | ||||
|         states = self.moves.init_batch(docs) | ||||
|         state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, | ||||
|                                                       drop) | ||||
|  | @ -357,17 +352,17 @@ cdef class Parser: | |||
| 
 | ||||
|             d_scores = self.get_batch_loss(states, golds, scores) | ||||
|             d_vector = bp_scores(d_scores, sgd=sgd) | ||||
|             loss += (d_scores**2).sum() | ||||
| 
 | ||||
|             if not isinstance(tokvecs, state2vec.ops.xp.ndarray): | ||||
|                 backprops.append((token_ids, d_vector, bp_vector)) | ||||
|             else: | ||||
|             if isinstance(self.model[0].ops, CupyOps) \ | ||||
|             and not isinstance(token_ids, state2vec.ops.xp.ndarray): | ||||
|                 # Move token_ids and d_vector to CPU, asynchronously | ||||
|                 backprops.append(( | ||||
|                     get_async(cuda_stream, token_ids), | ||||
|                     get_async(cuda_stream, d_vector), | ||||
|                     bp_vector | ||||
|                 )) | ||||
|             else: | ||||
|                 backprops.append((token_ids, d_vector, bp_vector)) | ||||
|             self.transition_batch(states, scores) | ||||
|             todo = [st for st in todo if not st[0].is_final()] | ||||
|         # Tells CUDA to block, so our async copies complete. | ||||
|  | @ -385,9 +380,7 @@ cdef class Parser: | |||
|             else: | ||||
|                 xp.add.at(d_tokvecs, | ||||
|                     token_ids, d_state_features * active_feats) | ||||
|         bp_tokvecs(d_tokvecs, sgd) | ||||
|         state['parser_loss'] = loss | ||||
|         return state | ||||
|         return d_tokvecs | ||||
| 
 | ||||
|     def get_batch_model(self, batch_size, tokvecs, stream, dropout): | ||||
|         lower, upper = self.model | ||||
|  | @ -445,7 +438,6 @@ cdef class Parser: | |||
|             self.moves.finalize_doc(doc) | ||||
| 
 | ||||
|     def add_label(self, label): | ||||
|         # Doesn't set label into serializer -- subclasses override it to do that. | ||||
|         for action in self.moves.action_types: | ||||
|             added = self.moves.add_action(action, label) | ||||
|             if added: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user