2023-12-08 22:24:09 +03:00
|
|
|
from libc.string cimport memcpy, memset
|
2023-12-08 22:23:08 +03:00
|
|
|
from thinc.backends.cblas cimport CBlas
|
2023-12-08 22:24:09 +03:00
|
|
|
|
2023-12-08 22:23:08 +03:00
|
|
|
from ..pipeline._parser_internals._state cimport StateC
|
2023-12-08 22:24:09 +03:00
|
|
|
from ..typedefs cimport hash_t, weight_t
|
2023-12-08 22:23:08 +03:00
|
|
|
|
|
|
|
|
|
|
|
cdef struct SizesC:
|
|
|
|
int states
|
|
|
|
int classes
|
|
|
|
int hiddens
|
|
|
|
int pieces
|
|
|
|
int feats
|
|
|
|
int embed_width
|
|
|
|
|
|
|
|
|
|
|
|
cdef struct WeightsC:
|
|
|
|
const float* feat_weights
|
|
|
|
const float* feat_bias
|
|
|
|
const float* hidden_bias
|
|
|
|
const float* hidden_weights
|
|
|
|
const float* seen_classes
|
|
|
|
|
|
|
|
|
|
|
|
cdef struct ActivationsC:
|
|
|
|
int* token_ids
|
|
|
|
float* unmaxed
|
|
|
|
float* scores
|
|
|
|
float* hiddens
|
|
|
|
int* is_valid
|
|
|
|
int _curr_size
|
|
|
|
int _max_size
|
|
|
|
|
|
|
|
|
|
|
|
cdef WeightsC get_c_weights(model) except *
|
|
|
|
|
|
|
|
cdef SizesC get_c_sizes(model, int batch_size) except *
|
|
|
|
|
|
|
|
cdef ActivationsC alloc_activations(SizesC n) nogil
|
|
|
|
|
|
|
|
cdef void free_activations(const ActivationsC* A) nogil
|
|
|
|
|
|
|
|
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
2023-12-18 22:02:15 +03:00
|
|
|
const WeightsC* W, SizesC n) nogil
|
|
|
|
|
2023-12-08 22:23:08 +03:00
|
|
|
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil
|
|
|
|
|
2023-12-18 22:02:15 +03:00
|
|
|
cdef void cpu_log_loss(float* d_scores, const float* costs,
|
|
|
|
const int* is_valid, const float* scores, int O) nogil
|