mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Fix v4 branch to build against Thinc v9 (#11921)
* Move `thinc.extra.search` to `spacy.pipeline._parser_internals` Backport of: https://github.com/explosion/spaCy/pull/11317 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Replace references to `thinc.backends.linalg` with `CBlas` Backport of: https://github.com/explosion/spaCy/pull/11292 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Use cross entropy from `thinc.legacy` * Require thinc>=9.0.0.dev0,<9.1.0 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									ca75190a3d
								
							
						
					
					
						commit
						f9308aae13
					
				|  | @ -5,7 +5,7 @@ requires = [ | ||||||
|     "cymem>=2.0.2,<2.1.0", |     "cymem>=2.0.2,<2.1.0", | ||||||
|     "preshed>=3.0.2,<3.1.0", |     "preshed>=3.0.2,<3.1.0", | ||||||
|     "murmurhash>=0.28.0,<1.1.0", |     "murmurhash>=0.28.0,<1.1.0", | ||||||
|     "thinc>=8.1.0,<8.2.0", |     "thinc>=9.0.0.dev0,<9.1.0", | ||||||
|     "numpy>=1.15.0", |     "numpy>=1.15.0", | ||||||
| ] | ] | ||||||
| build-backend = "setuptools.build_meta" | build-backend = "setuptools.build_meta" | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ spacy-legacy>=3.0.10,<3.1.0 | ||||||
| spacy-loggers>=1.0.0,<2.0.0 | spacy-loggers>=1.0.0,<2.0.0 | ||||||
| cymem>=2.0.2,<2.1.0 | cymem>=2.0.2,<2.1.0 | ||||||
| preshed>=3.0.2,<3.1.0 | preshed>=3.0.2,<3.1.0 | ||||||
| thinc>=8.1.0,<8.2.0 | thinc>=9.0.0.dev0,<9.1.0 | ||||||
| ml_datasets>=0.2.0,<0.3.0 | ml_datasets>=0.2.0,<0.3.0 | ||||||
| murmurhash>=0.28.0,<1.1.0 | murmurhash>=0.28.0,<1.1.0 | ||||||
| wasabi>=0.9.1,<1.1.0 | wasabi>=0.9.1,<1.1.0 | ||||||
|  |  | ||||||
|  | @ -38,7 +38,7 @@ install_requires = | ||||||
|     murmurhash>=0.28.0,<1.1.0 |     murmurhash>=0.28.0,<1.1.0 | ||||||
|     cymem>=2.0.2,<2.1.0 |     cymem>=2.0.2,<2.1.0 | ||||||
|     preshed>=3.0.2,<3.1.0 |     preshed>=3.0.2,<3.1.0 | ||||||
|     thinc>=8.1.0,<8.2.0 |     thinc>=9.0.0.dev0,<9.1.0 | ||||||
|     wasabi>=0.9.1,<1.1.0 |     wasabi>=0.9.1,<1.1.0 | ||||||
|     srsly>=2.4.3,<3.0.0 |     srsly>=2.4.3,<3.0.0 | ||||||
|     catalogue>=2.0.6,<2.1.0 |     catalogue>=2.0.6,<2.1.0 | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							|  | @ -48,6 +48,7 @@ MOD_NAMES = [ | ||||||
|     "spacy.pipeline._parser_internals.arc_eager", |     "spacy.pipeline._parser_internals.arc_eager", | ||||||
|     "spacy.pipeline._parser_internals.ner", |     "spacy.pipeline._parser_internals.ner", | ||||||
|     "spacy.pipeline._parser_internals.nonproj", |     "spacy.pipeline._parser_internals.nonproj", | ||||||
|  |     "spacy.pipeline._parser_internals.search", | ||||||
|     "spacy.pipeline._parser_internals._state", |     "spacy.pipeline._parser_internals._state", | ||||||
|     "spacy.pipeline._parser_internals.stateclass", |     "spacy.pipeline._parser_internals.stateclass", | ||||||
|     "spacy.pipeline._parser_internals.transition_system", |     "spacy.pipeline._parser_internals.transition_system", | ||||||
|  | @ -67,6 +68,7 @@ MOD_NAMES = [ | ||||||
|     "spacy.matcher.dependencymatcher", |     "spacy.matcher.dependencymatcher", | ||||||
|     "spacy.symbols", |     "spacy.symbols", | ||||||
|     "spacy.vectors", |     "spacy.vectors", | ||||||
|  |     "spacy.tests.parser._search", | ||||||
| ] | ] | ||||||
| COMPILE_OPTIONS = { | COMPILE_OPTIONS = { | ||||||
|     "msvc": ["/Ox", "/EHsc"], |     "msvc": ["/Ox", "/EHsc"], | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ cimport numpy as np | ||||||
| from libc.math cimport exp | from libc.math cimport exp | ||||||
| from libc.string cimport memset, memcpy | from libc.string cimport memset, memcpy | ||||||
| from libc.stdlib cimport calloc, free, realloc | from libc.stdlib cimport calloc, free, realloc | ||||||
| from thinc.backends.linalg cimport Vec, VecVec |  | ||||||
| from thinc.backends.cblas cimport saxpy, sgemm | from thinc.backends.cblas cimport saxpy, sgemm | ||||||
| 
 | 
 | ||||||
| import numpy | import numpy | ||||||
|  | @ -102,11 +101,10 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states, | ||||||
|     sum_state_features(cblas, A.unmaxed, |     sum_state_features(cblas, A.unmaxed, | ||||||
|         W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces) |         W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces) | ||||||
|     for i in range(n.states): |     for i in range(n.states): | ||||||
|         VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces], |         saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1) | ||||||
|             W.feat_bias, 1., n.hiddens * n.pieces) |  | ||||||
|         for j in range(n.hiddens): |         for j in range(n.hiddens): | ||||||
|             index = i * n.hiddens * n.pieces + j * n.pieces |             index = i * n.hiddens * n.pieces + j * n.pieces | ||||||
|             which = Vec.arg_max(&A.unmaxed[index], n.pieces) |             which = _arg_max(&A.unmaxed[index], n.pieces) | ||||||
|             A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] |             A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] | ||||||
|     memset(A.scores, 0, n.states * n.classes * sizeof(float)) |     memset(A.scores, 0, n.states * n.classes * sizeof(float)) | ||||||
|     if W.hidden_weights == NULL: |     if W.hidden_weights == NULL: | ||||||
|  | @ -119,8 +117,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states, | ||||||
|             0.0, A.scores, n.classes) |             0.0, A.scores, n.classes) | ||||||
|         # Add bias |         # Add bias | ||||||
|         for i in range(n.states): |         for i in range(n.states): | ||||||
|             VecVec.add_i(&A.scores[i*n.classes], |             saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1) | ||||||
|                 W.hidden_bias, 1., n.classes) |  | ||||||
|     # Set unseen classes to minimum value |     # Set unseen classes to minimum value | ||||||
|     i = 0 |     i = 0 | ||||||
|     min_ = A.scores[0] |     min_ = A.scores[0] | ||||||
|  | @ -158,7 +155,8 @@ cdef void cpu_log_loss(float* d_scores, | ||||||
|     """Do multi-label log loss""" |     """Do multi-label log loss""" | ||||||
|     cdef double max_, gmax, Z, gZ |     cdef double max_, gmax, Z, gZ | ||||||
|     best = arg_max_if_gold(scores, costs, is_valid, O) |     best = arg_max_if_gold(scores, costs, is_valid, O) | ||||||
|     guess = Vec.arg_max(scores, O) |     guess = _arg_max(scores, O) | ||||||
|  | 
 | ||||||
|     if best == -1 or guess == -1: |     if best == -1 or guess == -1: | ||||||
|         # These shouldn't happen, but if they do, we want to make sure we don't |         # These shouldn't happen, but if they do, we want to make sure we don't | ||||||
|         # cause an OOB access. |         # cause an OOB access. | ||||||
|  | @ -488,3 +486,15 @@ cdef class precompute_hiddens: | ||||||
|             return d_best.reshape((d_best.shape + (1,))) |             return d_best.reshape((d_best.shape + (1,))) | ||||||
|   |   | ||||||
|         return state_vector, backprop_relu |         return state_vector, backprop_relu | ||||||
|  | 
 | ||||||
|  | cdef inline int _arg_max(const float* scores, const int n_classes) nogil: | ||||||
|  |     if n_classes == 2: | ||||||
|  |         return 0 if scores[0] > scores[1] else 1 | ||||||
|  |     cdef int i | ||||||
|  |     cdef int best = 0 | ||||||
|  |     cdef float mode = scores[0] | ||||||
|  |     for i in range(1, n_classes): | ||||||
|  |         if scores[i] > mode: | ||||||
|  |             mode = scores[i] | ||||||
|  |             best = i | ||||||
|  |     return best | ||||||
|  |  | ||||||
|  | @ -1,6 +1,6 @@ | ||||||
| from ...typedefs cimport class_t, hash_t | from ...typedefs cimport class_t, hash_t | ||||||
| 
 | 
 | ||||||
| # These are passed as callbacks to thinc.search.Beam | # These are passed as callbacks to .search.Beam | ||||||
| cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 | cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 | ||||||
| 
 | 
 | ||||||
| cdef int check_final_state(void* _state, void* extra_args) except -1 | cdef int check_final_state(void* _state, void* extra_args) except -1 | ||||||
|  |  | ||||||
|  | @ -3,17 +3,16 @@ | ||||||
| cimport numpy as np | cimport numpy as np | ||||||
| import numpy | import numpy | ||||||
| from cpython.ref cimport PyObject, Py_XDECREF | from cpython.ref cimport PyObject, Py_XDECREF | ||||||
| from thinc.extra.search cimport Beam |  | ||||||
| from thinc.extra.search import MaxViolation |  | ||||||
| from thinc.extra.search cimport MaxViolation |  | ||||||
| 
 | 
 | ||||||
| from ...typedefs cimport hash_t, class_t | from ...typedefs cimport hash_t, class_t | ||||||
| from .transition_system cimport TransitionSystem, Transition | from .transition_system cimport TransitionSystem, Transition | ||||||
| from ...errors import Errors | from ...errors import Errors | ||||||
|  | from .search cimport Beam, MaxViolation | ||||||
|  | from .search import MaxViolation | ||||||
| from .stateclass cimport StateC, StateClass | from .stateclass cimport StateC, StateClass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # These are passed as callbacks to thinc.search.Beam | # These are passed as callbacks to .search.Beam | ||||||
| cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: | cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: | ||||||
|     dest = <StateC*>_dest |     dest = <StateC*>_dest | ||||||
|     src = <StateC*>_src |     src = <StateC*>_src | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ from ...training.example cimport Example | ||||||
| from .stateclass cimport StateClass | from .stateclass cimport StateClass | ||||||
| from ._state cimport StateC, ArcC | from ._state cimport StateC, ArcC | ||||||
| from ...errors import Errors | from ...errors import Errors | ||||||
| from thinc.extra.search cimport Beam | from .search cimport Beam | ||||||
| 
 | 
 | ||||||
| cdef weight_t MIN_SCORE = -90000 | cdef weight_t MIN_SCORE = -90000 | ||||||
| cdef attr_t SUBTOK_LABEL = hash_string('subtok') | cdef attr_t SUBTOK_LABEL = hash_string('subtok') | ||||||
|  |  | ||||||
|  | @ -6,7 +6,6 @@ from libcpp.vector cimport vector | ||||||
| from cymem.cymem cimport Pool | from cymem.cymem cimport Pool | ||||||
| 
 | 
 | ||||||
| from collections import Counter | from collections import Counter | ||||||
| from thinc.extra.search cimport Beam |  | ||||||
| 
 | 
 | ||||||
| from ...tokens.doc cimport Doc | from ...tokens.doc cimport Doc | ||||||
| from ...tokens.span import Span | from ...tokens.span import Span | ||||||
|  | @ -17,6 +16,7 @@ from ...attrs cimport IS_SPACE | ||||||
| from ...structs cimport TokenC, SpanC | from ...structs cimport TokenC, SpanC | ||||||
| from ...training import split_bilu_label | from ...training import split_bilu_label | ||||||
| from ...training.example cimport Example | from ...training.example cimport Example | ||||||
|  | from .search cimport Beam | ||||||
| from .stateclass cimport StateClass | from .stateclass cimport StateClass | ||||||
| from ._state cimport StateC | from ._state cimport StateC | ||||||
| from .transition_system cimport Transition, do_func_t | from .transition_system cimport Transition, do_func_t | ||||||
|  |  | ||||||
							
								
								
									
										89
									
								
								spacy/pipeline/_parser_internals/search.pxd
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								spacy/pipeline/_parser_internals/search.pxd
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,89 @@ | ||||||
|  | from cymem.cymem cimport Pool | ||||||
|  | 
 | ||||||
|  | from libc.stdint cimport uint32_t | ||||||
|  | from libc.stdint cimport uint64_t | ||||||
|  | from libcpp.pair cimport pair | ||||||
|  | from libcpp.queue cimport priority_queue | ||||||
|  | from libcpp.vector cimport vector | ||||||
|  | 
 | ||||||
|  | from ...typedefs cimport class_t, weight_t, hash_t | ||||||
|  | 
 | ||||||
|  | ctypedef pair[weight_t, size_t] Entry | ||||||
|  | ctypedef priority_queue[Entry] Queue | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1 | ||||||
|  | 
 | ||||||
|  | ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL | ||||||
|  | 
 | ||||||
|  | ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1 | ||||||
|  | 
 | ||||||
|  | ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1 | ||||||
|  | 
 | ||||||
|  | ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef struct _State: | ||||||
|  |     void* content | ||||||
|  |     class_t* hist | ||||||
|  |     weight_t score | ||||||
|  |     weight_t loss | ||||||
|  |     int i | ||||||
|  |     int t | ||||||
|  |     bint is_done | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef class Beam: | ||||||
|  |     cdef Pool mem | ||||||
|  |     cdef class_t nr_class | ||||||
|  |     cdef class_t width | ||||||
|  |     cdef class_t size | ||||||
|  |     cdef public weight_t min_density | ||||||
|  |     cdef int t | ||||||
|  |     cdef readonly bint is_done | ||||||
|  |     cdef list histories | ||||||
|  |     cdef list _parent_histories | ||||||
|  |     cdef weight_t** scores | ||||||
|  |     cdef int** is_valid | ||||||
|  |     cdef weight_t** costs | ||||||
|  |     cdef _State* _parents | ||||||
|  |     cdef _State* _states | ||||||
|  |     cdef del_func_t del_func | ||||||
|  | 
 | ||||||
|  |     cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1 | ||||||
|  | 
 | ||||||
|  |     cdef inline void* at(self, int i) nogil: | ||||||
|  |         return self._states[i].content | ||||||
|  | 
 | ||||||
|  |     cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1 | ||||||
|  |     cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func, | ||||||
|  |                      void* extra_args) except -1 | ||||||
|  |     cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1 | ||||||
|  |   | ||||||
|  | 
 | ||||||
|  |     cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil: | ||||||
|  |         self.scores[i][j] = score | ||||||
|  |         self.is_valid[i][j] = is_valid | ||||||
|  |         self.costs[i][j] = cost | ||||||
|  | 
 | ||||||
|  |     cdef int set_row(self, int i, const weight_t* scores, const int* is_valid, | ||||||
|  |                      const weight_t* costs) except -1 | ||||||
|  |     cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef class MaxViolation: | ||||||
|  |     cdef Pool mem | ||||||
|  |     cdef weight_t cost | ||||||
|  |     cdef weight_t delta | ||||||
|  |     cdef readonly weight_t p_score | ||||||
|  |     cdef readonly weight_t g_score | ||||||
|  |     cdef readonly double Z | ||||||
|  |     cdef readonly double gZ | ||||||
|  |     cdef class_t n | ||||||
|  |     cdef readonly list p_hist | ||||||
|  |     cdef readonly list g_hist | ||||||
|  |     cdef readonly list p_probs | ||||||
|  |     cdef readonly list g_probs | ||||||
|  | 
 | ||||||
|  |     cpdef int check(self, Beam pred, Beam gold) except -1 | ||||||
|  |     cpdef int check_crf(self, Beam pred, Beam gold) except -1 | ||||||
							
								
								
									
										306
									
								
								spacy/pipeline/_parser_internals/search.pyx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										306
									
								
								spacy/pipeline/_parser_internals/search.pyx
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,306 @@ | ||||||
|  | # cython: profile=True, experimental_cpp_class_def=True, cdivision=True, infer_types=True | ||||||
|  | cimport cython | ||||||
|  | from libc.string cimport memset, memcpy | ||||||
|  | from libc.math cimport log, exp | ||||||
|  | import math | ||||||
|  | 
 | ||||||
|  | from cymem.cymem cimport Pool | ||||||
|  | from preshed.maps cimport PreshMap | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef class Beam: | ||||||
|  |     def __init__(self, class_t nr_class, class_t width, weight_t min_density=0.0): | ||||||
|  |         assert nr_class != 0 | ||||||
|  |         assert width != 0 | ||||||
|  |         self.nr_class = nr_class | ||||||
|  |         self.width = width | ||||||
|  |         self.min_density = min_density | ||||||
|  |         self.size = 1 | ||||||
|  |         self.t = 0 | ||||||
|  |         self.mem = Pool() | ||||||
|  |         self.del_func = NULL | ||||||
|  |         self._parents = <_State*>self.mem.alloc(self.width, sizeof(_State)) | ||||||
|  |         self._states = <_State*>self.mem.alloc(self.width, sizeof(_State)) | ||||||
|  |         cdef int i | ||||||
|  |         self.histories = [[] for i in range(self.width)] | ||||||
|  |         self._parent_histories = [[] for i in range(self.width)] | ||||||
|  | 
 | ||||||
|  |         self.scores = <weight_t**>self.mem.alloc(self.width, sizeof(weight_t*)) | ||||||
|  |         self.is_valid = <int**>self.mem.alloc(self.width, sizeof(weight_t*)) | ||||||
|  |         self.costs = <weight_t**>self.mem.alloc(self.width, sizeof(weight_t*)) | ||||||
|  |         for i in range(self.width): | ||||||
|  |             self.scores[i] = <weight_t*>self.mem.alloc(self.nr_class, sizeof(weight_t)) | ||||||
|  |             self.is_valid[i] = <int*>self.mem.alloc(self.nr_class, sizeof(int)) | ||||||
|  |             self.costs[i] = <weight_t*>self.mem.alloc(self.nr_class, sizeof(weight_t)) | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return self.size | ||||||
|  | 
 | ||||||
|  |     property score: | ||||||
|  |         def __get__(self): | ||||||
|  |             return self._states[0].score | ||||||
|  | 
 | ||||||
|  |     property min_score: | ||||||
|  |         def __get__(self): | ||||||
|  |             return self._states[self.size-1].score | ||||||
|  | 
 | ||||||
|  |     property loss: | ||||||
|  |         def __get__(self): | ||||||
|  |             return self._states[0].loss | ||||||
|  | 
 | ||||||
|  |     property probs: | ||||||
|  |         def __get__(self): | ||||||
|  |             return _softmax([self._states[i].score for i in range(self.size)]) | ||||||
|  | 
 | ||||||
|  |     property scores: | ||||||
|  |         def __get__(self): | ||||||
|  |             return [self._states[i].score for i in range(self.size)] | ||||||
|  | 
 | ||||||
|  |     property histories: | ||||||
|  |         def __get__(self): | ||||||
|  |             return self.histories | ||||||
|  | 
 | ||||||
|  |     cdef int set_row(self, int i, const weight_t* scores, const int* is_valid, | ||||||
|  |                      const weight_t* costs) except -1: | ||||||
|  |         cdef int j | ||||||
|  |         for j in range(self.nr_class): | ||||||
|  |             self.scores[i][j] = scores[j] | ||||||
|  |             self.is_valid[i][j] = is_valid[j] | ||||||
|  |             self.costs[i][j] = costs[j] | ||||||
|  | 
 | ||||||
|  |     cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1: | ||||||
|  |         cdef int i, j | ||||||
|  |         for i in range(self.width): | ||||||
|  |             memcpy(self.scores[i], scores[i], sizeof(weight_t) * self.nr_class) | ||||||
|  |             memcpy(self.is_valid[i], is_valid[i], sizeof(bint) * self.nr_class) | ||||||
|  |             memcpy(self.costs[i], costs[i], sizeof(int) * self.nr_class) | ||||||
|  | 
 | ||||||
|  |     cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1: | ||||||
|  |         for i in range(self.width): | ||||||
|  |             self._states[i].content = init_func(self.mem, n, extra_args) | ||||||
|  |             self._parents[i].content = init_func(self.mem, n, extra_args) | ||||||
|  |         self.del_func = del_func | ||||||
|  | 
 | ||||||
|  |     def __dealloc__(self): | ||||||
|  |         if self.del_func == NULL: | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         for i in range(self.width): | ||||||
|  |             self.del_func(self.mem, self._states[i].content, NULL) | ||||||
|  |             self.del_func(self.mem, self._parents[i].content, NULL) | ||||||
|  | 
 | ||||||
|  |     @cython.cdivision(True) | ||||||
|  |     cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func, | ||||||
|  |                      void* extra_args) except -1: | ||||||
|  |         cdef weight_t** scores = self.scores | ||||||
|  |         cdef int** is_valid = self.is_valid | ||||||
|  |         cdef weight_t** costs = self.costs | ||||||
|  | 
 | ||||||
|  |         cdef Queue* q = new Queue() | ||||||
|  |         self._fill(q, scores, is_valid) | ||||||
|  |         # For a beam of width k, we only ever need 2k state objects. How? | ||||||
|  |         # Each transition takes a parent and a class and produces a new state. | ||||||
|  |         # So, we don't need the whole history --- just the parent. So at | ||||||
|  |         # each step, we take a parent, and apply one or more extensions to | ||||||
|  |         # it. | ||||||
|  |         self._parents, self._states = self._states, self._parents | ||||||
|  |         self._parent_histories, self.histories = self.histories, self._parent_histories | ||||||
|  |         cdef weight_t score | ||||||
|  |         cdef int p_i | ||||||
|  |         cdef int i = 0 | ||||||
|  |         cdef class_t clas | ||||||
|  |         cdef _State* parent | ||||||
|  |         cdef _State* state | ||||||
|  |         cdef hash_t key | ||||||
|  |         cdef PreshMap seen_states = PreshMap(self.width) | ||||||
|  |         cdef uint64_t is_seen | ||||||
|  |         cdef uint64_t one = 1 | ||||||
|  |         while i < self.width and not q.empty(): | ||||||
|  |             data = q.top() | ||||||
|  |             p_i = data.second / self.nr_class | ||||||
|  |             clas = data.second % self.nr_class | ||||||
|  |             score = data.first | ||||||
|  |             q.pop() | ||||||
|  |             parent = &self._parents[p_i] | ||||||
|  |             # Indicates terminal state reached; i.e. state is done | ||||||
|  |             if parent.is_done: | ||||||
|  |                 # Now parent will not be changed, so we don't have to copy. | ||||||
|  |                 # Once finished, should also be unbranching. | ||||||
|  |                 self._states[i], parent[0] = parent[0], self._states[i] | ||||||
|  |                 parent.i = self._states[i].i | ||||||
|  |                 parent.t = self._states[i].t | ||||||
|  |                 parent.is_done = self._states[i].t | ||||||
|  |                 self._states[i].score = score | ||||||
|  |                 self.histories[i] = list(self._parent_histories[p_i]) | ||||||
|  |                 i += 1 | ||||||
|  |             else: | ||||||
|  |                 state = &self._states[i] | ||||||
|  |                 # The supplied transition function should adjust the destination | ||||||
|  |                 # state to be the result of applying the class to the source state | ||||||
|  |                 transition_func(state.content, parent.content, clas, extra_args) | ||||||
|  |                 key = hash_func(state.content, extra_args) if hash_func is not NULL else 0 | ||||||
|  |                 is_seen = <uint64_t>seen_states.get(key) | ||||||
|  |                 if key == 0 or key == 1 or not is_seen: | ||||||
|  |                     if key != 0 and key != 1: | ||||||
|  |                         seen_states.set(key, <void*>one) | ||||||
|  |                     state.score = score | ||||||
|  |                     state.loss = parent.loss + costs[p_i][clas] | ||||||
|  |                     self.histories[i] = list(self._parent_histories[p_i]) | ||||||
|  |                     self.histories[i].append(clas) | ||||||
|  |                     i += 1 | ||||||
|  |         del q | ||||||
|  |         self.size = i | ||||||
|  |         assert self.size >= 1 | ||||||
|  |         for i in range(self.width): | ||||||
|  |             memset(self.scores[i], 0, sizeof(weight_t) * self.nr_class) | ||||||
|  |             memset(self.costs[i], 0, sizeof(weight_t) * self.nr_class) | ||||||
|  |             memset(self.is_valid[i], 0, sizeof(int) * self.nr_class) | ||||||
|  |         self.t += 1 | ||||||
|  | 
 | ||||||
|  |     cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1: | ||||||
|  |         cdef int i | ||||||
|  |         for i in range(self.size): | ||||||
|  |             if not self._states[i].is_done: | ||||||
|  |                 self._states[i].is_done = finish_func(self._states[i].content, extra_args) | ||||||
|  |         for i in range(self.size): | ||||||
|  |             if not self._states[i].is_done: | ||||||
|  |                 self.is_done = False | ||||||
|  |                 break | ||||||
|  |         else: | ||||||
|  |             self.is_done = True | ||||||
|  | 
 | ||||||
|  |     @cython.cdivision(True) | ||||||
|  |     cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1: | ||||||
|  |         """Populate the queue from a k * n matrix of scores, where k is the | ||||||
|  |         beam-width, and n is the number of classes. | ||||||
|  |         """ | ||||||
|  |         cdef Entry entry | ||||||
|  |         cdef weight_t score | ||||||
|  |         cdef _State* s | ||||||
|  |         cdef int i, j, move_id | ||||||
|  |         assert self.size >= 1 | ||||||
|  |         cdef vector[Entry] entries | ||||||
|  |         for i in range(self.size): | ||||||
|  |             s = &self._states[i] | ||||||
|  |             move_id = i * self.nr_class | ||||||
|  |             if s.is_done: | ||||||
|  |                 # Update score by path average, following TACL '13 paper. | ||||||
|  |                 if self.histories[i]: | ||||||
|  |                     entry.first = s.score + (s.score / self.t) | ||||||
|  |                 else: | ||||||
|  |                     entry.first = s.score | ||||||
|  |                 entry.second = move_id | ||||||
|  |                 entries.push_back(entry) | ||||||
|  |             else: | ||||||
|  |                 for j in range(self.nr_class): | ||||||
|  |                     if is_valid[i][j]: | ||||||
|  |                         entry.first = s.score + scores[i][j] | ||||||
|  |                         entry.second = move_id + j | ||||||
|  |                         entries.push_back(entry) | ||||||
|  |         cdef double max_, Z, cutoff | ||||||
|  |         if self.min_density == 0.0: | ||||||
|  |             for i in range(entries.size()): | ||||||
|  |                 q.push(entries[i]) | ||||||
|  |         elif not entries.empty(): | ||||||
|  |             max_ = entries[0].first | ||||||
|  |             Z = 0. | ||||||
|  |             cutoff = 0. | ||||||
|  |             # Softmax into probabilities, so we can prune | ||||||
|  |             for i in range(entries.size()): | ||||||
|  |                 if entries[i].first > max_: | ||||||
|  |                     max_ = entries[i].first | ||||||
|  |             for i in range(entries.size()): | ||||||
|  |                 Z += exp(entries[i].first-max_) | ||||||
|  |             cutoff = (1. / Z) * self.min_density | ||||||
|  |             for i in range(entries.size()): | ||||||
|  |                 prob = exp(entries[i].first-max_) / Z | ||||||
|  |                 if prob >= cutoff: | ||||||
|  |                     q.push(entries[i]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef class MaxViolation: | ||||||
|  |     def __init__(self): | ||||||
|  |         self.p_score = 0.0 | ||||||
|  |         self.g_score = 0.0 | ||||||
|  |         self.Z = 0.0 | ||||||
|  |         self.gZ = 0.0 | ||||||
|  |         self.delta = -1 | ||||||
|  |         self.cost = 0 | ||||||
|  |         self.p_hist = [] | ||||||
|  |         self.g_hist = [] | ||||||
|  |         self.p_probs = [] | ||||||
|  |         self.g_probs = [] | ||||||
|  | 
 | ||||||
|  |     cpdef int check(self, Beam pred, Beam gold) except -1: | ||||||
|  |         cdef _State* p = &pred._states[0] | ||||||
|  |         cdef _State* g = &gold._states[0] | ||||||
|  |         cdef weight_t d = p.score - g.score | ||||||
|  |         if p.loss >= 1 and (self.cost == 0 or d > self.delta): | ||||||
|  |             self.cost = p.loss | ||||||
|  |             self.delta = d | ||||||
|  |             self.p_hist = list(pred.histories[0]) | ||||||
|  |             self.g_hist = list(gold.histories[0]) | ||||||
|  |             self.p_score = p.score | ||||||
|  |             self.g_score = g.score | ||||||
|  |             self.Z = 1e-10 | ||||||
|  |             self.gZ = 1e-10 | ||||||
|  |             for i in range(pred.size): | ||||||
|  |                 if pred._states[i].loss > 0: | ||||||
|  |                     self.Z += exp(pred._states[i].score) | ||||||
|  |             for i in range(gold.size): | ||||||
|  |                 if gold._states[i].loss == 0: | ||||||
|  |                     prob = exp(gold._states[i].score) | ||||||
|  |                     self.Z += prob | ||||||
|  |                     self.gZ += prob | ||||||
|  | 
 | ||||||
|  |     cpdef int check_crf(self, Beam pred, Beam gold) except -1: | ||||||
|  |         d = pred.score - gold.score | ||||||
|  |         seen_golds = set([tuple(gold.histories[i]) for i in range(gold.size)]) | ||||||
|  |         if pred.loss > 0 and (self.cost == 0 or d > self.delta): | ||||||
|  |             p_hist = [] | ||||||
|  |             p_scores = [] | ||||||
|  |             g_hist = [] | ||||||
|  |             g_scores = [] | ||||||
|  |             for i in range(pred.size): | ||||||
|  |                 if pred._states[i].loss > 0: | ||||||
|  |                     p_scores.append(pred._states[i].score) | ||||||
|  |                     p_hist.append(list(pred.histories[i])) | ||||||
|  |                 # This can happen from non-monotonic actions | ||||||
|  |                 # If we find a better gold analysis this way, be sure to keep it. | ||||||
|  |                 elif pred._states[i].loss <= 0 \ | ||||||
|  |                 and tuple(pred.histories[i]) not in seen_golds: | ||||||
|  |                     g_scores.append(pred._states[i].score) | ||||||
|  |                     g_hist.append(list(pred.histories[i])) | ||||||
|  |             for i in range(gold.size): | ||||||
|  |                 if gold._states[i].loss == 0: | ||||||
|  |                     g_scores.append(gold._states[i].score) | ||||||
|  |                     g_hist.append(list(gold.histories[i])) | ||||||
|  | 
 | ||||||
|  |             all_probs = _softmax(p_scores + g_scores) | ||||||
|  |             p_probs = all_probs[:len(p_scores)] | ||||||
|  |             g_probs_all = all_probs[len(p_scores):] | ||||||
|  |             g_probs = _softmax(g_scores) | ||||||
|  | 
 | ||||||
|  |             self.cost = pred.loss | ||||||
|  |             self.delta = d | ||||||
|  |             self.p_hist = p_hist | ||||||
|  |             self.g_hist = g_hist | ||||||
|  |             # TODO: These variables are misnamed! These are the gradients of the loss. | ||||||
|  |             self.p_probs = p_probs | ||||||
|  |             # Intuition here: | ||||||
|  |             # The gradient of the loss is: | ||||||
|  |             # P(model) - P(truth) | ||||||
|  |             # Normally, P(truth) is 1 for the gold | ||||||
|  |             # But, if we want to do the "partial credit" scheme, we want | ||||||
|  |             # to create a distribution over the gold, proportional to the scores | ||||||
|  |             # awarded. | ||||||
|  |             self.g_probs = [x-y for x, y in zip(g_probs_all, g_probs)] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _softmax(nums): | ||||||
|  |     if not nums: | ||||||
|  |         return [] | ||||||
|  |     max_ = max(nums) | ||||||
|  |     nums = [(exp(n-max_) if n is not None else None) for n in nums] | ||||||
|  |     Z = sum(n for n in nums if n is not None) | ||||||
|  |     return [(n/Z if n is not None else None) for n in nums] | ||||||
|  | @ -5,8 +5,9 @@ from itertools import islice | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| import srsly | import srsly | ||||||
| from thinc.api import Config, Model, SequenceCategoricalCrossentropy | from thinc.api import Config, Model | ||||||
| from thinc.types import ArrayXd, Floats2d, Ints1d | from thinc.types import ArrayXd, Floats2d, Ints1d | ||||||
|  | from thinc.legacy import LegacySequenceCategoricalCrossentropy | ||||||
| 
 | 
 | ||||||
| from ._edit_tree_internals.edit_trees import EditTrees | from ._edit_tree_internals.edit_trees import EditTrees | ||||||
| from ._edit_tree_internals.schemas import validate_edit_tree | from ._edit_tree_internals.schemas import validate_edit_tree | ||||||
|  | @ -129,7 +130,9 @@ class EditTreeLemmatizer(TrainablePipe): | ||||||
|         self, examples: Iterable[Example], scores: List[Floats2d] |         self, examples: Iterable[Example], scores: List[Floats2d] | ||||||
|     ) -> Tuple[float, List[Floats2d]]: |     ) -> Tuple[float, List[Floats2d]]: | ||||||
|         validate_examples(examples, "EditTreeLemmatizer.get_loss") |         validate_examples(examples, "EditTreeLemmatizer.get_loss") | ||||||
|         loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) |         loss_func = LegacySequenceCategoricalCrossentropy( | ||||||
|  |             normalize=False, missing_value=-1 | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         truths = [] |         truths = [] | ||||||
|         for eg in examples: |         for eg in examples: | ||||||
|  |  | ||||||
|  | @ -1,7 +1,8 @@ | ||||||
| # cython: infer_types=True, profile=True, binding=True | # cython: infer_types=True, profile=True, binding=True | ||||||
| from typing import Callable, Dict, Iterable, List, Optional, Union | from typing import Callable, Dict, Iterable, List, Optional, Union | ||||||
| import srsly | import srsly | ||||||
| from thinc.api import SequenceCategoricalCrossentropy, Model, Config | from thinc.api import Model, Config | ||||||
|  | from thinc.legacy import LegacySequenceCategoricalCrossentropy | ||||||
| from thinc.types import Floats2d, Ints1d | from thinc.types import Floats2d, Ints1d | ||||||
| from itertools import islice | from itertools import islice | ||||||
| 
 | 
 | ||||||
|  | @ -290,7 +291,7 @@ class Morphologizer(Tagger): | ||||||
|         DOCS: https://spacy.io/api/morphologizer#get_loss |         DOCS: https://spacy.io/api/morphologizer#get_loss | ||||||
|         """ |         """ | ||||||
|         validate_examples(examples, "Morphologizer.get_loss") |         validate_examples(examples, "Morphologizer.get_loss") | ||||||
|         loss_func = SequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False) |         loss_func = LegacySequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False) | ||||||
|         truths = [] |         truths = [] | ||||||
|         for eg in examples: |         for eg in examples: | ||||||
|             eg_truths = [] |             eg_truths = [] | ||||||
|  |  | ||||||
|  | @ -3,7 +3,9 @@ from typing import Dict, Iterable, Optional, Callable, List, Union | ||||||
| from itertools import islice | from itertools import islice | ||||||
| 
 | 
 | ||||||
| import srsly | import srsly | ||||||
| from thinc.api import Model, SequenceCategoricalCrossentropy, Config | from thinc.api import Model, Config | ||||||
|  | from thinc.legacy import LegacySequenceCategoricalCrossentropy | ||||||
|  | 
 | ||||||
| from thinc.types import Floats2d, Ints1d | from thinc.types import Floats2d, Ints1d | ||||||
| 
 | 
 | ||||||
| from ..tokens.doc cimport Doc | from ..tokens.doc cimport Doc | ||||||
|  | @ -161,7 +163,7 @@ class SentenceRecognizer(Tagger): | ||||||
|         """ |         """ | ||||||
|         validate_examples(examples, "SentenceRecognizer.get_loss") |         validate_examples(examples, "SentenceRecognizer.get_loss") | ||||||
|         labels = self.labels |         labels = self.labels | ||||||
|         loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False) |         loss_func = LegacySequenceCategoricalCrossentropy(names=labels, normalize=False) | ||||||
|         truths = [] |         truths = [] | ||||||
|         for eg in examples: |         for eg in examples: | ||||||
|             eg_truth = [] |             eg_truth = [] | ||||||
|  |  | ||||||
|  | @ -2,7 +2,8 @@ | ||||||
| from typing import Callable, Dict, Iterable, List, Optional, Union | from typing import Callable, Dict, Iterable, List, Optional, Union | ||||||
| import numpy | import numpy | ||||||
| import srsly | import srsly | ||||||
| from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config | from thinc.api import Model, set_dropout_rate, Config | ||||||
|  | from thinc.legacy import LegacySequenceCategoricalCrossentropy | ||||||
| from thinc.types import Floats2d, Ints1d | from thinc.types import Floats2d, Ints1d | ||||||
| import warnings | import warnings | ||||||
| from itertools import islice | from itertools import islice | ||||||
|  | @ -244,7 +245,7 @@ class Tagger(TrainablePipe): | ||||||
| 
 | 
 | ||||||
|         DOCS: https://spacy.io/api/tagger#rehearse |         DOCS: https://spacy.io/api/tagger#rehearse | ||||||
|         """ |         """ | ||||||
|         loss_func = SequenceCategoricalCrossentropy() |         loss_func = LegacySequenceCategoricalCrossentropy() | ||||||
|         if losses is None: |         if losses is None: | ||||||
|             losses = {} |             losses = {} | ||||||
|         losses.setdefault(self.name, 0.0) |         losses.setdefault(self.name, 0.0) | ||||||
|  | @ -275,7 +276,7 @@ class Tagger(TrainablePipe): | ||||||
|         DOCS: https://spacy.io/api/tagger#get_loss |         DOCS: https://spacy.io/api/tagger#get_loss | ||||||
|         """ |         """ | ||||||
|         validate_examples(examples, "Tagger.get_loss") |         validate_examples(examples, "Tagger.get_loss") | ||||||
|         loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"]) |         loss_func = LegacySequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"]) | ||||||
|         # Convert empty tag "" to missing value None so that both misaligned |         # Convert empty tag "" to missing value None so that both misaligned | ||||||
|         # tokens and tokens with missing annotation have the default missing |         # tokens and tokens with missing annotation have the default missing | ||||||
|         # value None. |         # value None. | ||||||
|  |  | ||||||
|  | @ -10,12 +10,12 @@ import random | ||||||
| 
 | 
 | ||||||
| import srsly | import srsly | ||||||
| from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps | from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps | ||||||
| from thinc.extra.search cimport Beam |  | ||||||
| import numpy.random | import numpy.random | ||||||
| import numpy | import numpy | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| from ._parser_internals.stateclass cimport StateClass | from ._parser_internals.stateclass cimport StateClass | ||||||
|  | from ._parser_internals.search cimport Beam | ||||||
| from ..ml.parser_model cimport alloc_activations, free_activations | from ..ml.parser_model cimport alloc_activations, free_activations | ||||||
| from ..ml.parser_model cimport predict_states, arg_max_if_valid | from ..ml.parser_model cimport predict_states, arg_max_if_valid | ||||||
| from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss | from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss | ||||||
|  |  | ||||||
|  | @ -1,6 +1,10 @@ | ||||||
| import pytest | import pytest | ||||||
| from spacy.util import get_lang_class | from spacy.util import get_lang_class | ||||||
|  | import functools | ||||||
| from hypothesis import settings | from hypothesis import settings | ||||||
|  | import inspect | ||||||
|  | import importlib | ||||||
|  | import sys | ||||||
| 
 | 
 | ||||||
| # Functionally disable deadline settings for tests | # Functionally disable deadline settings for tests | ||||||
| # to prevent spurious test failures in CI builds. | # to prevent spurious test failures in CI builds. | ||||||
|  | @ -47,6 +51,33 @@ def pytest_runtest_setup(item): | ||||||
|             pytest.skip("not referencing any issues") |             pytest.skip("not referencing any issues") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # Decorator for Cython-built tests | ||||||
|  | # https://shwina.github.io/cython-testing/ | ||||||
|  | def cytest(func): | ||||||
|  |     """ | ||||||
|  |     Wraps `func` in a plain Python function. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     @functools.wraps(func) | ||||||
|  |     def wrapped(*args, **kwargs): | ||||||
|  |         bound = inspect.signature(func).bind(*args, **kwargs) | ||||||
|  |         return func(*bound.args, **bound.kwargs) | ||||||
|  | 
 | ||||||
|  |     return wrapped | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def register_cython_tests(cython_mod_name: str, test_mod_name: str): | ||||||
|  |     """ | ||||||
|  |     Registers all callables with name `test_*` in Cython module `cython_mod_name` | ||||||
|  |     as attributes in module `test_mod_name`, making them discoverable by pytest. | ||||||
|  |     """ | ||||||
|  |     cython_mod = importlib.import_module(cython_mod_name) | ||||||
|  |     for name in dir(cython_mod): | ||||||
|  |         item = getattr(cython_mod, name) | ||||||
|  |         if callable(item) and name.startswith("test_"): | ||||||
|  |             setattr(sys.modules[test_mod_name], name, item) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # Fixtures for language tokenizers (languages sorted alphabetically) | # Fixtures for language tokenizers (languages sorted alphabetically) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										119
									
								
								spacy/tests/parser/_search.pyx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								spacy/tests/parser/_search.pyx
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,119 @@ | ||||||
|  | # cython: infer_types=True, binding=True | ||||||
|  | from spacy.pipeline._parser_internals.search cimport Beam, MaxViolation | ||||||
|  | from spacy.typedefs cimport class_t, weight_t | ||||||
|  | from cymem.cymem cimport Pool | ||||||
|  | 
 | ||||||
|  | from ..conftest import cytest | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | cdef struct TestState: | ||||||
|  |     int length | ||||||
|  |     int x | ||||||
|  |     Py_UNICODE* string | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1: | ||||||
|  |     dest_state = <TestState*>dest | ||||||
|  |     src_state = <TestState*>src | ||||||
|  |     dest_state.length = src_state.length | ||||||
|  |     dest_state.x = src_state.x | ||||||
|  |     dest_state.x += clas | ||||||
|  |     if extra_args != NULL: | ||||||
|  |         dest_state.string = <Py_UNICODE*>extra_args | ||||||
|  |     else: | ||||||
|  |         dest_state.string = src_state.string | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef void* initialize(Pool mem, int n, void* extra_args) except NULL: | ||||||
|  |     state = <TestState*>mem.alloc(1, sizeof(TestState)) | ||||||
|  |     state.length = n | ||||||
|  |     state.x = 1 | ||||||
|  |     if extra_args == NULL: | ||||||
|  |         state.string = u'default' | ||||||
|  |     else: | ||||||
|  |         state.string = <Py_UNICODE*>extra_args | ||||||
|  |     return state | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef int destroy(Pool mem, void* state, void* extra_args) except -1: | ||||||
|  |     state = <TestState*>state | ||||||
|  |     mem.free(state) | ||||||
|  | 
 | ||||||
|  | @cytest | ||||||
|  | @pytest.mark.parametrize("nr_class,beam_width", | ||||||
|  |     [ | ||||||
|  |         (2, 3), | ||||||
|  |         (3, 6), | ||||||
|  |         (4, 20), | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | def test_init(nr_class, beam_width): | ||||||
|  |     b = Beam(nr_class, beam_width) | ||||||
|  |     assert b.size == 1 | ||||||
|  |     assert b.width == beam_width | ||||||
|  |     assert b.nr_class == nr_class | ||||||
|  | 
 | ||||||
|  | @cytest | ||||||
|  | def test_init_violn(): | ||||||
|  |     MaxViolation() | ||||||
|  | 
 | ||||||
|  | @cytest | ||||||
|  | @pytest.mark.parametrize("nr_class,beam_width,length", | ||||||
|  |     [ | ||||||
|  |         (2, 3, 3), | ||||||
|  |         (3, 6, 15), | ||||||
|  |         (4, 20, 32), | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | def test_initialize(nr_class, beam_width, length): | ||||||
|  |     b = Beam(nr_class, beam_width) | ||||||
|  |     b.initialize(initialize, destroy, length, NULL) | ||||||
|  |     for i in range(b.width): | ||||||
|  |         s = <TestState*>b.at(i) | ||||||
|  |         assert s.length == length, s.length | ||||||
|  |         assert s.string == 'default' | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @cytest | ||||||
|  | @pytest.mark.parametrize("nr_class,beam_width,length,extra", | ||||||
|  |     [ | ||||||
|  |         (2, 3, 4, None), | ||||||
|  |         (3, 6, 15, u"test beam 1"), | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | def test_initialize_extra(nr_class, beam_width, length, extra): | ||||||
|  |     b = Beam(nr_class, beam_width) | ||||||
|  |     if extra is None: | ||||||
|  |         b.initialize(initialize, destroy, length, NULL) | ||||||
|  |     else: | ||||||
|  |         b.initialize(initialize, destroy, length, <void*><Py_UNICODE*>extra) | ||||||
|  |     for i in range(b.width): | ||||||
|  |         s = <TestState*>b.at(i) | ||||||
|  |         assert s.length == length | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @cytest | ||||||
|  | @pytest.mark.parametrize("nr_class,beam_width,length", | ||||||
|  |     [ | ||||||
|  |         (3, 6, 15), | ||||||
|  |         (4, 20, 32), | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | def test_transition(nr_class, beam_width, length): | ||||||
|  |     b = Beam(nr_class, beam_width) | ||||||
|  |     b.initialize(initialize, destroy, length, NULL) | ||||||
|  |     b.set_cell(0, 2, 30, True, 0) | ||||||
|  |     b.set_cell(0, 1, 42, False, 0) | ||||||
|  |     b.advance(transition, NULL, NULL) | ||||||
|  |     assert b.size == 1, b.size | ||||||
|  |     assert b.score == 30, b.score | ||||||
|  |     s = <TestState*>b.at(0) | ||||||
|  |     assert s.x == 3 | ||||||
|  |     assert b._states[0].score == 30, b._states[0].score | ||||||
|  |     b.set_cell(0, 1, 10, True, 0) | ||||||
|  |     b.set_cell(0, 2, 20, True, 0) | ||||||
|  |     b.advance(transition, NULL, NULL) | ||||||
|  |     assert b._states[0].score == 50, b._states[0].score | ||||||
|  |     assert b._states[1].score == 40 | ||||||
|  |     s = <TestState*>b.at(0) | ||||||
|  |     assert s.x == 5 | ||||||
							
								
								
									
										3
									
								
								spacy/tests/parser/test_search.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								spacy/tests/parser/test_search.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,3 @@ | ||||||
|  | from ..conftest import register_cython_tests | ||||||
|  | 
 | ||||||
|  | register_cython_tests("spacy.tests.parser._search", __name__) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user