From bcfe3bd3122a61147d31425566060da90e997115 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 7 Mar 2019 12:51:11 +0100 Subject: [PATCH] Fix StringStore after symbols changes --- spacy/strings.pyx | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 64954503f..0565b2a0a 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -11,11 +11,15 @@ import srsly from .compat import basestring_ from .symbols import IDS as SYMBOLS_BY_STR -from .symbols import NAMES as SYMBOLS_BY_INT +from . import symbols from .typedefs cimport hash_t from .errors import Errors from . import util +SYMBOLS_BY_INT = {} +for name in symbols.NAMES: + SYMBOLS_BY_INT[SYMBOLS_BY_STR[name]] = name +print(SYMBOLS_BY_INT[6005]) def get_string_id(key): """Get a string ID, handling the reserved symbols correctly. If the key is @@ -116,6 +120,8 @@ cdef class StringStore: return u'' elif string_or_id in SYMBOLS_BY_STR: return SYMBOLS_BY_STR[string_or_id] + elif string_or_id in SYMBOLS_BY_INT: + return SYMBOLS_BY_INT[string_or_id] cdef hash_t key if isinstance(string_or_id, unicode): key = hash_string(string_or_id) @@ -123,8 +129,6 @@ cdef class StringStore: elif isinstance(string_or_id, bytes): key = hash_utf8(string_or_id, len(string_or_id)) return key - elif string_or_id < len(SYMBOLS_BY_INT): - return SYMBOLS_BY_INT[string_or_id] else: key = string_or_id self.hits.insert(key) @@ -181,11 +185,14 @@ cdef class StringStore: string (unicode): The string to check. RETURNS (bool): Whether the store contains the string. """ + global SYMBOLS_BY_INT cdef hash_t key if isinstance(string, int) or isinstance(string, long): if string == 0: return True key = string + if key in SYMBOLS_BY_INT: + return True elif len(string) == 0: return True elif string in SYMBOLS_BY_STR: @@ -195,11 +202,8 @@ cdef class StringStore: else: string = string.encode('utf8') key = hash_utf8(string, len(string)) - if key < len(SYMBOLS_BY_INT): - return True - else: - self.hits.insert(key) - return self._map.get(key) is not NULL + self.hits.insert(key) + return self._map.get(key) is not NULL def __iter__(self): """Iterate over the strings in the store, in order.