mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Bugfix initializing DocBin with attributes (#4368)
* docbin init fix + documentation fix + unit tests * newline * try with zlib instead of gzip (python 2 incompatibilities)
This commit is contained in:
parent
ce1d441de5
commit
4e7259c6cf
|
@ -12,6 +12,7 @@ import json
|
|||
|
||||
import spacy
|
||||
import spacy.util
|
||||
from bin.ud import conll17_ud_eval
|
||||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.util import compounding, minibatch, minibatch_by_words
|
||||
|
@ -25,8 +26,6 @@ import itertools
|
|||
import random
|
||||
import numpy.random
|
||||
|
||||
import conll17_ud_eval
|
||||
|
||||
from spacy import lang
|
||||
from spacy.lang import zh
|
||||
from spacy.lang import ja
|
||||
|
|
|
@ -142,14 +142,8 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
|
|||
elif key.upper() in stringy_attrs:
|
||||
stringy_attrs.pop(key.upper())
|
||||
for name, value in stringy_attrs.items():
|
||||
if isinstance(name, int):
|
||||
int_key = name
|
||||
elif name in IDS:
|
||||
int_key = IDS[name]
|
||||
elif name.upper() in IDS:
|
||||
int_key = IDS[name.upper()]
|
||||
else:
|
||||
continue
|
||||
int_key = intify_attr(name)
|
||||
if int_key is not None:
|
||||
if strings_map is not None and isinstance(value, basestring):
|
||||
if hasattr(strings_map, 'add'):
|
||||
value = strings_map.add(value)
|
||||
|
@ -157,3 +151,19 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
|
|||
value = strings_map[value]
|
||||
inty_attrs[int_key] = value
|
||||
return inty_attrs
|
||||
|
||||
|
||||
def intify_attr(name):
|
||||
"""
|
||||
Normalize an attribute name, converting it to int.
|
||||
|
||||
stringy_attr (string): Attribute string name. Can also be int (will then be left unchanged)
|
||||
RETURNS (int): int representation of the attribute, or None if it couldn't be converted.
|
||||
"""
|
||||
if isinstance(name, int):
|
||||
return name
|
||||
elif name in IDS:
|
||||
return IDS[name]
|
||||
elif name.upper() in IDS:
|
||||
return IDS[name.upper()]
|
||||
return None
|
||||
|
|
11
spacy/tests/regression/test_issue4367.py
Normal file
11
spacy/tests/regression/test_issue4367.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.tokens import DocBin
|
||||
|
||||
|
||||
def test_issue4367():
|
||||
"""Test that docbin init goes well"""
|
||||
doc_bin_1 = DocBin()
|
||||
doc_bin_2 = DocBin(attrs=["LEMMA"])
|
||||
doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
|
|
@ -1,8 +1,12 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import spacy
|
||||
|
||||
import pytest
|
||||
from spacy.tokens import Doc
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.compat import path2str
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
@ -57,3 +61,17 @@ def test_serialize_doc_exclude(en_vocab):
|
|||
doc.to_bytes(user_data=False)
|
||||
with pytest.raises(ValueError):
|
||||
Doc(en_vocab).from_bytes(doc.to_bytes(), tensor=False)
|
||||
|
||||
|
||||
def test_serialize_doc_bin():
|
||||
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
||||
texts = ["Some text", "Lots of texts...", "..."]
|
||||
nlp = English()
|
||||
for doc in nlp.pipe(texts):
|
||||
doc_bin.add(doc)
|
||||
bytes_data = doc_bin.to_bytes()
|
||||
|
||||
# Deserialize later, e.g. in a new process
|
||||
nlp = spacy.blank("en")
|
||||
doc_bin = DocBin().from_bytes(bytes_data)
|
||||
docs = list(doc_bin.get_docs(nlp.vocab))
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import numpy
|
||||
import gzip
|
||||
import zlib
|
||||
import srsly
|
||||
from thinc.neural.ops import NumpyOps
|
||||
|
||||
from ..compat import copy_reg
|
||||
from ..tokens import Doc
|
||||
from ..attrs import SPACY, ORTH, intify_attrs
|
||||
from ..attrs import SPACY, ORTH, intify_attr
|
||||
from ..errors import Errors
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ class DocBin(object):
|
|||
DOCS: https://spacy.io/api/docbin#init
|
||||
"""
|
||||
attrs = attrs or []
|
||||
attrs = sorted(intify_attrs(attrs))
|
||||
attrs = sorted([intify_attr(attr) for attr in attrs])
|
||||
self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY]
|
||||
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
|
||||
self.tokens = []
|
||||
|
@ -142,7 +142,7 @@ class DocBin(object):
|
|||
}
|
||||
if self.store_user_data:
|
||||
msg["user_data"] = self.user_data
|
||||
return gzip.compress(srsly.msgpack_dumps(msg))
|
||||
return zlib.compress(srsly.msgpack_dumps(msg))
|
||||
|
||||
def from_bytes(self, bytes_data):
|
||||
"""Deserialize the DocBin's annotations from a bytestring.
|
||||
|
@ -152,7 +152,7 @@ class DocBin(object):
|
|||
|
||||
DOCS: https://spacy.io/api/docbin#from_bytes
|
||||
"""
|
||||
msg = srsly.msgpack_loads(gzip.decompress(bytes_data))
|
||||
msg = srsly.msgpack_loads(zlib.decompress(bytes_data))
|
||||
self.attrs = msg["attrs"]
|
||||
self.strings = set(msg["strings"])
|
||||
lengths = numpy.fromstring(msg["lengths"], dtype="int32")
|
||||
|
|
|
@ -84,7 +84,7 @@ texts = ["Some text", "Lots of texts...", "..."]
|
|||
nlp = spacy.load("en_core_web_sm")
|
||||
for doc in nlp.pipe(texts):
|
||||
doc_bin.add(doc)
|
||||
bytes_data = docbin.to_bytes()
|
||||
bytes_data = doc_bin.to_bytes()
|
||||
|
||||
# Deserialize later, e.g. in a new process
|
||||
nlp = spacy.blank("en")
|
||||
|
|
Loading…
Reference in New Issue
Block a user