Make RPCError subclasses unpicklable again (#1387)

This commit is contained in:
Tulir Asokan 2020-02-14 19:22:17 +02:00 committed by GitHub
parent 8bd60f7cde
commit c6bd620555
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 5 deletions

View File

@ -26,13 +26,16 @@ def generate_errors(errors, f):
# Error classes generation # Error classes generation
for error in errors: for error in errors:
f.write('\n\nclass {}({}):\n' f.write('\n\nclass {}({}):\n '.format(error.name, error.subclass))
' def __init__(self, **kwargs):\n'
' '.format(error.name, error.subclass))
if error.has_captures: if error.has_captures:
f.write("self.{} = int(kwargs.get('capture', 0))\n " f.write('def __init__(self, request, capture=0):\n '
' self.request = request\n ')
f.write(' self.{} = int(capture)\n '
.format(error.capture_name)) .format(error.capture_name))
else:
f.write('def __init__(self, request):\n '
' self.request = request\n ')
f.write('super(Exception, self).__init__(' f.write('super(Exception, self).__init__('
'{}'.format(repr(error.description))) '{}'.format(repr(error.description)))
@ -40,7 +43,12 @@ def generate_errors(errors, f):
if error.has_captures: if error.has_captures:
f.write('.format({0}=self.{0})'.format(error.capture_name)) f.write('.format({0}=self.{0})'.format(error.capture_name))
f.write(" + self._fmt_request(kwargs['request']))\n") f.write(' + self._fmt_request(self.request))\n\n')
f.write(' def __reduce__(self):\n ')
if error.has_captures:
f.write('return type(self), (self.request, self.{})\n'.format(error.capture_name))
else:
f.write('return type(self), (self.request,)\n')
# Create the actual {CODE: ErrorClassName} dict once classes are defined # Create the actual {CODE: ErrorClassName} dict once classes are defined
f.write('\n\nrpc_errors_dict = {\n') f.write('\n\nrpc_errors_dict = {\n')

View File

@ -0,0 +1,35 @@
import pickle
from telethon.errors import RPCError, BadRequestError, FileIdInvalidError, NetworkMigrateError
def _assert_equality(error, unpickled_error):
assert error.code == unpickled_error.code
assert error.message == unpickled_error.message
assert type(error) == type(unpickled_error)
assert str(error) == str(unpickled_error)
def test_base_rpcerror_pickle():
error = RPCError("request", "message", 123)
unpickled_error = pickle.loads(pickle.dumps(error))
_assert_equality(error, unpickled_error)
def test_rpcerror_pickle():
error = BadRequestError("request", "BAD_REQUEST", 400)
unpickled_error = pickle.loads(pickle.dumps(error))
_assert_equality(error, unpickled_error)
def test_fancy_rpcerror_pickle():
error = FileIdInvalidError("request")
unpickled_error = pickle.loads(pickle.dumps(error))
_assert_equality(error, unpickled_error)
def test_fancy_rpcerror_capture_pickle():
error = NetworkMigrateError(request="request", capture=5)
unpickled_error = pickle.loads(pickle.dumps(error))
_assert_equality(error, unpickled_error)
assert error.new_dc == unpickled_error.new_dc