Move thinc.extra.search to spacy.pipeline._parser_internals (#11317)

* `search`: Move from `thinc.extra.search`
Fix NPE in `Beam.__dealloc__`

* `pytest`: Add support for executing Cython tests
Move `search` tests from thinc and patch them to run with `pytest`

* `mypy` fix

* Update comment

* `conftest`: Expose `register_cython_tests`

* Remove unused import
This commit is contained in:
Madeesh Kannan 2022-08-17 17:37:08 +02:00 committed by GitHub
parent dc14ee01ac
commit 6109770fc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 557 additions and 8 deletions

View File

@ -46,6 +46,7 @@ MOD_NAMES = [
"spacy.pipeline._parser_internals.batch",
"spacy.pipeline._parser_internals.ner",
"spacy.pipeline._parser_internals.nonproj",
"spacy.pipeline._parser_internals.search",
"spacy.pipeline._parser_internals._state",
"spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system",
@ -65,6 +66,7 @@ MOD_NAMES = [
"spacy.matcher.dependencymatcher",
"spacy.symbols",
"spacy.vectors",
"spacy.tests.parser._search",
]
COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"],

View File

@ -1,6 +1,6 @@
from ...typedefs cimport class_t, hash_t
# These are passed as callbacks to thinc.search.Beam
# These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
cdef int check_final_state(void* _state, void* extra_args) except -1

View File

@ -3,18 +3,17 @@
cimport numpy as np
import numpy
from cpython.ref cimport PyObject, Py_XDECREF
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.extra.search cimport MaxViolation
from ...typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition
from ...errors import Errors
from .batch cimport Batch
from .search cimport Beam, MaxViolation
from .search import MaxViolation
from .stateclass cimport StateC, StateClass
# These are passed as callbacks to thinc.search.Beam
# These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateC*>_dest
src = <StateC*>_src

View File

@ -14,7 +14,7 @@ from ...training.example cimport Example
from .stateclass cimport StateClass
from ._state cimport StateC, ArcC
from ...errors import Errors
from thinc.extra.search cimport Beam
from .search cimport Beam
cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string('subtok')

View File

@ -6,7 +6,6 @@ from libcpp.vector cimport vector
from cymem.cymem cimport Pool
from collections import Counter
from thinc.extra.search cimport Beam
from ...tokens.doc cimport Doc
from ...tokens.span import Span
@ -16,6 +15,7 @@ from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE
from ...structs cimport TokenC, SpanC
from ...training.example cimport Example
from .search cimport Beam
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition, do_func_t

View File

@ -0,0 +1,89 @@
from cymem.cymem cimport Pool
from libc.stdint cimport uint32_t
from libc.stdint cimport uint64_t
from libcpp.pair cimport pair
from libcpp.queue cimport priority_queue
from libcpp.vector cimport vector
from ...typedefs cimport class_t, weight_t, hash_t
ctypedef pair[weight_t, size_t] Entry
ctypedef priority_queue[Entry] Queue
ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1
ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL
ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1
ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1
ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0
cdef struct _State:
void* content
class_t* hist
weight_t score
weight_t loss
int i
int t
bint is_done
cdef class Beam:
cdef Pool mem
cdef class_t nr_class
cdef class_t width
cdef class_t size
cdef public weight_t min_density
cdef int t
cdef readonly bint is_done
cdef list histories
cdef list _parent_histories
cdef weight_t** scores
cdef int** is_valid
cdef weight_t** costs
cdef _State* _parents
cdef _State* _states
cdef del_func_t del_func
cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1
cdef inline void* at(self, int i) nogil:
return self._states[i].content
cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
void* extra_args) except -1
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1
cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil:
self.scores[i][j] = score
self.is_valid[i][j] = is_valid
self.costs[i][j] = cost
cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
const weight_t* costs) except -1
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1
cdef class MaxViolation:
cdef Pool mem
cdef weight_t cost
cdef weight_t delta
cdef readonly weight_t p_score
cdef readonly weight_t g_score
cdef readonly double Z
cdef readonly double gZ
cdef class_t n
cdef readonly list p_hist
cdef readonly list g_hist
cdef readonly list p_probs
cdef readonly list g_probs
cpdef int check(self, Beam pred, Beam gold) except -1
cpdef int check_crf(self, Beam pred, Beam gold) except -1

View File

@ -0,0 +1,306 @@
# cython: profile=True, experimental_cpp_class_def=True, cdivision=True, infer_types=True
cimport cython
from libc.string cimport memset, memcpy
from libc.math cimport log, exp
import math
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
cdef class Beam:
def __init__(self, class_t nr_class, class_t width, weight_t min_density=0.0):
assert nr_class != 0
assert width != 0
self.nr_class = nr_class
self.width = width
self.min_density = min_density
self.size = 1
self.t = 0
self.mem = Pool()
self.del_func = NULL
self._parents = <_State*>self.mem.alloc(self.width, sizeof(_State))
self._states = <_State*>self.mem.alloc(self.width, sizeof(_State))
cdef int i
self.histories = [[] for i in range(self.width)]
self._parent_histories = [[] for i in range(self.width)]
self.scores = <weight_t**>self.mem.alloc(self.width, sizeof(weight_t*))
self.is_valid = <int**>self.mem.alloc(self.width, sizeof(weight_t*))
self.costs = <weight_t**>self.mem.alloc(self.width, sizeof(weight_t*))
for i in range(self.width):
self.scores[i] = <weight_t*>self.mem.alloc(self.nr_class, sizeof(weight_t))
self.is_valid[i] = <int*>self.mem.alloc(self.nr_class, sizeof(int))
self.costs[i] = <weight_t*>self.mem.alloc(self.nr_class, sizeof(weight_t))
def __len__(self):
return self.size
property score:
def __get__(self):
return self._states[0].score
property min_score:
def __get__(self):
return self._states[self.size-1].score
property loss:
def __get__(self):
return self._states[0].loss
property probs:
def __get__(self):
return _softmax([self._states[i].score for i in range(self.size)])
property scores:
def __get__(self):
return [self._states[i].score for i in range(self.size)]
property histories:
def __get__(self):
return self.histories
cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
const weight_t* costs) except -1:
cdef int j
for j in range(self.nr_class):
self.scores[i][j] = scores[j]
self.is_valid[i][j] = is_valid[j]
self.costs[i][j] = costs[j]
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1:
cdef int i, j
for i in range(self.width):
memcpy(self.scores[i], scores[i], sizeof(weight_t) * self.nr_class)
memcpy(self.is_valid[i], is_valid[i], sizeof(bint) * self.nr_class)
memcpy(self.costs[i], costs[i], sizeof(int) * self.nr_class)
cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1:
for i in range(self.width):
self._states[i].content = init_func(self.mem, n, extra_args)
self._parents[i].content = init_func(self.mem, n, extra_args)
self.del_func = del_func
def __dealloc__(self):
if self.del_func == NULL:
return
for i in range(self.width):
self.del_func(self.mem, self._states[i].content, NULL)
self.del_func(self.mem, self._parents[i].content, NULL)
@cython.cdivision(True)
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
void* extra_args) except -1:
cdef weight_t** scores = self.scores
cdef int** is_valid = self.is_valid
cdef weight_t** costs = self.costs
cdef Queue* q = new Queue()
self._fill(q, scores, is_valid)
# For a beam of width k, we only ever need 2k state objects. How?
# Each transition takes a parent and a class and produces a new state.
# So, we don't need the whole history --- just the parent. So at
# each step, we take a parent, and apply one or more extensions to
# it.
self._parents, self._states = self._states, self._parents
self._parent_histories, self.histories = self.histories, self._parent_histories
cdef weight_t score
cdef int p_i
cdef int i = 0
cdef class_t clas
cdef _State* parent
cdef _State* state
cdef hash_t key
cdef PreshMap seen_states = PreshMap(self.width)
cdef uint64_t is_seen
cdef uint64_t one = 1
while i < self.width and not q.empty():
data = q.top()
p_i = data.second / self.nr_class
clas = data.second % self.nr_class
score = data.first
q.pop()
parent = &self._parents[p_i]
# Indicates terminal state reached; i.e. state is done
if parent.is_done:
# Now parent will not be changed, so we don't have to copy.
# Once finished, should also be unbranching.
self._states[i], parent[0] = parent[0], self._states[i]
parent.i = self._states[i].i
parent.t = self._states[i].t
parent.is_done = self._states[i].t
self._states[i].score = score
self.histories[i] = list(self._parent_histories[p_i])
i += 1
else:
state = &self._states[i]
# The supplied transition function should adjust the destination
# state to be the result of applying the class to the source state
transition_func(state.content, parent.content, clas, extra_args)
key = hash_func(state.content, extra_args) if hash_func is not NULL else 0
is_seen = <uint64_t>seen_states.get(key)
if key == 0 or key == 1 or not is_seen:
if key != 0 and key != 1:
seen_states.set(key, <void*>one)
state.score = score
state.loss = parent.loss + costs[p_i][clas]
self.histories[i] = list(self._parent_histories[p_i])
self.histories[i].append(clas)
i += 1
del q
self.size = i
assert self.size >= 1
for i in range(self.width):
memset(self.scores[i], 0, sizeof(weight_t) * self.nr_class)
memset(self.costs[i], 0, sizeof(weight_t) * self.nr_class)
memset(self.is_valid[i], 0, sizeof(int) * self.nr_class)
self.t += 1
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1:
cdef int i
for i in range(self.size):
if not self._states[i].is_done:
self._states[i].is_done = finish_func(self._states[i].content, extra_args)
for i in range(self.size):
if not self._states[i].is_done:
self.is_done = False
break
else:
self.is_done = True
@cython.cdivision(True)
cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1:
"""Populate the queue from a k * n matrix of scores, where k is the
beam-width, and n is the number of classes.
"""
cdef Entry entry
cdef weight_t score
cdef _State* s
cdef int i, j, move_id
assert self.size >= 1
cdef vector[Entry] entries
for i in range(self.size):
s = &self._states[i]
move_id = i * self.nr_class
if s.is_done:
# Update score by path average, following TACL '13 paper.
if self.histories[i]:
entry.first = s.score + (s.score / self.t)
else:
entry.first = s.score
entry.second = move_id
entries.push_back(entry)
else:
for j in range(self.nr_class):
if is_valid[i][j]:
entry.first = s.score + scores[i][j]
entry.second = move_id + j
entries.push_back(entry)
cdef double max_, Z, cutoff
if self.min_density == 0.0:
for i in range(entries.size()):
q.push(entries[i])
elif not entries.empty():
max_ = entries[0].first
Z = 0.
cutoff = 0.
# Softmax into probabilities, so we can prune
for i in range(entries.size()):
if entries[i].first > max_:
max_ = entries[i].first
for i in range(entries.size()):
Z += exp(entries[i].first-max_)
cutoff = (1. / Z) * self.min_density
for i in range(entries.size()):
prob = exp(entries[i].first-max_) / Z
if prob >= cutoff:
q.push(entries[i])
cdef class MaxViolation:
def __init__(self):
self.p_score = 0.0
self.g_score = 0.0
self.Z = 0.0
self.gZ = 0.0
self.delta = -1
self.cost = 0
self.p_hist = []
self.g_hist = []
self.p_probs = []
self.g_probs = []
cpdef int check(self, Beam pred, Beam gold) except -1:
cdef _State* p = &pred._states[0]
cdef _State* g = &gold._states[0]
cdef weight_t d = p.score - g.score
if p.loss >= 1 and (self.cost == 0 or d > self.delta):
self.cost = p.loss
self.delta = d
self.p_hist = list(pred.histories[0])
self.g_hist = list(gold.histories[0])
self.p_score = p.score
self.g_score = g.score
self.Z = 1e-10
self.gZ = 1e-10
for i in range(pred.size):
if pred._states[i].loss > 0:
self.Z += exp(pred._states[i].score)
for i in range(gold.size):
if gold._states[i].loss == 0:
prob = exp(gold._states[i].score)
self.Z += prob
self.gZ += prob
cpdef int check_crf(self, Beam pred, Beam gold) except -1:
d = pred.score - gold.score
seen_golds = set([tuple(gold.histories[i]) for i in range(gold.size)])
if pred.loss > 0 and (self.cost == 0 or d > self.delta):
p_hist = []
p_scores = []
g_hist = []
g_scores = []
for i in range(pred.size):
if pred._states[i].loss > 0:
p_scores.append(pred._states[i].score)
p_hist.append(list(pred.histories[i]))
# This can happen from non-monotonic actions
# If we find a better gold analysis this way, be sure to keep it.
elif pred._states[i].loss <= 0 \
and tuple(pred.histories[i]) not in seen_golds:
g_scores.append(pred._states[i].score)
g_hist.append(list(pred.histories[i]))
for i in range(gold.size):
if gold._states[i].loss == 0:
g_scores.append(gold._states[i].score)
g_hist.append(list(gold.histories[i]))
all_probs = _softmax(p_scores + g_scores)
p_probs = all_probs[:len(p_scores)]
g_probs_all = all_probs[len(p_scores):]
g_probs = _softmax(g_scores)
self.cost = pred.loss
self.delta = d
self.p_hist = p_hist
self.g_hist = g_hist
# TODO: These variables are misnamed! These are the gradients of the loss.
self.p_probs = p_probs
# Intuition here:
# The gradient of the loss is:
# P(model) - P(truth)
# Normally, P(truth) is 1 for the gold
# But, if we want to do the "partial credit" scheme, we want
# to create a distribution over the gold, proportional to the scores
# awarded.
self.g_probs = [x-y for x, y in zip(g_probs_all, g_probs)]
def _softmax(nums):
if not nums:
return []
max_ = max(nums)
nums = [(exp(n-max_) if n is not None else None) for n in nums]
Z = sum(n for n in nums if n is not None)
return [(n/Z if n is not None else None) for n in nums]

View File

@ -12,13 +12,13 @@ import contextlib
import srsly
from thinc.api import set_dropout_rate, CupyOps, get_array_module
from thinc.extra.search cimport Beam
from thinc.types import Ints1d
import numpy.random
import numpy
import warnings
from ._parser_internals.stateclass cimport StateC, StateClass
from ._parser_internals.search cimport Beam
from ..tokens.doc cimport Doc
from .trainable_pipe import TrainablePipe
from ._parser_internals cimport _beam_utils

View File

@ -1,5 +1,9 @@
import pytest
from spacy.util import get_lang_class
import functools
import inspect
import importlib
import sys
def pytest_addoption(parser):
@ -41,6 +45,33 @@ def pytest_runtest_setup(item):
pytest.skip("not referencing any issues")
# Decorator for Cython-built tests
# https://shwina.github.io/cython-testing/
def cytest(func):
"""
Wraps `func` in a plain Python function.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
bound = inspect.signature(func).bind(*args, **kwargs)
return func(*bound.args, **bound.kwargs)
return wrapped
def register_cython_tests(cython_mod_name: str, test_mod_name: str):
"""
Registers all callables with name `test_*` in Cython module `cython_mod_name`
as attributes in module `test_mod_name`, making them discoverable by pytest.
"""
cython_mod = importlib.import_module(cython_mod_name)
for name in dir(cython_mod):
item = getattr(cython_mod, name)
if callable(item) and name.startswith("test_"):
setattr(sys.modules[test_mod_name], name, item)
# Fixtures for language tokenizers (languages sorted alphabetically)

View File

@ -0,0 +1,119 @@
# cython: infer_types=True, binding=True
from spacy.pipeline._parser_internals.search cimport Beam, MaxViolation
from spacy.typedefs cimport class_t, weight_t
from cymem.cymem cimport Pool
from ..conftest import cytest
import pytest
cdef struct TestState:
int length
int x
Py_UNICODE* string
cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1:
dest_state = <TestState*>dest
src_state = <TestState*>src
dest_state.length = src_state.length
dest_state.x = src_state.x
dest_state.x += clas
if extra_args != NULL:
dest_state.string = <Py_UNICODE*>extra_args
else:
dest_state.string = src_state.string
cdef void* initialize(Pool mem, int n, void* extra_args) except NULL:
state = <TestState*>mem.alloc(1, sizeof(TestState))
state.length = n
state.x = 1
if extra_args == NULL:
state.string = u'default'
else:
state.string = <Py_UNICODE*>extra_args
return state
cdef int destroy(Pool mem, void* state, void* extra_args) except -1:
state = <TestState*>state
mem.free(state)
@cytest
@pytest.mark.parametrize("nr_class,beam_width",
[
(2, 3),
(3, 6),
(4, 20),
]
)
def test_init(nr_class, beam_width):
b = Beam(nr_class, beam_width)
assert b.size == 1
assert b.width == beam_width
assert b.nr_class == nr_class
@cytest
def test_init_violn():
MaxViolation()
@cytest
@pytest.mark.parametrize("nr_class,beam_width,length",
[
(2, 3, 3),
(3, 6, 15),
(4, 20, 32),
]
)
def test_initialize(nr_class, beam_width, length):
b = Beam(nr_class, beam_width)
b.initialize(initialize, destroy, length, NULL)
for i in range(b.width):
s = <TestState*>b.at(i)
assert s.length == length, s.length
assert s.string == 'default'
@cytest
@pytest.mark.parametrize("nr_class,beam_width,length,extra",
[
(2, 3, 4, None),
(3, 6, 15, u"test beam 1"),
]
)
def test_initialize_extra(nr_class, beam_width, length, extra):
b = Beam(nr_class, beam_width)
if extra is None:
b.initialize(initialize, destroy, length, NULL)
else:
b.initialize(initialize, destroy, length, <void*><Py_UNICODE*>extra)
for i in range(b.width):
s = <TestState*>b.at(i)
assert s.length == length
@cytest
@pytest.mark.parametrize("nr_class,beam_width,length",
[
(3, 6, 15),
(4, 20, 32),
]
)
def test_transition(nr_class, beam_width, length):
b = Beam(nr_class, beam_width)
b.initialize(initialize, destroy, length, NULL)
b.set_cell(0, 2, 30, True, 0)
b.set_cell(0, 1, 42, False, 0)
b.advance(transition, NULL, NULL)
assert b.size == 1, b.size
assert b.score == 30, b.score
s = <TestState*>b.at(0)
assert s.x == 3
assert b._states[0].score == 30, b._states[0].score
b.set_cell(0, 1, 10, True, 0)
b.set_cell(0, 2, 20, True, 0)
b.advance(transition, NULL, NULL)
assert b._states[0].score == 50, b._states[0].score
assert b._states[1].score == 40
s = <TestState*>b.at(0)
assert s.x == 5

View File

@ -0,0 +1,3 @@
from ..conftest import register_cython_tests
register_cython_tests("spacy.tests.parser._search", __name__)