Add alignment_mode argument to Span.char_span()

This commit is contained in:
Simon Gurcke 2023-01-23 11:25:51 +10:00 committed by GitHub
parent f9e020dd67
commit 02f2af3ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View File

@ -98,6 +98,7 @@ class Span:
label: Union[int, str] = ...,
kb_id: Union[int, str] = ...,
vector: Optional[Floats1d] = ...,
alignment_mode: str = ...,
) -> Span: ...
@property
def conjuncts(self) -> Tuple[Token]: ...

View File

@ -362,7 +362,7 @@ cdef class Span:
result = xp.dot(vector, other.vector) / (self.vector_norm * other.vector_norm)
# ensure we get a scalar back (numpy does this automatically but cupy doesn't)
return result.item()
cpdef np.ndarray to_array(self, object py_attr_ids):
"""Given a list of M attribute IDs, export the tokens to a numpy
`ndarray` of shape `(N, M)`, where `N` is the length of the document.
@ -639,7 +639,7 @@ cdef class Span:
else:
return self.doc[root]
def char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None, id=0):
def char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None, alignment_mode="strict", id=0):
"""Create a `Span` object from the slice `span.text[start : end]`.
start (int): The index of the first character of the span.
@ -649,11 +649,16 @@ cdef class Span:
kb_id (uint64 or string): An ID from a KB to capture the meaning of a named entity.
vector (ndarray[ndim=1, dtype='float32']): A meaning representation of
the span.
alignment_mode (str): How character indices are aligned to token
boundaries. Options: "strict" (character indices must be aligned
with token boundaries), "contract" (span of all tokens completely
within the character span), "expand" (span of all tokens at least
partially covered by the character span). Defaults to "strict".
RETURNS (Span): The newly constructed object.
"""
start_idx += self.c.start_char
end_idx += self.c.start_char
return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector)
return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector, alignment_mode=alignment_mode)
@property
def conjuncts(self):