Normalize TokenC.sent_start values for Matcher (#5346)

Normalize TokenC.sent_start values to booleans for the `Matcher`.
This commit is contained in:
adrianeboyd 2020-04-29 12:57:30 +02:00 committed by GitHub
parent bdff76dede
commit 3f43c73d37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 5 deletions

View File

@ -14,7 +14,7 @@ import warnings
from ..typedefs cimport attr_t from ..typedefs cimport attr_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr from ..tokens.doc cimport Doc, get_token_attr_for_matcher
from ..tokens.span cimport Span from ..tokens.span cimport Span
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
@ -549,7 +549,7 @@ cdef char get_is_match(PatternStateC state,
spec = state.pattern spec = state.pattern
if spec.nr_attr > 0: if spec.nr_attr > 0:
for attr in spec.attrs[:spec.nr_attr]: for attr in spec.attrs[:spec.nr_attr]:
if get_token_attr(token, attr.attr) != attr.value: if get_token_attr_for_matcher(token, attr.attr) != attr.value:
return 0 return 0
for i in range(spec.nr_extra_attr): for i in range(spec.nr_extra_attr):
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]: if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
@ -720,7 +720,7 @@ class _RegexPredicate(object):
if self.is_extension: if self.is_extension:
value = token._.get(self.attr) value = token._.get(self.attr)
else: else:
value = token.vocab.strings[get_token_attr(token.c, self.attr)] value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)]
return bool(self.value.search(value)) return bool(self.value.search(value))
@ -741,7 +741,7 @@ class _SetMemberPredicate(object):
if self.is_extension: if self.is_extension:
value = get_string_id(token._.get(self.attr)) value = get_string_id(token._.get(self.attr))
else: else:
value = get_token_attr(token.c, self.attr) value = get_token_attr_for_matcher(token.c, self.attr)
if self.predicate == "IN": if self.predicate == "IN":
return value in self.value return value in self.value
else: else:
@ -768,7 +768,7 @@ class _ComparisonPredicate(object):
if self.is_extension: if self.is_extension:
value = token._.get(self.attr) value = token._.get(self.attr)
else: else:
value = get_token_attr(token.c, self.attr) value = get_token_attr_for_matcher(token.c, self.attr)
if self.predicate == "==": if self.predicate == "==":
return value == self.value return value == self.value
if self.predicate == "!=": if self.predicate == "!=":

View File

@ -8,6 +8,7 @@ from ..attrs cimport attr_id_t
cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil
ctypedef const LexemeC* const_Lexeme_ptr ctypedef const LexemeC* const_Lexeme_ptr

View File

@ -79,6 +79,16 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
return Lexeme.get_struct_attr(token.lex, feat_name) return Lexeme.get_struct_attr(token.lex, feat_name)
cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name) nogil:
if feat_name == SENT_START:
if token.sent_start == 1:
return True
else:
return False
else:
return get_token_attr(token, feat_name)
def _get_chunker(lang): def _get_chunker(lang):
try: try:
cls = util.get_lang_class(lang) cls = util.get_lang_class(lang)