mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Normalize TokenC.sent_start values for Matcher (#5346)
Normalize TokenC.sent_start values to booleans for the `Matcher`.
This commit is contained in:
parent
bdff76dede
commit
3f43c73d37
|
@ -14,7 +14,7 @@ import warnings
|
|||
from ..typedefs cimport attr_t
|
||||
from ..structs cimport TokenC
|
||||
from ..vocab cimport Vocab
|
||||
from ..tokens.doc cimport Doc, get_token_attr
|
||||
from ..tokens.doc cimport Doc, get_token_attr_for_matcher
|
||||
from ..tokens.span cimport Span
|
||||
from ..tokens.token cimport Token
|
||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
||||
|
@ -549,7 +549,7 @@ cdef char get_is_match(PatternStateC state,
|
|||
spec = state.pattern
|
||||
if spec.nr_attr > 0:
|
||||
for attr in spec.attrs[:spec.nr_attr]:
|
||||
if get_token_attr(token, attr.attr) != attr.value:
|
||||
if get_token_attr_for_matcher(token, attr.attr) != attr.value:
|
||||
return 0
|
||||
for i in range(spec.nr_extra_attr):
|
||||
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
|
||||
|
@ -720,7 +720,7 @@ class _RegexPredicate(object):
|
|||
if self.is_extension:
|
||||
value = token._.get(self.attr)
|
||||
else:
|
||||
value = token.vocab.strings[get_token_attr(token.c, self.attr)]
|
||||
value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)]
|
||||
return bool(self.value.search(value))
|
||||
|
||||
|
||||
|
@ -741,7 +741,7 @@ class _SetMemberPredicate(object):
|
|||
if self.is_extension:
|
||||
value = get_string_id(token._.get(self.attr))
|
||||
else:
|
||||
value = get_token_attr(token.c, self.attr)
|
||||
value = get_token_attr_for_matcher(token.c, self.attr)
|
||||
if self.predicate == "IN":
|
||||
return value in self.value
|
||||
else:
|
||||
|
@ -768,7 +768,7 @@ class _ComparisonPredicate(object):
|
|||
if self.is_extension:
|
||||
value = token._.get(self.attr)
|
||||
else:
|
||||
value = get_token_attr(token.c, self.attr)
|
||||
value = get_token_attr_for_matcher(token.c, self.attr)
|
||||
if self.predicate == "==":
|
||||
return value == self.value
|
||||
if self.predicate == "!=":
|
||||
|
|
|
@ -8,6 +8,7 @@ from ..attrs cimport attr_id_t
|
|||
|
||||
|
||||
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
|
||||
|
||||
|
||||
ctypedef const LexemeC* const_Lexeme_ptr
|
||||
|
|
|
@ -79,6 +79,16 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
|||
return Lexeme.get_struct_attr(token.lex, feat_name)
|
||||
|
||||
|
||||
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil:
|
||||
if feat_name == SENT_START:
|
||||
if token.sent_start == 1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return get_token_attr(token, feat_name)
|
||||
|
||||
|
||||
def _get_chunker(lang):
|
||||
try:
|
||||
cls = util.get_lang_class(lang)
|
||||
|
|
Loading…
Reference in New Issue
Block a user