Add morphology serialization

This commit is contained in:
Matthew Honnibal 2019-08-29 21:17:34 +02:00
parent c94fc9edb9
commit fc0a3c8c38

View File

@ -15,6 +15,7 @@ from .parts_of_speech cimport SPACE
from .parts_of_speech import IDS as POS_IDS
from .lexeme cimport Lexeme
from .errors import Errors
from .util import ensure_path
cdef enum univ_field_t:
@ -162,12 +163,7 @@ cdef class Morphology:
self.n_tags = len(tag_map)
self.reverse_index = {}
self._feat_map = MorphologyClassMap(FEATURES)
for i, (tag_str, attrs) in enumerate(sorted(tag_map.items())):
attrs = _normalize_props(attrs)
self.add({self._feat_map.id2feat[feat] for feat in attrs
if feat in self._feat_map.id2feat})
self.tag_map[tag_str] = dict(attrs)
self.reverse_index[self.strings.add(tag_str)] = i
self._load_from_tag_map(tag_map)
self._cache = PreshMapArray(self.n_tags)
self.exc = {}
@ -177,6 +173,14 @@ cdef class Morphology:
self.add_special_case(
self.strings.as_string(tag), self.strings.as_string(orth), attrs)
def _load_from_tag_map(self, tag_map):
for i, (tag_str, attrs) in enumerate(sorted(tag_map.items())):
attrs = _normalize_props(attrs)
self.add({self._feat_map.id2feat[feat] for feat in attrs
if feat in self._feat_map.id2feat})
self.tag_map[tag_str] = dict(attrs)
self.reverse_index[self.strings.add(tag_str)] = i
def __reduce__(self):
return (Morphology, (self.strings, self.tag_map, self.lemmatizer,
self.exc), None, None)
@ -188,6 +192,7 @@ cdef class Morphology:
for f in features:
if isinstance(f, basestring_):
self.strings.add(f)
string_features = features
features = intify_features(features)
cdef attr_t feature
for feature in features:
@ -321,22 +326,34 @@ cdef class Morphology:
for form_str, attrs in entries.items():
self.add_special_case(tag_str, form_str, attrs)
def to_bytes(self):
json_tags = []
def to_bytes(self, exclude=tuple(), **kwargs):
tag_map = {}
for key in self.tags:
tag_ptr = <MorphAnalysisC*>self.tags.get(key)
if tag_ptr != NULL:
json_tags.append(tag_to_json(tag_ptr))
return srsly.json_dumps(json_tags)
tag_map[key] = tag_to_json(tag_ptr)
exceptions = {}
for (tag_str, orth_int), attrs in sorted(self.exc.items()):
exceptions.setdefault(tag_str, {})
exceptions[tag_str][self.strings[orth_int]] = attrs
data = {"tag_map": tag_map, "exceptions": exceptions}
return srsly.msgpack_dumps(data)
def from_bytes(self, byte_string):
raise NotImplementedError
msg = srsly.msgpack_loads(byte_string)
self._load_from_tag_map(msg["tag_map"])
self.load_morph_exceptions(msg["exceptions"])
return self
def to_disk(self, path):
raise NotImplementedError
def to_disk(self, path, exclude=tuple(), **kwargs):
path = ensure_path(path)
with path.open("wb") as file_:
file_.write(self.to_bytes())
def from_disk(self, path):
raise NotImplementedError
with path.open("rb") as file_:
byte_string = file_.read()
return self.from_bytes(byte_string)
@classmethod
def create_class_map(cls):