mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Fix serializer
This commit is contained in:
parent
b16ae75824
commit
4efb391994
|
@ -1,4 +1,6 @@
|
|||
# coding: utf8
|
||||
# cython: infer_types=True
|
||||
# cython: bounds_check=False
|
||||
from __future__ import unicode_literals
|
||||
|
||||
cimport cython
|
||||
|
@ -565,7 +567,7 @@ cdef class Doc:
|
|||
for i in range(self.length):
|
||||
self.c[i] = parsed[i]
|
||||
|
||||
def from_array(self, attrs, array):
|
||||
def from_array(self, attrs, int[:, :] array):
|
||||
"""
|
||||
Write to a `Doc` object, from an `(M, N)` array of attributes.
|
||||
"""
|
||||
|
@ -573,34 +575,23 @@ cdef class Doc:
|
|||
cdef attr_id_t attr_id
|
||||
cdef TokenC* tokens = self.c
|
||||
cdef int length = len(array)
|
||||
cdef attr_t[:] values
|
||||
# Get set up for fast loading
|
||||
cdef Pool mem = Pool()
|
||||
cdef int n_attrs = len(attrs)
|
||||
attr_ids = <attr_id_t*>mem.alloc(n_attrs, sizeof(attr_id_t))
|
||||
for i, attr_id in enumerate(attrs):
|
||||
attr_ids[i] = attr_id
|
||||
# Now load the data
|
||||
for i in range(self.length):
|
||||
token = &self.c[i]
|
||||
for j in range(n_attrs):
|
||||
Token.set_struct_attr(token, attr_ids[j], array[i, j])
|
||||
# Auxiliary loading logic
|
||||
for col, attr_id in enumerate(attrs):
|
||||
values = array[:, col]
|
||||
if attr_id == HEAD:
|
||||
if attr_id == TAG:
|
||||
for i in range(length):
|
||||
tokens[i].head = values[i]
|
||||
if values[i] >= 1:
|
||||
tokens[i + values[i]].l_kids += 1
|
||||
elif values[i] < 0:
|
||||
tokens[i + values[i]].r_kids += 1
|
||||
elif attr_id == TAG:
|
||||
for i in range(length):
|
||||
if values[i] != 0:
|
||||
self.vocab.morphology.assign_tag(&tokens[i], values[i])
|
||||
elif attr_id == POS:
|
||||
for i in range(length):
|
||||
tokens[i].pos = <univ_pos_t>values[i]
|
||||
elif attr_id == DEP:
|
||||
for i in range(length):
|
||||
tokens[i].dep = values[i]
|
||||
elif attr_id == ENT_IOB:
|
||||
for i in range(length):
|
||||
tokens[i].ent_iob = values[i]
|
||||
elif attr_id == ENT_TYPE:
|
||||
for i in range(length):
|
||||
tokens[i].ent_type = values[i]
|
||||
else:
|
||||
raise ValueError("Unknown attribute ID: %d" % attr_id)
|
||||
if array[i, col] != 0:
|
||||
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
|
||||
set_children_from_heads(self.c, self.length)
|
||||
self.is_parsed = bool(HEAD in attrs or DEP in attrs)
|
||||
self.is_tagged = bool(TAG in attrs or POS in attrs)
|
||||
|
@ -645,9 +636,9 @@ cdef class Doc:
|
|||
self.push_back(lex, has_space)
|
||||
|
||||
start = end + has_space
|
||||
self.from_array(attrs[:, 2:],
|
||||
[TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE])
|
||||
|
||||
self.from_array([TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE],
|
||||
attrs[:, 2:])
|
||||
return self
|
||||
|
||||
def merge(self, int start_idx, int end_idx, *args, **attributes):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user