Move thinc.extra.search to spacy.pipeline._parser_internals (#11317)

* `search`: Move from `thinc.extra.search`
Fix NPE in `Beam.__dealloc__`

* `pytest`: Add support for executing Cython tests
Move `search` tests from thinc and patch them to run with `pytest`

* `mypy` fix

* Update comment

* `conftest`: Expose `register_cython_tests`

* Remove unused import
This commit is contained in:
Madeesh Kannan 2022-08-17 17:37:08 +02:00 committed by GitHub
parent dc14ee01ac
commit 6109770fc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 557 additions and 8 deletions

View File

@ -46,6 +46,7 @@ MOD_NAMES = [
"spacy.pipeline._parser_internals.batch", "spacy.pipeline._parser_internals.batch",
"spacy.pipeline._parser_internals.ner", "spacy.pipeline._parser_internals.ner",
"spacy.pipeline._parser_internals.nonproj", "spacy.pipeline._parser_internals.nonproj",
"spacy.pipeline._parser_internals.search",
"spacy.pipeline._parser_internals._state", "spacy.pipeline._parser_internals._state",
"spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system", "spacy.pipeline._parser_internals.transition_system",
@ -65,6 +66,7 @@ MOD_NAMES = [
"spacy.matcher.dependencymatcher", "spacy.matcher.dependencymatcher",
"spacy.symbols", "spacy.symbols",
"spacy.vectors", "spacy.vectors",
"spacy.tests.parser._search",
] ]
COMPILE_OPTIONS = { COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"], "msvc": ["/Ox", "/EHsc"],

View File

@ -1,6 +1,6 @@
from ...typedefs cimport class_t, hash_t from ...typedefs cimport class_t, hash_t
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
cdef int check_final_state(void* _state, void* extra_args) except -1 cdef int check_final_state(void* _state, void* extra_args) except -1

View File

@ -3,18 +3,17 @@
cimport numpy as np cimport numpy as np
import numpy import numpy
from cpython.ref cimport PyObject, Py_XDECREF from cpython.ref cimport PyObject, Py_XDECREF
from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation
from thinc.extra.search cimport MaxViolation
from ...typedefs cimport hash_t, class_t from ...typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from ...errors import Errors from ...errors import Errors
from .batch cimport Batch from .batch cimport Batch
from .search cimport Beam, MaxViolation
from .search import MaxViolation
from .stateclass cimport StateC, StateClass from .stateclass cimport StateC, StateClass
# These are passed as callbacks to thinc.search.Beam # These are passed as callbacks to .search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateC*>_dest dest = <StateC*>_dest
src = <StateC*>_src src = <StateC*>_src

View File

@ -14,7 +14,7 @@ from ...training.example cimport Example
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC, ArcC from ._state cimport StateC, ArcC
from ...errors import Errors from ...errors import Errors
from thinc.extra.search cimport Beam from .search cimport Beam
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string('subtok') cdef attr_t SUBTOK_LABEL = hash_string('subtok')

View File

@ -6,7 +6,6 @@ from libcpp.vector cimport vector
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from collections import Counter from collections import Counter
from thinc.extra.search cimport Beam
from ...tokens.doc cimport Doc from ...tokens.doc cimport Doc
from ...tokens.span import Span from ...tokens.span import Span
@ -16,6 +15,7 @@ from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE from ...attrs cimport IS_SPACE
from ...structs cimport TokenC, SpanC from ...structs cimport TokenC, SpanC
from ...training.example cimport Example from ...training.example cimport Example
from .search cimport Beam
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
from .transition_system cimport Transition, do_func_t from .transition_system cimport Transition, do_func_t

View File

@ -0,0 +1,89 @@
from cymem.cymem cimport Pool
from libc.stdint cimport uint32_t
from libc.stdint cimport uint64_t
from libcpp.pair cimport pair
from libcpp.queue cimport priority_queue
from libcpp.vector cimport vector
from ...typedefs cimport class_t, weight_t, hash_t
ctypedef pair[weight_t, size_t] Entry
ctypedef priority_queue[Entry] Queue
ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1
ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL
ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1
ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1
ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0
cdef struct _State:
void* content
class_t* hist
weight_t score
weight_t loss
int i
int t
bint is_done
cdef class Beam:
cdef Pool mem
cdef class_t nr_class
cdef class_t width
cdef class_t size
cdef public weight_t min_density
cdef int t
cdef readonly bint is_done
cdef list histories
cdef list _parent_histories
cdef weight_t** scores
cdef int** is_valid
cdef weight_t** costs
cdef _State* _parents
cdef _State* _states
cdef del_func_t del_func
cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1
cdef inline void* at(self, int i) nogil:
return self._states[i].content
cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
void* extra_args) except -1
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1
cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil:
self.scores[i][j] = score
self.is_valid[i][j] = is_valid
self.costs[i][j] = cost
cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
const weight_t* costs) except -1
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1
cdef class MaxViolation:
cdef Pool mem
cdef weight_t cost
cdef weight_t delta
cdef readonly weight_t p_score
cdef readonly weight_t g_score
cdef readonly double Z
cdef readonly double gZ
cdef class_t n
cdef readonly list p_hist
cdef readonly list g_hist
cdef readonly list p_probs
cdef readonly list g_probs
cpdef int check(self, Beam pred, Beam gold) except -1
cpdef int check_crf(self, Beam pred, Beam gold) except -1

View File

@ -0,0 +1,306 @@
# cython: profile=True, experimental_cpp_class_def=True, cdivision=True, infer_types=True
cimport cython
from libc.string cimport memset, memcpy
from libc.math cimport log, exp
import math
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
cdef class Beam:
def __init__(self, class_t nr_class, class_t width, weight_t min_density=0.0):
assert nr_class != 0
assert width != 0
self.nr_class = nr_class
self.width = width
self.min_density = min_density
self.size = 1
self.t = 0
self.mem = Pool()
self.del_func = NULL
self._parents = <_State*>self.mem.alloc(self.width, sizeof(_State))
self._states = <_State*>self.mem.alloc(self.width, sizeof(_State))
cdef int i
self.histories = [[] for i in range(self.width)]
self._parent_histories = [[] for i in range(self.width)]
self.scores = <weight_t**>self.mem.alloc(self.width, sizeof(weight_t*))
self.is_valid = <int**>self.mem.alloc(self.width, sizeof(weight_t*))
self.costs = <weight_t**>self.mem.alloc(self.width, sizeof(weight_t*))
for i in range(self.width):
self.scores[i] = <weight_t*>self.mem.alloc(self.nr_class, sizeof(weight_t))
self.is_valid[i] = <int*>self.mem.alloc(self.nr_class, sizeof(int))
self.costs[i] = <weight_t*>self.mem.alloc(self.nr_class, sizeof(weight_t))
def __len__(self):
return self.size
property score:
def __get__(self):
return self._states[0].score
property min_score:
def __get__(self):
return self._states[self.size-1].score
property loss:
def __get__(self):
return self._states[0].loss
property probs:
def __get__(self):
return _softmax([self._states[i].score for i in range(self.size)])
property scores:
def __get__(self):
return [self._states[i].score for i in range(self.size)]
property histories:
def __get__(self):
return self.histories
cdef int set_row(self, int i, const weight_t* scores, const int* is_valid,
const weight_t* costs) except -1:
cdef int j
for j in range(self.nr_class):
self.scores[i][j] = scores[j]
self.is_valid[i][j] = is_valid[j]
self.costs[i][j] = costs[j]
cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1:
cdef int i, j
for i in range(self.width):
memcpy(self.scores[i], scores[i], sizeof(weight_t) * self.nr_class)
memcpy(self.is_valid[i], is_valid[i], sizeof(bint) * self.nr_class)
memcpy(self.costs[i], costs[i], sizeof(int) * self.nr_class)
cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1:
for i in range(self.width):
self._states[i].content = init_func(self.mem, n, extra_args)
self._parents[i].content = init_func(self.mem, n, extra_args)
self.del_func = del_func
def __dealloc__(self):
if self.del_func == NULL:
return
for i in range(self.width):
self.del_func(self.mem, self._states[i].content, NULL)
self.del_func(self.mem, self._parents[i].content, NULL)
@cython.cdivision(True)
cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func,
void* extra_args) except -1:
cdef weight_t** scores = self.scores
cdef int** is_valid = self.is_valid
cdef weight_t** costs = self.costs
cdef Queue* q = new Queue()
self._fill(q, scores, is_valid)
# For a beam of width k, we only ever need 2k state objects. How?
# Each transition takes a parent and a class and produces a new state.
# So, we don't need the whole history --- just the parent. So at
# each step, we take a parent, and apply one or more extensions to
# it.
self._parents, self._states = self._states, self._parents
self._parent_histories, self.histories = self.histories, self._parent_histories
cdef weight_t score
cdef int p_i
cdef int i = 0
cdef class_t clas
cdef _State* parent
cdef _State* state
cdef hash_t key
cdef PreshMap seen_states = PreshMap(self.width)
cdef uint64_t is_seen
cdef uint64_t one = 1
while i < self.width and not q.empty():
data = q.top()
p_i = data.second / self.nr_class
clas = data.second % self.nr_class
score = data.first
q.pop()
parent = &self._parents[p_i]
# Indicates terminal state reached; i.e. state is done
if parent.is_done:
# Now parent will not be changed, so we don't have to copy.
# Once finished, should also be unbranching.
self._states[i], parent[0] = parent[0], self._states[i]
parent.i = self._states[i].i
parent.t = self._states[i].t
parent.is_done = self._states[i].t
self._states[i].score = score
self.histories[i] = list(self._parent_histories[p_i])
i += 1
else:
state = &self._states[i]
# The supplied transition function should adjust the destination
# state to be the result of applying the class to the source state
transition_func(state.content, parent.content, clas, extra_args)
key = hash_func(state.content, extra_args) if hash_func is not NULL else 0
is_seen = <uint64_t>seen_states.get(key)
if key == 0 or key == 1 or not is_seen:
if key != 0 and key != 1:
seen_states.set(key, <void*>one)
state.score = score
state.loss = parent.loss + costs[p_i][clas]
self.histories[i] = list(self._parent_histories[p_i])
self.histories[i].append(clas)
i += 1
del q
self.size = i
assert self.size >= 1
for i in range(self.width):
memset(self.scores[i], 0, sizeof(weight_t) * self.nr_class)
memset(self.costs[i], 0, sizeof(weight_t) * self.nr_class)
memset(self.is_valid[i], 0, sizeof(int) * self.nr_class)
self.t += 1
cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1:
cdef int i
for i in range(self.size):
if not self._states[i].is_done:
self._states[i].is_done = finish_func(self._states[i].content, extra_args)
for i in range(self.size):
if not self._states[i].is_done:
self.is_done = False
break
else:
self.is_done = True
@cython.cdivision(True)
cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1:
"""Populate the queue from a k * n matrix of scores, where k is the
beam-width, and n is the number of classes.
"""
cdef Entry entry
cdef weight_t score
cdef _State* s
cdef int i, j, move_id
assert self.size >= 1
cdef vector[Entry] entries
for i in range(self.size):
s = &self._states[i]
move_id = i * self.nr_class
if s.is_done:
# Update score by path average, following TACL '13 paper.
if self.histories[i]:
entry.first = s.score + (s.score / self.t)
else:
entry.first = s.score
entry.second = move_id
entries.push_back(entry)
else:
for j in range(self.nr_class):
if is_valid[i][j]:
entry.first = s.score + scores[i][j]
entry.second = move_id + j
entries.push_back(entry)
cdef double max_, Z, cutoff
if self.min_density == 0.0:
for i in range(entries.size()):
q.push(entries[i])
elif not entries.empty():
max_ = entries[0].first
Z = 0.
cutoff = 0.
# Softmax into probabilities, so we can prune
for i in range(entries.size()):
if entries[i].first > max_:
max_ = entries[i].first
for i in range(entries.size()):
Z += exp(entries[i].first-max_)
cutoff = (1. / Z) * self.min_density
for i in range(entries.size()):
prob = exp(entries[i].first-max_) / Z
if prob >= cutoff:
q.push(entries[i])
cdef class MaxViolation:
def __init__(self):
self.p_score = 0.0
self.g_score = 0.0
self.Z = 0.0
self.gZ = 0.0
self.delta = -1
self.cost = 0
self.p_hist = []
self.g_hist = []
self.p_probs = []
self.g_probs = []
cpdef int check(self, Beam pred, Beam gold) except -1:
cdef _State* p = &pred._states[0]
cdef _State* g = &gold._states[0]
cdef weight_t d = p.score - g.score
if p.loss >= 1 and (self.cost == 0 or d > self.delta):
self.cost = p.loss
self.delta = d
self.p_hist = list(pred.histories[0])
self.g_hist = list(gold.histories[0])
self.p_score = p.score
self.g_score = g.score
self.Z = 1e-10
self.gZ = 1e-10
for i in range(pred.size):
if pred._states[i].loss > 0:
self.Z += exp(pred._states[i].score)
for i in range(gold.size):
if gold._states[i].loss == 0:
prob = exp(gold._states[i].score)
self.Z += prob
self.gZ += prob
cpdef int check_crf(self, Beam pred, Beam gold) except -1:
d = pred.score - gold.score
seen_golds = set([tuple(gold.histories[i]) for i in range(gold.size)])
if pred.loss > 0 and (self.cost == 0 or d > self.delta):
p_hist = []
p_scores = []
g_hist = []
g_scores = []
for i in range(pred.size):
if pred._states[i].loss > 0:
p_scores.append(pred._states[i].score)
p_hist.append(list(pred.histories[i]))
# This can happen from non-monotonic actions
# If we find a better gold analysis this way, be sure to keep it.
elif pred._states[i].loss <= 0 \
and tuple(pred.histories[i]) not in seen_golds:
g_scores.append(pred._states[i].score)
g_hist.append(list(pred.histories[i]))
for i in range(gold.size):
if gold._states[i].loss == 0:
g_scores.append(gold._states[i].score)
g_hist.append(list(gold.histories[i]))
all_probs = _softmax(p_scores + g_scores)
p_probs = all_probs[:len(p_scores)]
g_probs_all = all_probs[len(p_scores):]
g_probs = _softmax(g_scores)
self.cost = pred.loss
self.delta = d
self.p_hist = p_hist
self.g_hist = g_hist
# TODO: These variables are misnamed! These are the gradients of the loss.
self.p_probs = p_probs
# Intuition here:
# The gradient of the loss is:
# P(model) - P(truth)
# Normally, P(truth) is 1 for the gold
# But, if we want to do the "partial credit" scheme, we want
# to create a distribution over the gold, proportional to the scores
# awarded.
self.g_probs = [x-y for x, y in zip(g_probs_all, g_probs)]
def _softmax(nums):
if not nums:
return []
max_ = max(nums)
nums = [(exp(n-max_) if n is not None else None) for n in nums]
Z = sum(n for n in nums if n is not None)
return [(n/Z if n is not None else None) for n in nums]

View File

@ -12,13 +12,13 @@ import contextlib
import srsly import srsly
from thinc.api import set_dropout_rate, CupyOps, get_array_module from thinc.api import set_dropout_rate, CupyOps, get_array_module
from thinc.extra.search cimport Beam
from thinc.types import Ints1d from thinc.types import Ints1d
import numpy.random import numpy.random
import numpy import numpy
import warnings import warnings
from ._parser_internals.stateclass cimport StateC, StateClass from ._parser_internals.stateclass cimport StateC, StateClass
from ._parser_internals.search cimport Beam
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe
from ._parser_internals cimport _beam_utils from ._parser_internals cimport _beam_utils

View File

@ -1,5 +1,9 @@
import pytest import pytest
from spacy.util import get_lang_class from spacy.util import get_lang_class
import functools
import inspect
import importlib
import sys
def pytest_addoption(parser): def pytest_addoption(parser):
@ -41,6 +45,33 @@ def pytest_runtest_setup(item):
pytest.skip("not referencing any issues") pytest.skip("not referencing any issues")
# Decorator for Cython-built tests
# https://shwina.github.io/cython-testing/
def cytest(func):
"""
Wraps `func` in a plain Python function.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
bound = inspect.signature(func).bind(*args, **kwargs)
return func(*bound.args, **bound.kwargs)
return wrapped
def register_cython_tests(cython_mod_name: str, test_mod_name: str):
"""
Registers all callables with name `test_*` in Cython module `cython_mod_name`
as attributes in module `test_mod_name`, making them discoverable by pytest.
"""
cython_mod = importlib.import_module(cython_mod_name)
for name in dir(cython_mod):
item = getattr(cython_mod, name)
if callable(item) and name.startswith("test_"):
setattr(sys.modules[test_mod_name], name, item)
# Fixtures for language tokenizers (languages sorted alphabetically) # Fixtures for language tokenizers (languages sorted alphabetically)

View File

@ -0,0 +1,119 @@
# cython: infer_types=True, binding=True
from spacy.pipeline._parser_internals.search cimport Beam, MaxViolation
from spacy.typedefs cimport class_t, weight_t
from cymem.cymem cimport Pool
from ..conftest import cytest
import pytest
cdef struct TestState:
int length
int x
Py_UNICODE* string
cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1:
dest_state = <TestState*>dest
src_state = <TestState*>src
dest_state.length = src_state.length
dest_state.x = src_state.x
dest_state.x += clas
if extra_args != NULL:
dest_state.string = <Py_UNICODE*>extra_args
else:
dest_state.string = src_state.string
cdef void* initialize(Pool mem, int n, void* extra_args) except NULL:
state = <TestState*>mem.alloc(1, sizeof(TestState))
state.length = n
state.x = 1
if extra_args == NULL:
state.string = u'default'
else:
state.string = <Py_UNICODE*>extra_args
return state
cdef int destroy(Pool mem, void* state, void* extra_args) except -1:
state = <TestState*>state
mem.free(state)
@cytest
@pytest.mark.parametrize("nr_class,beam_width",
[
(2, 3),
(3, 6),
(4, 20),
]
)
def test_init(nr_class, beam_width):
b = Beam(nr_class, beam_width)
assert b.size == 1
assert b.width == beam_width
assert b.nr_class == nr_class
@cytest
def test_init_violn():
MaxViolation()
@cytest
@pytest.mark.parametrize("nr_class,beam_width,length",
[
(2, 3, 3),
(3, 6, 15),
(4, 20, 32),
]
)
def test_initialize(nr_class, beam_width, length):
b = Beam(nr_class, beam_width)
b.initialize(initialize, destroy, length, NULL)
for i in range(b.width):
s = <TestState*>b.at(i)
assert s.length == length, s.length
assert s.string == 'default'
@cytest
@pytest.mark.parametrize("nr_class,beam_width,length,extra",
[
(2, 3, 4, None),
(3, 6, 15, u"test beam 1"),
]
)
def test_initialize_extra(nr_class, beam_width, length, extra):
b = Beam(nr_class, beam_width)
if extra is None:
b.initialize(initialize, destroy, length, NULL)
else:
b.initialize(initialize, destroy, length, <void*><Py_UNICODE*>extra)
for i in range(b.width):
s = <TestState*>b.at(i)
assert s.length == length
@cytest
@pytest.mark.parametrize("nr_class,beam_width,length",
[
(3, 6, 15),
(4, 20, 32),
]
)
def test_transition(nr_class, beam_width, length):
b = Beam(nr_class, beam_width)
b.initialize(initialize, destroy, length, NULL)
b.set_cell(0, 2, 30, True, 0)
b.set_cell(0, 1, 42, False, 0)
b.advance(transition, NULL, NULL)
assert b.size == 1, b.size
assert b.score == 30, b.score
s = <TestState*>b.at(0)
assert s.x == 3
assert b._states[0].score == 30, b._states[0].score
b.set_cell(0, 1, 10, True, 0)
b.set_cell(0, 2, 20, True, 0)
b.advance(transition, NULL, NULL)
assert b._states[0].score == 50, b._states[0].score
assert b._states[1].score == 40
s = <TestState*>b.at(0)
assert s.x == 5

View File

@ -0,0 +1,3 @@
from ..conftest import register_cython_tests
register_cython_tests("spacy.tests.parser._search", __name__)