* Reorganize the serialization functions on Doc

This commit is contained in:
Matthew Honnibal 2015-07-22 04:53:01 +02:00
parent 109106a949
commit 4d61239eac

View File

@ -5,7 +5,6 @@ import numpy
import struct import struct
from ..lexeme cimport EMPTY_LEXEME from ..lexeme cimport EMPTY_LEXEME
from ..strings cimport slice_unicode
from ..typedefs cimport attr_t, flags_t from ..typedefs cimport attr_t, flags_t
from ..attrs cimport attr_id_t from ..attrs cimport attr_id_t
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, LENGTH, CLUSTER from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, LENGTH, CLUSTER
@ -15,7 +14,6 @@ from ..parts_of_speech cimport CONJ, PUNCT
from ..lexeme cimport check_flag from ..lexeme cimport check_flag
from ..lexeme cimport get_attr as get_lex_attr from ..lexeme cimport get_attr as get_lex_attr
from .spans import Span from .spans import Span
from ..structs cimport UniStr
from .token cimport Token from .token cimport Token
from ..serialize.bits cimport BitArray from ..serialize.bits cimport BitArray
@ -273,10 +271,62 @@ cdef class Doc:
cdef int set_parse(self, const TokenC* parsed) except -1: cdef int set_parse(self, const TokenC* parsed) except -1:
# TODO: This method is fairly misleading atm. It's used by Parser # TODO: This method is fairly misleading atm. It's used by Parser
# to actually apply the parse calculated. Need to rethink this. # to actually apply the parse calculated. Need to rethink this.
# Probably we should use from_array?
self.is_parsed = True self.is_parsed = True
for i in range(self.length): for i in range(self.length):
self.data[i] = parsed[i] self.data[i] = parsed[i]
def from_array(self, attrs, array):
cdef int i, col
cdef attr_id_t attr_id
cdef TokenC* tokens = self.data
cdef int length = len(array)
for col, attr_id in enumerate(attrs):
values = array[:, col]
if attr_id == HEAD:
# TODO: Set left and right children
for i in range(length):
tokens[i].head = values[i]
elif attr_id == TAG:
for i in range(length):
tokens[i].tag = values[i]
elif attr_id == DEP:
for i in range(length):
tokens[i].dep = values[i]
elif attr_id == ENT_IOB:
for i in range(length):
tokens[i].ent_iob = values[i]
elif attr_id == ENT_TYPE:
for i in range(length):
tokens[i].ent_type = values[i]
return self
def to_bytes(self):
bits = self.vocab.packer.pack(self)
return struct.pack('I', len(bits)) + bits.as_bytes()
def from_bytes(self, data):
bits = BitArray(data)
self.vocab.packer.unpack_into(bits, self)
return self
@staticmethod
def read_bytes(file_):
keep_reading = True
while keep_reading:
try:
n_bits_str = file_.read(4)
if len(n_bits_str) < 4:
break
n_bits = struct.unpack('I', n_bits_str)[0]
n_bytes = n_bits // 8 + bool(n_bits % 8)
data = file_.read(n_bytes)
except StopIteration:
keep_reading = False
yield data
# This function is terrible --- need to fix this.
def merge(self, int start_idx, int end_idx, unicode tag, unicode lemma, def merge(self, int start_idx, int end_idx, unicode tag, unicode lemma,
unicode ent_type): unicode ent_type):
"""Merge a multi-word expression into a single token. Currently """Merge a multi-word expression into a single token. Currently
@ -296,9 +346,8 @@ cdef class Doc:
return None return None
cdef unicode string = self.string cdef unicode string = self.string
# Get LexemeC for newly merged token # Get LexemeC for newly merged token
cdef UniStr new_orth_c new_orth = string[start_idx:end_idx]
slice_unicode(&new_orth_c, string, start_idx, end_idx) cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
cdef const LexemeC* lex = self.vocab.get(self.mem, &new_orth_c)
# House the new merged token where it starts # House the new merged token where it starts
cdef TokenC* token = &self.data[start] cdef TokenC* token = &self.data[start]
# Update fields # Update fields
@ -361,47 +410,3 @@ cdef class Doc:
# Return the merged Python object # Return the merged Python object
return self[start] return self[start]
def from_array(self, attrs, array):
cdef int i, col
cdef attr_id_t attr_id
cdef TokenC* tokens = self.data
cdef int length = len(array)
for col, attr_id in enumerate(attrs):
values = array[:, col]
if attr_id == HEAD:
for i in range(length):
tokens[i].head = values[i]
elif attr_id == TAG:
for i in range(length):
tokens[i].tag = values[i]
elif attr_id == DEP:
for i in range(length):
tokens[i].dep = values[i]
elif attr_id == ENT_IOB:
for i in range(length):
tokens[i].ent_iob = values[i]
elif attr_id == ENT_TYPE:
for i in range(length):
tokens[i].ent_type = values[i]
def to_bytes(self):
bits = self.vocab.packer.pack(self)
return struct.pack('I', len(bits)) + bits.as_bytes()
@staticmethod
def from_bytes(Vocab vocab, file_):
keep_reading = True
while keep_reading:
try:
n_bits_str = file_.read(4)
if len(n_bits_str) < 4:
break
n_bits = struct.unpack('I', n_bits_str)[0]
n_bytes = n_bits // 8 + bool(n_bits % 8)
data = file_.read(n_bytes)
except StopIteration:
keep_reading = False
bits = BitArray(data)
doc = vocab.packer.unpack(bits)
yield doc