Turn HashChecker into CdnDecrypter to abstract CDN-specific aspects

This commit is contained in:
Lonami Exo 2017-08-28 16:25:10 +02:00
parent b504ce14bc
commit 8afcd0b91f
4 changed files with 129 additions and 112 deletions

View File

@ -1,4 +1,4 @@
from .aes import AES from .aes import AES
from .auth_key import AuthKey from .auth_key import AuthKey
from .factorization import Factorization from .factorization import Factorization
from .hash_checker import HashChecker from .cdn_decrypter import CdnDecrypter

View File

@ -0,0 +1,116 @@
from hashlib import sha256
import pyaes
from ..tl import JsonSession
from ..tl.functions.upload import GetCdnFileRequest, ReuploadCdnFileRequest
from ..tl.types.upload import CdnFileReuploadNeeded
from ..errors import CdnFileTamperedError
class CdnDecrypter:
"""Used when downloading a file results in a 'FileCdnRedirect' to
both prepare the redirect, decrypt the file as it downloads, and
ensure the file hasn't been tampered.
"""
def __init__(self, cdn_client, file_token, cdn_aes, cdn_file_hashes):
self.client = cdn_client
self.file_token = file_token
self.cdn_aes = cdn_aes
self.cdn_file_hashes = cdn_file_hashes
self.shaes = [sha256() for _ in range(len(cdn_file_hashes))]
@staticmethod
def prepare_decrypter(client, client_cls, cdn_redirect, offset, part_size):
"""Prepares a CDN decrypter, returning (decrypter, file data).
'client' should be the original TelegramBareClient that
tried to download the file.
'client_cls' should be the class of the TelegramBareClient.
"""
# TODO Avoid the need for 'client_cls=TelegramBareClient'
# https://core.telegram.org/cdn
# TODO Use libssl if available
cdn_aes = pyaes.AESModeOfOperationCTR(cdn_redirect.encryption_key)
# The returned IV is the counter used on CTR
cdn_aes._counter._counter = list(
cdn_redirect.encryption_iv[:12] +
(offset >> 4).to_bytes(4, 'big')
)
# Create a new client on said CDN
dc = client._get_dc(cdn_redirect.dc_id, cdn=True)
session = JsonSession(client.session)
session.server_address = dc.ip_address
session.port = dc.port
cdn_client = client_cls( # Avoid importing TelegramBareClient
session, client.api_id, client.api_hash,
timeout=client._timeout
)
# This will make use of the new RSA keys for this specific CDN
cdn_file = cdn_client.connect(initial_query=GetCdnFileRequest(
cdn_redirect.file_token, offset, part_size
))
# CDN client is ready, create the resulting CdnDecrypter
decrypter = CdnDecrypter(
cdn_client, cdn_redirect.file_token,
cdn_aes, cdn_redirect.cdn_file_hashes
)
if isinstance(cdn_file, CdnFileReuploadNeeded):
# We need to use the original client here
client(ReuploadCdnFileRequest(
file_token=cdn_redirect.file_token,
request_token=cdn_file.request_token
))
# We want to always return a valid upload.CdnFile
cdn_file = decrypter.get_file(offset, part_size)
else:
cdn_file.bytes = decrypter.cdn_aes.encrypt(cdn_file.bytes)
decrypter.check(offset, cdn_file.bytes)
return decrypter, cdn_file
def get_file(self, offset, limit):
"""Calls GetCdnFileRequest and decrypts its bytes.
Also ensures that the file hasn't been tampered.
"""
result = self.client(GetCdnFileRequest(self.file_token, offset, limit))
result.bytes = self.cdn_aes.encrypt(result.bytes)
self.check(offset, result.bytes)
return result
def check(self, offset, data):
"""Checks the integrity of the given data"""
for cdn_hash, sha in zip(self.cdn_file_hashes, self.shaes):
inter = self.intersect(
cdn_hash.offset, cdn_hash.offset + cdn_hash.limit,
offset, offset + len(data)
)
if inter:
x1, x2 = inter[0] - offset, inter[1] - offset
sha.update(data[x1:x2])
elif offset > cdn_hash.offset:
if cdn_hash.hash == sha.digest():
self.cdn_file_hashes.remove(cdn_hash)
self.shaes.remove(sha)
else:
raise CdnFileTamperedError()
def finish_check(self):
"""Similar to the check method, but for all unchecked hashes"""
for cdn_hash, sha in zip(self.cdn_file_hashes, self.shaes):
if cdn_hash.hash != sha.digest():
raise CdnFileTamperedError()
self.cdn_file_hashes.clear()
self.shaes.clear()
@staticmethod
def intersect(x1, x2, z1, z2):
if x1 > z1:
return None if x1 > z2 else (x1, min(x2, z2))
else:
return (z1, min(x2, z2)) if x2 > z1 else None

View File

@ -1,39 +0,0 @@
from hashlib import sha256
from ..errors import CdnFileTamperedError
class HashChecker:
def __init__(self, cdn_file_hashes):
self.cdn_file_hashes = cdn_file_hashes
self.shaes = [sha256() for _ in range(len(cdn_file_hashes))]
def check(self, offset, data):
for cdn_hash, sha in zip(self.cdn_file_hashes, self.shaes):
inter = self.intersect(
cdn_hash.offset, cdn_hash.offset + cdn_hash.limit,
offset, offset + len(data)
)
if inter:
x1, x2 = inter[0] - offset, inter[1] - offset
sha.update(data[x1:x2])
elif offset > cdn_hash.offset:
if cdn_hash.hash == sha.digest():
self.cdn_file_hashes.remove(cdn_hash)
self.shaes.remove(sha)
else:
raise CdnFileTamperedError()
def finish_check(self):
for cdn_hash, sha in zip(self.cdn_file_hashes, self.shaes):
if cdn_hash.hash != sha.digest():
raise CdnFileTamperedError()
self.cdn_file_hashes.clear()
self.shaes.clear()
@staticmethod
def intersect(x1, x2, z1, z2):
if x1 > z1:
return None if x1 > z2 else (x1, min(x2, z2))
else:
return (z1, min(x2, z2)) if x2 > z1 else None

View File

@ -12,7 +12,7 @@ from .errors import (
) )
from .network import authenticator, MtProtoSender, TcpTransport from .network import authenticator, MtProtoSender, TcpTransport
from .utils import get_appropriated_part_size from .utils import get_appropriated_part_size
from .crypto import rsa, HashChecker from .crypto import rsa, CdnDecrypter
# For sending and receiving requests # For sending and receiving requests
from .tl import TLObject, JsonSession from .tl import TLObject, JsonSession
@ -298,21 +298,6 @@ class TelegramBareClient:
self._cached_clients[dc_id] = client self._cached_clients[dc_id] = client
return client return client
def _get_cdn_client(self, dc_id, query):
"""_get_exported_client counterpart for CDNs.
Returns a tuple of (client, query result)
"""
dc = self._get_dc(dc_id, cdn=True)
session = JsonSession(self.session)
session.server_address = dc.ip_address
session.port = dc.port
client = TelegramBareClient(
session, self.api_id, self.api_hash,
timeout=self._timeout
)
# This will make use of the new RSA keys for this specific CDN
return client, client.connect(initial_query=query)
# endregion # endregion
# region Invoking Telegram requests # region Invoking Telegram requests
@ -485,38 +470,25 @@ class TelegramBareClient:
try: try:
offset_index = 0 offset_index = 0
cdn_file_token = None cdn_decrypter = None
hash_checker = None
def encrypt_method(x):
return x # Defaults to no-op
while True: while True:
offset = offset_index * part_size offset = offset_index * part_size
try: try:
if cdn_file_token: if cdn_decrypter:
result = client(GetCdnFileRequest( result = cdn_decrypter.get_file(offset, part_size)
cdn_file_token, offset, part_size
))
else: else:
result = client(GetFileRequest( result = client(GetFileRequest(
input_location, offset, part_size input_location, offset, part_size
)) ))
if isinstance(result, FileCdnRedirect): if isinstance(result, FileCdnRedirect):
cdn_file_token = result.file_token cdn_decrypter, result = \
hash_checker = HashChecker( CdnDecrypter.prepare_decrypter(
result.cdn_file_hashes client, TelegramBareClient, result,
offset, part_size
) )
client, encrypt_method, result = \
self._prepare_cdn_redirect(
result, offset, part_size
)
if result is None:
# File was not ready on the CDN yet
continue
except FileMigrateError as e: except FileMigrateError as e:
client = self._get_exported_client(e.new_dc) client = self._get_exported_client(e.new_dc)
@ -527,12 +499,11 @@ class TelegramBareClient:
# If we have received no data (0 bytes), the file is over # If we have received no data (0 bytes), the file is over
# So there is nothing left to download and write # So there is nothing left to download and write
if not result.bytes: if not result.bytes:
# Return some extra information, unless it's a cdn file if cdn_decrypter:
hash_checker.finish_check() cdn_decrypter.finish_check()
return getattr(result, 'type', '')
result.bytes = encrypt_method(result.bytes) # Return some extra information, unless it's a CDN file
hash_checker.check(offset, result.bytes) return getattr(result, 'type', '')
f.write(result.bytes) f.write(result.bytes)
if progress_callback: if progress_callback:
@ -541,35 +512,4 @@ class TelegramBareClient:
if isinstance(file, str): if isinstance(file, str):
f.close() f.close()
def _prepare_cdn_redirect(self, cdn_redirect, offset, part_size):
"""Returns (client, encrypt_method, result)"""
# https://core.telegram.org/cdn
# TODO Use libssl if available
cdn_aes = pyaes.AESModeOfOperationCTR(
cdn_redirect.encryption_key
)
# The returned IV is the counter used on CTR
cdn_aes._counter._counter = list(
cdn_redirect.encryption_iv[:12] +
(offset >> 4).to_bytes(4, 'big')
)
client, cdn_file = self._get_cdn_client(
cdn_redirect.dc_id,
GetCdnFileRequest(
cdn_redirect.file_token, offset, part_size
)
)
if isinstance(cdn_file, CdnFileReuploadNeeded):
# We need to use the original client here
self(ReuploadCdnFileRequest(
file_token=cdn_redirect.file_token,
request_token=cdn_file.request_token
))
return client, cdn_aes.encrypt, None
else:
# We have the first bytes for the file
return client, cdn_aes.encrypt, cdn_file
# endregion # endregion