* Rework the Span-merge patch, to avoid extending the interface of Doc, and avoid virtualizing the Span.start and Span.end indices, to keep Span usage efficient

This commit is contained in:
Matthew Honnibal 2015-11-07 08:56:49 +11:00
parent 56499d89ef
commit a9b612abdf
3 changed files with 15 additions and 11 deletions

View File

@ -516,7 +516,7 @@ cdef class Doc:
cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2: cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2:
cdef int i cdef int i
for i in range(length): for i in range(length):
if self.c[i].idx == start_char: if tokens[i].idx == start_char:
return i return i
else: else:
return -1 return -1

View File

@ -12,3 +12,6 @@ cdef class Span:
cdef public _vector cdef public _vector
cdef public _vector_norm cdef public _vector_norm
cpdef int _recalculate_indices(self) except -1

View File

@ -46,7 +46,7 @@ cdef class Span:
return self.start_char >= other.start_char return self.start_char >= other.start_char
def __len__(self): def __len__(self):
self.recalculate_indices() self._recalculate_indices()
if self.end < self.start: if self.end < self.start:
return 0 return 0
return self.end - self.start return self.end - self.start
@ -57,18 +57,18 @@ cdef class Span:
return self.text.encode('utf-8') return self.text.encode('utf-8')
def __getitem__(self, object i): def __getitem__(self, object i):
self.recalculate_indices() self._recalculate_indices()
if isinstance(i, slice): if isinstance(i, slice):
start, end = normalize_slice(len(self), i.start, i.stop, i.step) start, end = normalize_slice(len(self), i.start, i.stop, i.step)
return Span(self.doc, start + self.start, end + self.start) return Span(self.doc, start + self.start, end + self.start)
else: else:
if i < 0: if i < 0:
return self.doc.c[self.end + i] return self.doc[self.end + i]
else: else:
return self.doc.c[self.start + i] return self.doc[self.start + i]
def __iter__(self): def __iter__(self):
self.recalculate_indices() self._recalculate_indices()
for i in range(self.start, self.end): for i in range(self.start, self.end):
yield self.doc[i] yield self.doc[i]
@ -80,13 +80,14 @@ cdef class Span:
return 0.0 return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm) return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
cpdef int recalculate_indices(self) except -1: cpdef int _recalculate_indices(self) except -1:
if self.end >= doc.length \ if self.end >= self.doc.length \
or tokens[self.start].idx != self.start_char \ or self.doc.c[self.start].idx != self.start_char \
or (tokens[self.end-1].idx + tokens[self.end-1].length) != self.end_char: or (self.doc.c[self.end-1].idx + self.doc.c[self.end-1].lex.length) != self.end_char:
start = token_by_start(self.doc.c, self.doc.length, self.start_char) start = token_by_start(self.doc.c, self.doc.length, self.start_char)
if self.start == -1: if self.start == -1:
raise IndexError("Error calculating span: Can't find start") raise IndexError("Error calculating span: Can't find start")
end = token_by_end(self.doc.c, self.doc.length, self.end_char)
if end == -1: if end == -1:
raise IndexError("Error calculating span: Can't find end") raise IndexError("Error calculating span: Can't find end")
@ -154,7 +155,7 @@ cdef class Span:
'Autumn' 'Autumn'
""" """
def __get__(self): def __get__(self):
self.recalculate_indices() self._recalculate_indices()
# This should probably be called 'head', and the other one called # This should probably be called 'head', and the other one called
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/ # 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
cdef const TokenC* start = &self.doc.c[self.start] cdef const TokenC* start = &self.doc.c[self.start]