Handle scalar values in doc.from_array()

This commit is contained in:
Matthew Honnibal 2019-03-10 16:54:03 +01:00
parent 61e5ce02a4
commit d6eaa71afc

View File

@ -747,7 +747,7 @@ cdef class Doc:
attrs = [attrs]
# Allow strings, e.g. 'lemma' or 'LEMMA'
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
for id_ in attrs]
for id_ in attrs]
if SENT_START in attrs and HEAD in attrs:
raise ValueError(Errors.E032)
@ -761,6 +761,8 @@ cdef class Doc:
attr_ids = <attr_id_t*>mem.alloc(n_attrs, sizeof(attr_id_t))
for i, attr_id in enumerate(attrs):
attr_ids[i] = attr_id
if len(array.shape) == 1:
array = array.reshape((array.size, 1))
# Now load the data
for i in range(self.length):
token = &self.c[i]