Refactor to remove duplicate slicing logic

This commit is contained in:
Yubing (Tom) Dong 2015-10-07 01:25:35 -07:00
parent 97685aecb7
commit 3fd3bc79aa
3 changed files with 28 additions and 30 deletions

View File

@ -21,6 +21,7 @@ from ..lexeme cimport Lexeme
from .spans cimport Span from .spans cimport Span
from .token cimport Token from .token cimport Token
from ..serialize.bits cimport BitArray from ..serialize.bits cimport BitArray
from ..util import normalize_slice
DEF PADDING = 5 DEF PADDING = 5
@ -87,14 +88,8 @@ cdef class Doc:
token (Token): token (Token):
""" """
if isinstance(i, slice): if isinstance(i, slice):
if not (i.step is None or i.step == 1): start, stop = normalize_slice(len(self), i.start, i.stop, i.step)
raise ValueError("Stepped slices not supported in Span objects." return Span(self, start, stop, label=0)
"Try: list(doc)[start:stop:step] instead.")
if i.start is None:
i = slice(0, i.stop)
if i.stop is None:
i = slice(i.start, len(self))
return Span(self, i.start, i.stop, label=0)
if i < 0: if i < 0:
i = self.length + i i = self.length + i

View File

@ -9,19 +9,15 @@ from ..structs cimport TokenC, LexemeC
from ..typedefs cimport flags_t, attr_t from ..typedefs cimport flags_t, attr_t
from ..attrs cimport attr_id_t from ..attrs cimport attr_id_t
from ..parts_of_speech cimport univ_pos_t from ..parts_of_speech cimport univ_pos_t
from ..util import normalize_slice
cdef class Span: cdef class Span:
"""A slice from a Doc object.""" """A slice from a Doc object."""
def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None, def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None,
vector_norm=None): vector_norm=None):
if start < 0: if not (0 <= start <= end <= len(tokens)):
start = tokens.length + start raise IndexError
start = min(tokens.length, max(0, start))
if end < 0:
end = tokens.length + end
end = min(tokens.length, max(start, end))
self.doc = tokens self.doc = tokens
self.start = start self.start = start
@ -52,23 +48,10 @@ cdef class Span:
def __getitem__(self, object i): def __getitem__(self, object i):
if isinstance(i, slice): if isinstance(i, slice):
start, end, step = i.start, i.stop, i.step start, end = normalize_slice(len(self), i.start, i.stop, i.step)
if start is None:
start = 0
elif start < 0:
start += len(self)
start = min(len(self), max(0, start))
if end is None:
end = len(self)
elif end < 0:
end += len(self)
end = min(len(self), max(start, end))
start += self.start start += self.start
end += self.start end += self.start
return Span(self.doc, start, end)
return self.doc[start:end:i.step]
if i < 0: if i < 0:
return self.doc[self.end + i] return self.doc[self.end + i]

View File

@ -7,6 +7,26 @@ from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
DATA_DIR = path.join(path.dirname(__file__), '..', 'data') DATA_DIR = path.join(path.dirname(__file__), '..', 'data')
def normalize_slice(length, start, stop, step=None):
if not (step is None or step == 1):
raise ValueError("Stepped slices not supported in Span objects."
"Try: list(tokens)[start:stop:step] instead.")
if start is None:
start = 0
elif start < 0:
start += length
start = min(length, max(0, start))
if stop is None:
stop = length
elif stop < 0:
stop += length
stop = min(length, max(start, stop))
assert 0 <= start <= stop <= length
return start, stop
def utf8open(loc, mode='r'): def utf8open(loc, mode='r'):
return codecs.open(loc, mode, 'utf8') return codecs.open(loc, mode, 'utf8')