Replace Entity/MatchStruct with SpanC (#4459)

* Replace MatchStruct with Entity

Replace MatchStruct with Entity since the existing Entity struct is
nearly identical.

* Replace Entity with more general SpanC
This commit is contained in:
adrianeboyd 2019-10-18 11:01:47 +02:00 committed by Matthew Honnibal
parent 29e3da6493
commit d359da9687
5 changed files with 24 additions and 25 deletions

View File

@ -4,6 +4,7 @@ from cymem.cymem cimport Pool
from preshed.maps cimport key_t, MapStruct
from ..attrs cimport attr_id_t
from ..structs cimport SpanC
from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
@ -18,10 +19,4 @@ cdef class PhraseMatcher:
cdef Pool mem
cdef key_t _terminal_hash
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
cdef struct MatchStruct:
key_t match_id
int start
int end
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil

View File

@ -9,6 +9,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
from ..structs cimport TokenC
from ..tokens.token cimport Token
from ..typedefs cimport attr_t
from ._schemas import TOKEN_PATTERN_SCHEMA
from ..errors import Errors, Warnings, deprecation_warning, user_warning
@ -222,17 +223,17 @@ cdef class PhraseMatcher:
# if doc is empty or None just return empty list
return matches
cdef vector[MatchStruct] c_matches
cdef vector[SpanC] c_matches
self.find_matches(doc, &c_matches)
for i in range(c_matches.size()):
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end))
for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(self.vocab.strings[ent_id])
if on_match is not None:
on_match(self, doc, i, matches)
return matches
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil:
cdef MapStruct* current_node = self.c_map
cdef int start = 0
cdef int idx = 0
@ -240,7 +241,7 @@ cdef class PhraseMatcher:
cdef key_t key
cdef void* value
cdef int i = 0
cdef MatchStruct ms
cdef SpanC ms
cdef void* result
while idx < doc.length:
start = idx
@ -255,7 +256,7 @@ cdef class PhraseMatcher:
if result:
i = 0
while map_iter(<MapStruct*>result, &i, &key, &value):
ms = make_matchstruct(key, start, idy)
ms = make_spanstruct(key, start, idy)
matches.push_back(ms)
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
result = map_get(current_node, inner_token)
@ -270,7 +271,7 @@ cdef class PhraseMatcher:
if result:
i = 0
while map_iter(<MapStruct*>result, &i, &key, &value):
ms = make_matchstruct(key, start, idy)
ms = make_spanstruct(key, start, idy)
matches.push_back(ms)
current_node = self.c_map
idx += 1
@ -320,9 +321,9 @@ def unpickle_matcher(vocab, docs, callbacks, attr):
return matcher
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
cdef MatchStruct ms
ms.match_id = match_id
ms.start = start
ms.end = end
return ms
cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil:
cdef SpanC spanc
spanc.label = label
spanc.start = start
spanc.end = end
return spanc

View File

@ -47,11 +47,14 @@ cdef struct SerializedLexemeC:
# + sizeof(float) # l2_norm
cdef struct Entity:
cdef struct SpanC:
hash_t id
int start
int end
int start_char
int end_char
attr_t label
attr_t kb_id
cdef struct TokenC:

View File

@ -7,7 +7,7 @@ from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from murmurhash.mrmr cimport hash64
from ..vocab cimport EMPTY_LEXEME
from ..structs cimport TokenC, Entity
from ..structs cimport TokenC, SpanC
from ..lexeme cimport Lexeme
from ..symbols cimport punct
from ..attrs cimport IS_SPACE
@ -40,7 +40,7 @@ cdef cppclass StateC:
int* _buffer
bint* shifted
TokenC* _sent
Entity* _ents
SpanC* _ents
TokenC _empty_token
RingBufferC _hist
int length
@ -56,7 +56,7 @@ cdef cppclass StateC:
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC))
if not (this._buffer and this._stack and this.shifted
and this._sent and this._ents):
with gil:
@ -406,7 +406,7 @@ cdef cppclass StateC:
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
memcpy(this._stack, src._stack, this.length * sizeof(int))
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
memcpy(this._ents, src._ents, this.length * sizeof(Entity))
memcpy(this._ents, src._ents, this.length * sizeof(SpanC))
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
this._b_i = src._b_i
this._s_i = src._s_i

View File

@ -3,7 +3,7 @@ from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool
cimport cython
from ..structs cimport TokenC, Entity
from ..structs cimport TokenC, SpanC
from ..typedefs cimport attr_t
from ..vocab cimport EMPTY_LEXEME