Cast to uint64 for all array-based doc representations

This commit is contained in:
Adriane Boyd 2022-12-06 18:00:05 +01:00
parent ca0cae2074
commit b9c524917a
6 changed files with 15 additions and 13 deletions

View File

@ -295,7 +295,7 @@ def make_docs(nlp, batch, min_length, max_length):
raise ValueError(Errors.E138.format(text=record)) raise ValueError(Errors.E138.format(text=record))
if "heads" in record: if "heads" in record:
heads = record["heads"] heads = record["heads"]
heads = numpy.asarray(heads, dtype="uint64") heads = numpy.asarray([numpy.array(h).astype(numpy.uint64) for h in heads], dtype="uint64")
heads = heads.reshape((len(doc), 1)) heads = heads.reshape((len(doc), 1))
doc = doc.from_array([HEAD], heads) doc = doc.from_array([HEAD], heads)
if len(doc) >= min_length and len(doc) < max_length: if len(doc) >= min_length and len(doc) < max_length:

View File

@ -1,6 +1,7 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy
import pytest import pytest
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.attrs import ORTH, SHAPE, POS, DEP from spacy.attrs import ORTH, SHAPE, POS, DEP
@ -91,14 +92,14 @@ def test_doc_from_array_heads_in_bounds(en_vocab):
# head before start # head before start
arr = doc.to_array(["HEAD"]) arr = doc.to_array(["HEAD"])
arr[0] = -1 arr[0] = numpy.array(-1).astype(numpy.uint64)
doc_from_array = Doc(en_vocab, words=words) doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError): with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr) doc_from_array.from_array(["HEAD"], arr)
# head after end # head after end
arr = doc.to_array(["HEAD"]) arr = doc.to_array(["HEAD"])
arr[0] = 5 arr[0] = numpy.array(5).astype(numpy.uint64)
doc_from_array = Doc(en_vocab, words=words) doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError): with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr) doc_from_array.from_array(["HEAD"], arr)

View File

@ -37,9 +37,9 @@ def test_en_noun_chunks_not_nested(en_vocab):
[0, root], [0, root],
[4, amod], [4, amod],
[3, nmod], [3, nmod],
[-1, cc], [numpy.array(-1).astype(numpy.uint64), cc],
[-2, conj], [numpy.array(-2).astype(numpy.uint64), conj],
[-5, dobj], [numpy.array(-5).astype(numpy.uint64), dobj],
], ],
dtype="uint64", dtype="uint64",
), ),

View File

@ -58,11 +58,12 @@ def get_doc(
for annot in annotations: for annot in annotations:
if annot: if annot:
if annot is heads: if annot is heads:
annot = numpy.array(heads).astype(numpy.uint64)
for i in range(len(words)): for i in range(len(words)):
if attrs.ndim == 1: if attrs.ndim == 1:
attrs[i] = heads[i] attrs[i] = annot[i]
else: else:
attrs[i, j] = heads[i] attrs[i, j] = annot[i]
else: else:
for i in range(len(words)): for i in range(len(words)):
if attrs.ndim == 1: if attrs.ndim == 1:

View File

@ -805,7 +805,7 @@ cdef class Doc:
`(M, N)` array of attributes. `(M, N)` array of attributes.
attrs (list) A list of attribute ID ints. attrs (list) A list of attribute ID ints.
array (numpy.ndarray[ndim=2, dtype='int32']): The attribute values. array (numpy.ndarray[ndim=2, dtype='uint64']): The attribute values.
RETURNS (Doc): Itself. RETURNS (Doc): Itself.
DOCS: https://spacy.io/api/doc#from_array DOCS: https://spacy.io/api/doc#from_array
@ -845,9 +845,9 @@ cdef class Doc:
col = attrs.index(HEAD) col = attrs.index(HEAD)
for i in range(length): for i in range(length):
# cast index to signed int # cast index to signed int
abs_head_index = numpy.int32(array[i, col]) + i abs_head_index = array[i, col].astype(numpy.int32) + i
if abs_head_index < 0 or abs_head_index >= length: if abs_head_index < 0 or abs_head_index >= length:
raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=numpy.int32(array[i, col]))) raise ValueError(Errors.E190.format(index=i, value=array[i, col], rel_head_index=abs_head_index-i))
# Do TAG first. This lets subsequent loop override stuff like POS, LEMMA # Do TAG first. This lets subsequent loop override stuff like POS, LEMMA
if TAG in attrs: if TAG in attrs:
col = attrs.index(TAG) col = attrs.index(TAG)

View File

@ -272,7 +272,7 @@ cdef class Span:
for ancestor in ancestors: for ancestor in ancestors:
ancestor_i = ancestor.i - self.start ancestor_i = ancestor.i - self.start
if ancestor_i in range(length): if ancestor_i in range(length):
array[i, head_col] = ancestor_i - i array[i, head_col] = numpy.array(ancestor_i - i).astype(numpy.uint64)
# if there is no appropriate ancestor, define a new artificial root # if there is no appropriate ancestor, define a new artificial root
value = array[i, head_col] value = array[i, head_col]
@ -280,7 +280,7 @@ cdef class Span:
new_root = old_to_new_root.get(ancestor_i, None) new_root = old_to_new_root.get(ancestor_i, None)
if new_root is not None: if new_root is not None:
# take the same artificial root as a previous token from the same sentence # take the same artificial root as a previous token from the same sentence
array[i, head_col] = new_root - i array[i, head_col] = numpy.array(new_root - i).astype(numpy.uint64)
else: else:
# set this token as the new artificial root # set this token as the new artificial root
array[i, head_col] = 0 array[i, head_col] = 0