mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Bugfix initializing DocBin with attributes (#4368)
* docbin init fix + documentation fix + unit tests * newline * try with zlib instead of gzip (python 2 incompatibilities)
This commit is contained in:
parent
ce1d441de5
commit
4e7259c6cf
|
@ -12,6 +12,7 @@ import json
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
import spacy.util
|
import spacy.util
|
||||||
|
from bin.ud import conll17_ud_eval
|
||||||
from spacy.tokens import Token, Doc
|
from spacy.tokens import Token, Doc
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.util import compounding, minibatch, minibatch_by_words
|
from spacy.util import compounding, minibatch, minibatch_by_words
|
||||||
|
@ -25,8 +26,6 @@ import itertools
|
||||||
import random
|
import random
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
|
||||||
import conll17_ud_eval
|
|
||||||
|
|
||||||
from spacy import lang
|
from spacy import lang
|
||||||
from spacy.lang import zh
|
from spacy.lang import zh
|
||||||
from spacy.lang import ja
|
from spacy.lang import ja
|
||||||
|
|
|
@ -142,18 +142,28 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
|
||||||
elif key.upper() in stringy_attrs:
|
elif key.upper() in stringy_attrs:
|
||||||
stringy_attrs.pop(key.upper())
|
stringy_attrs.pop(key.upper())
|
||||||
for name, value in stringy_attrs.items():
|
for name, value in stringy_attrs.items():
|
||||||
if isinstance(name, int):
|
int_key = intify_attr(name)
|
||||||
int_key = name
|
if int_key is not None:
|
||||||
elif name in IDS:
|
if strings_map is not None and isinstance(value, basestring):
|
||||||
int_key = IDS[name]
|
if hasattr(strings_map, 'add'):
|
||||||
elif name.upper() in IDS:
|
value = strings_map.add(value)
|
||||||
int_key = IDS[name.upper()]
|
else:
|
||||||
else:
|
value = strings_map[value]
|
||||||
continue
|
inty_attrs[int_key] = value
|
||||||
if strings_map is not None and isinstance(value, basestring):
|
|
||||||
if hasattr(strings_map, 'add'):
|
|
||||||
value = strings_map.add(value)
|
|
||||||
else:
|
|
||||||
value = strings_map[value]
|
|
||||||
inty_attrs[int_key] = value
|
|
||||||
return inty_attrs
|
return inty_attrs
|
||||||
|
|
||||||
|
|
||||||
|
def intify_attr(name):
|
||||||
|
"""
|
||||||
|
Normalize an attribute name, converting it to int.
|
||||||
|
|
||||||
|
stringy_attr (string): Attribute string name. Can also be int (will then be left unchanged)
|
||||||
|
RETURNS (int): int representation of the attribute, or None if it couldn't be converted.
|
||||||
|
"""
|
||||||
|
if isinstance(name, int):
|
||||||
|
return name
|
||||||
|
elif name in IDS:
|
||||||
|
return IDS[name]
|
||||||
|
elif name.upper() in IDS:
|
||||||
|
return IDS[name.upper()]
|
||||||
|
return None
|
||||||
|
|
11
spacy/tests/regression/test_issue4367.py
Normal file
11
spacy/tests/regression/test_issue4367.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from spacy.tokens import DocBin
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue4367():
|
||||||
|
"""Test that docbin init goes well"""
|
||||||
|
doc_bin_1 = DocBin()
|
||||||
|
doc_bin_2 = DocBin(attrs=["LEMMA"])
|
||||||
|
doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
|
|
@ -1,8 +1,12 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.tokens import Doc
|
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.compat import path2str
|
from spacy.compat import path2str
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
@ -57,3 +61,17 @@ def test_serialize_doc_exclude(en_vocab):
|
||||||
doc.to_bytes(user_data=False)
|
doc.to_bytes(user_data=False)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Doc(en_vocab).from_bytes(doc.to_bytes(), tensor=False)
|
Doc(en_vocab).from_bytes(doc.to_bytes(), tensor=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_doc_bin():
|
||||||
|
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
||||||
|
texts = ["Some text", "Lots of texts...", "..."]
|
||||||
|
nlp = English()
|
||||||
|
for doc in nlp.pipe(texts):
|
||||||
|
doc_bin.add(doc)
|
||||||
|
bytes_data = doc_bin.to_bytes()
|
||||||
|
|
||||||
|
# Deserialize later, e.g. in a new process
|
||||||
|
nlp = spacy.blank("en")
|
||||||
|
doc_bin = DocBin().from_bytes(bytes_data)
|
||||||
|
docs = list(doc_bin.get_docs(nlp.vocab))
|
||||||
|
|
|
@ -2,13 +2,13 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import gzip
|
import zlib
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.neural.ops import NumpyOps
|
from thinc.neural.ops import NumpyOps
|
||||||
|
|
||||||
from ..compat import copy_reg
|
from ..compat import copy_reg
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import SPACY, ORTH, intify_attrs
|
from ..attrs import SPACY, ORTH, intify_attr
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class DocBin(object):
|
||||||
DOCS: https://spacy.io/api/docbin#init
|
DOCS: https://spacy.io/api/docbin#init
|
||||||
"""
|
"""
|
||||||
attrs = attrs or []
|
attrs = attrs or []
|
||||||
attrs = sorted(intify_attrs(attrs))
|
attrs = sorted([intify_attr(attr) for attr in attrs])
|
||||||
self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY]
|
self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY]
|
||||||
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
|
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
@ -142,7 +142,7 @@ class DocBin(object):
|
||||||
}
|
}
|
||||||
if self.store_user_data:
|
if self.store_user_data:
|
||||||
msg["user_data"] = self.user_data
|
msg["user_data"] = self.user_data
|
||||||
return gzip.compress(srsly.msgpack_dumps(msg))
|
return zlib.compress(srsly.msgpack_dumps(msg))
|
||||||
|
|
||||||
def from_bytes(self, bytes_data):
|
def from_bytes(self, bytes_data):
|
||||||
"""Deserialize the DocBin's annotations from a bytestring.
|
"""Deserialize the DocBin's annotations from a bytestring.
|
||||||
|
@ -152,7 +152,7 @@ class DocBin(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/docbin#from_bytes
|
DOCS: https://spacy.io/api/docbin#from_bytes
|
||||||
"""
|
"""
|
||||||
msg = srsly.msgpack_loads(gzip.decompress(bytes_data))
|
msg = srsly.msgpack_loads(zlib.decompress(bytes_data))
|
||||||
self.attrs = msg["attrs"]
|
self.attrs = msg["attrs"]
|
||||||
self.strings = set(msg["strings"])
|
self.strings = set(msg["strings"])
|
||||||
lengths = numpy.fromstring(msg["lengths"], dtype="int32")
|
lengths = numpy.fromstring(msg["lengths"], dtype="int32")
|
||||||
|
|
|
@ -84,7 +84,7 @@ texts = ["Some text", "Lots of texts...", "..."]
|
||||||
nlp = spacy.load("en_core_web_sm")
|
nlp = spacy.load("en_core_web_sm")
|
||||||
for doc in nlp.pipe(texts):
|
for doc in nlp.pipe(texts):
|
||||||
doc_bin.add(doc)
|
doc_bin.add(doc)
|
||||||
bytes_data = docbin.to_bytes()
|
bytes_data = doc_bin.to_bytes()
|
||||||
|
|
||||||
# Deserialize later, e.g. in a new process
|
# Deserialize later, e.g. in a new process
|
||||||
nlp = spacy.blank("en")
|
nlp = spacy.blank("en")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user