Bugfix initializing DocBin with attributes (#4368)

* docbin init fix + documentation fix + unit tests

* newline

* try with zlib instead of gzip (python 2 incompatibilities)
This commit is contained in:
Sofie Van Landeghem 2019-10-03 14:48:45 +02:00 committed by Matthew Honnibal
parent ce1d441de5
commit 4e7259c6cf
6 changed files with 61 additions and 23 deletions

View File

@ -12,6 +12,7 @@ import json
import spacy import spacy
import spacy.util import spacy.util
from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.util import compounding, minibatch, minibatch_by_words from spacy.util import compounding, minibatch, minibatch_by_words
@ -25,8 +26,6 @@ import itertools
import random import random
import numpy.random import numpy.random
import conll17_ud_eval
from spacy import lang from spacy import lang
from spacy.lang import zh from spacy.lang import zh
from spacy.lang import ja from spacy.lang import ja

View File

@ -142,14 +142,8 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
elif key.upper() in stringy_attrs: elif key.upper() in stringy_attrs:
stringy_attrs.pop(key.upper()) stringy_attrs.pop(key.upper())
for name, value in stringy_attrs.items(): for name, value in stringy_attrs.items():
if isinstance(name, int): int_key = intify_attr(name)
int_key = name if int_key is not None:
elif name in IDS:
int_key = IDS[name]
elif name.upper() in IDS:
int_key = IDS[name.upper()]
else:
continue
if strings_map is not None and isinstance(value, basestring): if strings_map is not None and isinstance(value, basestring):
if hasattr(strings_map, 'add'): if hasattr(strings_map, 'add'):
value = strings_map.add(value) value = strings_map.add(value)
@ -157,3 +151,19 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
value = strings_map[value] value = strings_map[value]
inty_attrs[int_key] = value inty_attrs[int_key] = value
return inty_attrs return inty_attrs
def intify_attr(name):
"""
Normalize an attribute name, converting it to int.
stringy_attr (string): Attribute string name. Can also be int (will then be left unchanged)
RETURNS (int): int representation of the attribute, or None if it couldn't be converted.
"""
if isinstance(name, int):
return name
elif name in IDS:
return IDS[name]
elif name.upper() in IDS:
return IDS[name.upper()]
return None

View File

@ -0,0 +1,11 @@
# coding: utf8
from __future__ import unicode_literals
from spacy.tokens import DocBin
def test_issue4367():
"""Test that docbin init goes well"""
doc_bin_1 = DocBin()
doc_bin_2 = DocBin(attrs=["LEMMA"])
doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])

View File

@ -1,8 +1,12 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import spacy
import pytest import pytest
from spacy.tokens import Doc
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.compat import path2str from spacy.compat import path2str
from ..util import make_tempdir from ..util import make_tempdir
@ -57,3 +61,17 @@ def test_serialize_doc_exclude(en_vocab):
doc.to_bytes(user_data=False) doc.to_bytes(user_data=False)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Doc(en_vocab).from_bytes(doc.to_bytes(), tensor=False) Doc(en_vocab).from_bytes(doc.to_bytes(), tensor=False)
def test_serialize_doc_bin():
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
texts = ["Some text", "Lots of texts...", "..."]
nlp = English()
for doc in nlp.pipe(texts):
doc_bin.add(doc)
bytes_data = doc_bin.to_bytes()
# Deserialize later, e.g. in a new process
nlp = spacy.blank("en")
doc_bin = DocBin().from_bytes(bytes_data)
docs = list(doc_bin.get_docs(nlp.vocab))

View File

@ -2,13 +2,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy import numpy
import gzip import zlib
import srsly import srsly
from thinc.neural.ops import NumpyOps from thinc.neural.ops import NumpyOps
from ..compat import copy_reg from ..compat import copy_reg
from ..tokens import Doc from ..tokens import Doc
from ..attrs import SPACY, ORTH, intify_attrs from ..attrs import SPACY, ORTH, intify_attr
from ..errors import Errors from ..errors import Errors
@ -53,7 +53,7 @@ class DocBin(object):
DOCS: https://spacy.io/api/docbin#init DOCS: https://spacy.io/api/docbin#init
""" """
attrs = attrs or [] attrs = attrs or []
attrs = sorted(intify_attrs(attrs)) attrs = sorted([intify_attr(attr) for attr in attrs])
self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY] self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY]
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0] self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
self.tokens = [] self.tokens = []
@ -142,7 +142,7 @@ class DocBin(object):
} }
if self.store_user_data: if self.store_user_data:
msg["user_data"] = self.user_data msg["user_data"] = self.user_data
return gzip.compress(srsly.msgpack_dumps(msg)) return zlib.compress(srsly.msgpack_dumps(msg))
def from_bytes(self, bytes_data): def from_bytes(self, bytes_data):
"""Deserialize the DocBin's annotations from a bytestring. """Deserialize the DocBin's annotations from a bytestring.
@ -152,7 +152,7 @@ class DocBin(object):
DOCS: https://spacy.io/api/docbin#from_bytes DOCS: https://spacy.io/api/docbin#from_bytes
""" """
msg = srsly.msgpack_loads(gzip.decompress(bytes_data)) msg = srsly.msgpack_loads(zlib.decompress(bytes_data))
self.attrs = msg["attrs"] self.attrs = msg["attrs"]
self.strings = set(msg["strings"]) self.strings = set(msg["strings"])
lengths = numpy.fromstring(msg["lengths"], dtype="int32") lengths = numpy.fromstring(msg["lengths"], dtype="int32")

View File

@ -84,7 +84,7 @@ texts = ["Some text", "Lots of texts...", "..."]
nlp = spacy.load("en_core_web_sm") nlp = spacy.load("en_core_web_sm")
for doc in nlp.pipe(texts): for doc in nlp.pipe(texts):
doc_bin.add(doc) doc_bin.add(doc)
bytes_data = docbin.to_bytes() bytes_data = doc_bin.to_bytes()
# Deserialize later, e.g. in a new process # Deserialize later, e.g. in a new process
nlp = spacy.blank("en") nlp = spacy.blank("en")