mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 21:51:24 +03:00 
			
		
		
		
	* Move `thinc.extra.search` to `spacy.pipeline._parser_internals` Backport of: https://github.com/explosion/spaCy/pull/11317 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Replace references to `thinc.backends.linalg` with `CBlas` Backport of: https://github.com/explosion/spaCy/pull/11292 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Use cross entropy from `thinc.legacy` * Require thinc>=9.0.0.dev0,<9.1.0 Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
		
			
				
	
	
		
			120 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			120 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True, binding=True
 | |
| from spacy.pipeline._parser_internals.search cimport Beam, MaxViolation
 | |
| from spacy.typedefs cimport class_t, weight_t
 | |
| from cymem.cymem cimport Pool
 | |
| 
 | |
| from ..conftest import cytest
 | |
| import pytest
 | |
| 
 | |
| cdef struct TestState:
 | |
|     int length
 | |
|     int x
 | |
|     Py_UNICODE* string
 | |
| 
 | |
| 
 | |
| cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1:
 | |
|     dest_state = <TestState*>dest
 | |
|     src_state = <TestState*>src
 | |
|     dest_state.length = src_state.length
 | |
|     dest_state.x = src_state.x
 | |
|     dest_state.x += clas
 | |
|     if extra_args != NULL:
 | |
|         dest_state.string = <Py_UNICODE*>extra_args
 | |
|     else:
 | |
|         dest_state.string = src_state.string
 | |
| 
 | |
| 
 | |
| cdef void* initialize(Pool mem, int n, void* extra_args) except NULL:
 | |
|     state = <TestState*>mem.alloc(1, sizeof(TestState))
 | |
|     state.length = n
 | |
|     state.x = 1
 | |
|     if extra_args == NULL:
 | |
|         state.string = u'default'
 | |
|     else:
 | |
|         state.string = <Py_UNICODE*>extra_args
 | |
|     return state
 | |
| 
 | |
| 
 | |
| cdef int destroy(Pool mem, void* state, void* extra_args) except -1:
 | |
|     state = <TestState*>state
 | |
|     mem.free(state)
 | |
| 
 | |
| @cytest
 | |
| @pytest.mark.parametrize("nr_class,beam_width",
 | |
|     [
 | |
|         (2, 3),
 | |
|         (3, 6),
 | |
|         (4, 20),
 | |
|     ]
 | |
| )
 | |
| def test_init(nr_class, beam_width):
 | |
|     b = Beam(nr_class, beam_width)
 | |
|     assert b.size == 1
 | |
|     assert b.width == beam_width
 | |
|     assert b.nr_class == nr_class
 | |
| 
 | |
| @cytest
 | |
| def test_init_violn():
 | |
|     MaxViolation()
 | |
| 
 | |
| @cytest
 | |
| @pytest.mark.parametrize("nr_class,beam_width,length",
 | |
|     [
 | |
|         (2, 3, 3),
 | |
|         (3, 6, 15),
 | |
|         (4, 20, 32),
 | |
|     ]
 | |
| )
 | |
| def test_initialize(nr_class, beam_width, length):
 | |
|     b = Beam(nr_class, beam_width)
 | |
|     b.initialize(initialize, destroy, length, NULL)
 | |
|     for i in range(b.width):
 | |
|         s = <TestState*>b.at(i)
 | |
|         assert s.length == length, s.length
 | |
|         assert s.string == 'default'
 | |
| 
 | |
| 
 | |
| @cytest
 | |
| @pytest.mark.parametrize("nr_class,beam_width,length,extra",
 | |
|     [
 | |
|         (2, 3, 4, None),
 | |
|         (3, 6, 15, u"test beam 1"),
 | |
|     ]
 | |
| )
 | |
| def test_initialize_extra(nr_class, beam_width, length, extra):
 | |
|     b = Beam(nr_class, beam_width)
 | |
|     if extra is None:
 | |
|         b.initialize(initialize, destroy, length, NULL)
 | |
|     else:
 | |
|         b.initialize(initialize, destroy, length, <void*><Py_UNICODE*>extra)
 | |
|     for i in range(b.width):
 | |
|         s = <TestState*>b.at(i)
 | |
|         assert s.length == length
 | |
| 
 | |
| 
 | |
| @cytest
 | |
| @pytest.mark.parametrize("nr_class,beam_width,length",
 | |
|     [
 | |
|         (3, 6, 15),
 | |
|         (4, 20, 32),
 | |
|     ]
 | |
| )
 | |
| def test_transition(nr_class, beam_width, length):
 | |
|     b = Beam(nr_class, beam_width)
 | |
|     b.initialize(initialize, destroy, length, NULL)
 | |
|     b.set_cell(0, 2, 30, True, 0)
 | |
|     b.set_cell(0, 1, 42, False, 0)
 | |
|     b.advance(transition, NULL, NULL)
 | |
|     assert b.size == 1, b.size
 | |
|     assert b.score == 30, b.score
 | |
|     s = <TestState*>b.at(0)
 | |
|     assert s.x == 3
 | |
|     assert b._states[0].score == 30, b._states[0].score
 | |
|     b.set_cell(0, 1, 10, True, 0)
 | |
|     b.set_cell(0, 2, 20, True, 0)
 | |
|     b.advance(transition, NULL, NULL)
 | |
|     assert b._states[0].score == 50, b._states[0].score
 | |
|     assert b._states[1].score == 40
 | |
|     s = <TestState*>b.at(0)
 | |
|     assert s.x == 5
 |