mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
Fix StringStore after symbols changes
This commit is contained in:
parent
d0ca64bb07
commit
bcfe3bd312
|
@ -11,11 +11,15 @@ import srsly
|
|||
|
||||
from .compat import basestring_
|
||||
from .symbols import IDS as SYMBOLS_BY_STR
|
||||
from .symbols import NAMES as SYMBOLS_BY_INT
|
||||
from . import symbols
|
||||
from .typedefs cimport hash_t
|
||||
from .errors import Errors
|
||||
from . import util
|
||||
|
||||
SYMBOLS_BY_INT = {}
|
||||
for name in symbols.NAMES:
|
||||
SYMBOLS_BY_INT[SYMBOLS_BY_STR[name]] = name
|
||||
print(SYMBOLS_BY_INT[6005])
|
||||
|
||||
def get_string_id(key):
|
||||
"""Get a string ID, handling the reserved symbols correctly. If the key is
|
||||
|
@ -116,6 +120,8 @@ cdef class StringStore:
|
|||
return u''
|
||||
elif string_or_id in SYMBOLS_BY_STR:
|
||||
return SYMBOLS_BY_STR[string_or_id]
|
||||
elif string_or_id in SYMBOLS_BY_INT:
|
||||
return SYMBOLS_BY_INT[string_or_id]
|
||||
cdef hash_t key
|
||||
if isinstance(string_or_id, unicode):
|
||||
key = hash_string(string_or_id)
|
||||
|
@ -123,8 +129,6 @@ cdef class StringStore:
|
|||
elif isinstance(string_or_id, bytes):
|
||||
key = hash_utf8(string_or_id, len(string_or_id))
|
||||
return key
|
||||
elif string_or_id < len(SYMBOLS_BY_INT):
|
||||
return SYMBOLS_BY_INT[string_or_id]
|
||||
else:
|
||||
key = string_or_id
|
||||
self.hits.insert(key)
|
||||
|
@ -181,11 +185,14 @@ cdef class StringStore:
|
|||
string (unicode): The string to check.
|
||||
RETURNS (bool): Whether the store contains the string.
|
||||
"""
|
||||
global SYMBOLS_BY_INT
|
||||
cdef hash_t key
|
||||
if isinstance(string, int) or isinstance(string, long):
|
||||
if string == 0:
|
||||
return True
|
||||
key = string
|
||||
if key in SYMBOLS_BY_INT:
|
||||
return True
|
||||
elif len(string) == 0:
|
||||
return True
|
||||
elif string in SYMBOLS_BY_STR:
|
||||
|
@ -195,11 +202,8 @@ cdef class StringStore:
|
|||
else:
|
||||
string = string.encode('utf8')
|
||||
key = hash_utf8(string, len(string))
|
||||
if key < len(SYMBOLS_BY_INT):
|
||||
return True
|
||||
else:
|
||||
self.hits.insert(key)
|
||||
return self._map.get(key) is not NULL
|
||||
self.hits.insert(key)
|
||||
return self._map.get(key) is not NULL
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the strings in the store, in order.
|
||||
|
|
Loading…
Reference in New Issue
Block a user