mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Rework the Span-merge patch, to avoid extending the interface of Doc, and avoid virtualizing the Span.start and Span.end indices, to keep Span usage efficient
This commit is contained in:
parent
83ca4e0b93
commit
56499d89ef
|
@ -19,6 +19,12 @@ ctypedef fused LexemeOrToken:
|
|||
const_TokenC_ptr
|
||||
|
||||
|
||||
cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2
|
||||
|
||||
|
||||
cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
|
||||
|
||||
|
||||
cdef class Doc:
|
||||
cdef readonly Pool mem
|
||||
cdef readonly Vocab vocab
|
||||
|
|
|
@ -438,47 +438,19 @@ cdef class Doc:
|
|||
keep_reading = False
|
||||
yield n_bytes_str + data
|
||||
|
||||
def token_index_start(self, int start_idx):
|
||||
""" Get index of token in doc that has character index start_idx """
|
||||
cdef int i
|
||||
for i in range(self.length):
|
||||
if self.c[i].idx == start_idx:
|
||||
return i
|
||||
return None
|
||||
|
||||
def token_index_end(self, int end_idx):
|
||||
""" Get index+1 of token in doc ending with character index end_idx """
|
||||
cdef int i
|
||||
for i in range(self.length):
|
||||
if (self.c[i].idx + self.c[i].lex.length) == end_idx:
|
||||
return i + 1
|
||||
return None
|
||||
|
||||
def range_from_indices(self, int start_idx, int end_idx):
|
||||
""" Get tuple - span of token indices which correspond to
|
||||
character indices (start_idx, end_idx) if such a span exists"""
|
||||
cdef int i
|
||||
cdef int start = -1
|
||||
cdef int end = -1
|
||||
for i in range(self.length):
|
||||
if self.c[i].idx == start_idx:
|
||||
start = i
|
||||
if (self.c[i].idx + self.c[i].lex.length) == end_idx:
|
||||
if start == -1:
|
||||
return None
|
||||
end = i + 1
|
||||
return (start, end)
|
||||
return None
|
||||
|
||||
# This function is terrible --- need to fix this.
|
||||
def merge(self, int start_idx, int end_idx, unicode tag, unicode lemma,
|
||||
unicode ent_type):
|
||||
"""Merge a multi-word expression into a single token. Currently
|
||||
experimental; API is likely to change."""
|
||||
start_end = self.range_from_indices(start_idx, end_idx)
|
||||
if start_end is None:
|
||||
cdef int start = token_by_start(self.c, self.length, start_idx)
|
||||
if start == -1:
|
||||
return None
|
||||
start, end = start_end
|
||||
cdef int end = token_by_end(self.c, self.length, end_idx)
|
||||
if end == -1:
|
||||
return None
|
||||
# Currently we have the token index, we want the range-end index
|
||||
end += 1
|
||||
|
||||
cdef Span span = self[start:end]
|
||||
# Get LexemeC for newly merged token
|
||||
new_orth = ''.join([t.text_with_ws for t in span])
|
||||
|
@ -541,6 +513,24 @@ cdef class Doc:
|
|||
return self[start]
|
||||
|
||||
|
||||
cdef int token_by_start(const TokenC* tokens, int length, int start_char) except -2:
|
||||
cdef int i
|
||||
for i in range(length):
|
||||
if self.c[i].idx == start_char:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2:
|
||||
cdef int i
|
||||
for i in range(length):
|
||||
if tokens[i].idx + tokens[i].lex.length == end_char:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
cdef int set_children_from_heads(TokenC* tokens, int length) except -1:
|
||||
cdef TokenC* head
|
||||
cdef TokenC* child
|
||||
|
|
|
@ -4,10 +4,10 @@ from .doc cimport Doc
|
|||
cdef class Span:
|
||||
cdef readonly Doc doc
|
||||
cdef public int i
|
||||
cdef public int start_token
|
||||
cdef public int end_token
|
||||
cdef public int start_idx
|
||||
cdef public int end_idx
|
||||
cdef public int start
|
||||
cdef public int end
|
||||
cdef public int start_char
|
||||
cdef public int end_char
|
||||
cdef readonly int label
|
||||
|
||||
cdef public _vector
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..typedefs cimport flags_t, attr_t
|
|||
from ..attrs cimport attr_id_t
|
||||
from ..parts_of_speech cimport univ_pos_t
|
||||
from ..util import normalize_slice
|
||||
from .doc cimport token_by_start, token_by_end
|
||||
|
||||
|
||||
cdef class Span:
|
||||
|
@ -21,10 +22,10 @@ cdef class Span:
|
|||
raise IndexError
|
||||
|
||||
self.doc = tokens
|
||||
self.start_token = start
|
||||
self.start_idx = self.doc[start].idx
|
||||
self.end_token = end
|
||||
self.end_idx = self.doc[end - 1].idx + len(self.doc[end - 1])
|
||||
self.start = start
|
||||
self.start_char = self.doc[start].idx
|
||||
self.end = end
|
||||
self.end_char = self.doc[end - 1].idx + len(self.doc[end - 1])
|
||||
self.label = label
|
||||
self._vector = vector
|
||||
self._vector_norm = vector_norm
|
||||
|
@ -32,19 +33,20 @@ cdef class Span:
|
|||
def __richcmp__(self, Span other, int op):
|
||||
# Eq
|
||||
if op == 0:
|
||||
return self.start < other.start
|
||||
return self.start_char < other.start_char
|
||||
elif op == 1:
|
||||
return self.start <= other.start
|
||||
return self.start_char <= other.start_char
|
||||
elif op == 2:
|
||||
return self.start == other.start and self.end == other.end
|
||||
return self.start_char == other.start_idx and self.end_char == other.end_char
|
||||
elif op == 3:
|
||||
return self.start != other.start or self.end != other.end
|
||||
return self.start_char != other.start_char or self.end_char != other.end_char
|
||||
elif op == 4:
|
||||
return self.start > other.start
|
||||
return self.start_char > other.start_char
|
||||
elif op == 5:
|
||||
return self.start >= other.start
|
||||
return self.start_char >= other.start_char
|
||||
|
||||
def __len__(self):
|
||||
self.recalculate_indices()
|
||||
if self.end < self.start:
|
||||
return 0
|
||||
return self.end - self.start
|
||||
|
@ -55,68 +57,41 @@ cdef class Span:
|
|||
return self.text.encode('utf-8')
|
||||
|
||||
def __getitem__(self, object i):
|
||||
self.recalculate_indices()
|
||||
if isinstance(i, slice):
|
||||
start, end = normalize_slice(len(self), i.start, i.stop, i.step)
|
||||
start += self.start
|
||||
end += self.start
|
||||
return Span(self.doc, start, end)
|
||||
|
||||
if i < 0:
|
||||
return self.doc[self.end + i]
|
||||
return Span(self.doc, start + self.start, end + self.start)
|
||||
else:
|
||||
return self.doc[self.start + i]
|
||||
if i < 0:
|
||||
return self.doc.c[self.end + i]
|
||||
else:
|
||||
return self.doc.c[self.start + i]
|
||||
|
||||
def __iter__(self):
|
||||
self.recalculate_indices()
|
||||
for i in range(self.start, self.end):
|
||||
yield self.doc[i]
|
||||
|
||||
def merge(self, unicode tag, unicode lemma, unicode ent_type):
|
||||
self.doc.merge(self[0].idx, self[-1].idx + len(self[-1]), tag, lemma, ent_type)
|
||||
self.doc.merge(self.start_char, self.end_char, tag, lemma, ent_type)
|
||||
|
||||
def similarity(self, other):
|
||||
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
|
||||
return 0.0
|
||||
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
|
||||
property start:
|
||||
""" Get start token index of this span from the Doc."""
|
||||
def __get__(self):
|
||||
# first is the first token of the span - get it from the doc
|
||||
first = None
|
||||
if self.start_token < len(self.doc):
|
||||
first = self.doc[self.start_token]
|
||||
# if we have merged spans in Doc start might have changed.
|
||||
# check if token start index is in doc index range and the token
|
||||
# index is start_idx (it hasn't changed).
|
||||
if first is None or first.idx != self.start_idx:
|
||||
# go through tokens in Doc - find index of token equal to start_idx
|
||||
new_start = self.doc.token_index_start(self.start_idx)
|
||||
if new_start is not None:
|
||||
self.start_token = new_start
|
||||
else:
|
||||
raise IndexError('Something went terribly wrong during a merge.'
|
||||
'No token found with idx %s' % self.start_idx)
|
||||
return self.start_token
|
||||
cpdef int recalculate_indices(self) except -1:
|
||||
if self.end >= doc.length \
|
||||
or tokens[self.start].idx != self.start_char \
|
||||
or (tokens[self.end-1].idx + tokens[self.end-1].length) != self.end_char:
|
||||
start = token_by_start(self.doc.c, self.doc.length, self.start_char)
|
||||
if self.start == -1:
|
||||
raise IndexError("Error calculating span: Can't find start")
|
||||
if end == -1:
|
||||
raise IndexError("Error calculating span: Can't find end")
|
||||
|
||||
property end:
|
||||
""" Get end token index of this span from the Doc."""
|
||||
def __get__(self):
|
||||
# last is the last token of the span - get it from the doc
|
||||
last = None
|
||||
if self.end_token <= len(self.doc):
|
||||
last = self.doc[self.end_token -1]
|
||||
# if we have merged spans in Doc end will have changed.
|
||||
# check if token end index is in doc index range and the token
|
||||
# index is end_idx + len(last_token) (it hasn't changed).
|
||||
if last is None or last.idx + len(last) != self.end_idx:
|
||||
# go through tokens in Doc - find index of token equal to end_idx
|
||||
new_end = self.doc.token_index_end(self.end_idx)
|
||||
if new_end is not None:
|
||||
self.end_token = new_end
|
||||
else:
|
||||
raise IndexError('Something went terribly wrong during a merge.'
|
||||
'No token found with idx %s' % self.end_idx)
|
||||
return self.end_token
|
||||
self.start = start
|
||||
self.end = end + 1
|
||||
|
||||
property vector:
|
||||
def __get__(self):
|
||||
|
@ -179,6 +154,7 @@ cdef class Span:
|
|||
'Autumn'
|
||||
"""
|
||||
def __get__(self):
|
||||
self.recalculate_indices()
|
||||
# This should probably be called 'head', and the other one called
|
||||
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
|
||||
cdef const TokenC* start = &self.doc.c[self.start]
|
||||
|
|
Loading…
Reference in New Issue
Block a user