mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Handle scalar values in doc.from_array()
This commit is contained in:
parent
61e5ce02a4
commit
d6eaa71afc
|
@ -747,7 +747,7 @@ cdef class Doc:
|
||||||
attrs = [attrs]
|
attrs = [attrs]
|
||||||
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
||||||
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
attrs = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||||
for id_ in attrs]
|
for id_ in attrs]
|
||||||
|
|
||||||
if SENT_START in attrs and HEAD in attrs:
|
if SENT_START in attrs and HEAD in attrs:
|
||||||
raise ValueError(Errors.E032)
|
raise ValueError(Errors.E032)
|
||||||
|
@ -761,6 +761,8 @@ cdef class Doc:
|
||||||
attr_ids = <attr_id_t*>mem.alloc(n_attrs, sizeof(attr_id_t))
|
attr_ids = <attr_id_t*>mem.alloc(n_attrs, sizeof(attr_id_t))
|
||||||
for i, attr_id in enumerate(attrs):
|
for i, attr_id in enumerate(attrs):
|
||||||
attr_ids[i] = attr_id
|
attr_ids[i] = attr_id
|
||||||
|
if len(array.shape) == 1:
|
||||||
|
array = array.reshape((array.size, 1))
|
||||||
# Now load the data
|
# Now load the data
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
token = &self.c[i]
|
token = &self.c[i]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user