spaCy/spacy/pipeline/_parser_internals/search.pxd
Daniël de Kok f9308aae13
Fix v4 branch to build against Thinc v9 (#11921)
* Move `thinc.extra.search` to `spacy.pipeline._parser_internals`

Backport of:
https://github.com/explosion/spaCy/pull/11317

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Replace references to `thinc.backends.linalg` with `CBlas`

Backport of:
https://github.com/explosion/spaCy/pull/11292

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Use cross entropy from `thinc.legacy`

* Require thinc>=9.0.0.dev0,<9.1.0

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
2022-12-17 14:32:19 +01:00

90 lines
2.6 KiB
Cython

from cymem.cymem cimport Pool
from libc.stdint cimport uint32_t
from libc.stdint cimport uint64_t
from libcpp.pair cimport pair
from libcpp.queue cimport priority_queue
from libcpp.vector cimport vector
from ...typedefs cimport class_t, weight_t, hash_t
ctypedef pair[weight_t, size_t] Entry
ctypedef priority_queue[Entry] Queue
ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1
ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL
ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1
ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1
ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0
cdef struct _State:
void* content
class_t* hist
weight_t score
weight_t loss
int i
int t
bint is_done
cdef class Beam:
cdef Pool mem
cdef class_t nr_class
cdef class_t width
cdef class_t size
cdef public weight_t min_density
cdef int t
cdef readonly bint is_done
cdef list histories
cdef list _parent_histories
cdef weight_t** scores
cdef int** is_valid
cdef weight_t** costs
cdef _State* _parents
cdef _State* _states
cdef del_func_t del_func
cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1
cdef inline void* at(self, int i) nogil:
return self._states[i].content
cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
void* extra_args) except -1
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1
cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil:
self.scores[i][j] = score
self.is_valid[i][j] = is_valid
self.costs[i][j] = cost
cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
const weight_t* costs) except -1
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1
cdef class MaxViolation:
cdef Pool mem
cdef weight_t cost
cdef weight_t delta
cdef readonly weight_t p_score
cdef readonly weight_t g_score
cdef readonly double Z
cdef readonly double gZ
cdef class_t n
cdef readonly list p_hist
cdef readonly list g_hist
cdef readonly list p_probs
cdef readonly list g_probs
cpdef int check(self, Beam pred, Beam gold) except -1
cpdef int check_crf(self, Beam pred, Beam gold) except -1