2016-09-04 22:07:09 +03:00
|
|
|
from binascii import crc32
|
2016-10-03 10:53:41 +03:00
|
|
|
from datetime import timedelta
|
|
|
|
|
2016-09-17 21:42:34 +03:00
|
|
|
from telethon.errors import *
|
2016-11-30 00:29:42 +03:00
|
|
|
from telethon.network import TcpClient
|
2016-09-17 21:42:34 +03:00
|
|
|
from telethon.utils import BinaryWriter
|
2016-08-26 13:58:53 +03:00
|
|
|
|
|
|
|
|
|
|
|
class TcpTransport:
|
2017-03-20 19:16:34 +03:00
|
|
|
def __init__(self, ip_address, port, proxy=None):
|
2017-04-29 12:07:32 +03:00
|
|
|
self.ip = ip_address
|
|
|
|
self.port = port
|
2017-03-20 19:16:34 +03:00
|
|
|
self.tcp_client = TcpClient(proxy)
|
2016-09-08 17:11:37 +03:00
|
|
|
self.send_counter = 0
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2017-04-29 12:07:32 +03:00
|
|
|
def connect(self):
|
|
|
|
"""Connects to the specified IP address and port"""
|
|
|
|
self.send_counter = 0
|
|
|
|
self.tcp_client.connect(self.ip, self.port)
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-09-08 17:11:37 +03:00
|
|
|
# Original reference: https://core.telegram.org/mtproto#tcp-transport
|
|
|
|
# The packets are encoded as: total length, sequence number, packet and checksum (CRC32)
|
2016-08-26 13:58:53 +03:00
|
|
|
def send(self, packet):
|
2016-08-28 14:43:00 +03:00
|
|
|
"""Sends the given packet (bytes array) to the connected peer"""
|
2016-09-08 17:11:37 +03:00
|
|
|
if not self.tcp_client.connected:
|
2016-08-26 13:58:53 +03:00
|
|
|
raise ConnectionError('Client not connected to server.')
|
|
|
|
|
2016-09-08 17:11:37 +03:00
|
|
|
with BinaryWriter() as writer:
|
|
|
|
writer.write_int(len(packet) + 12) # 12 = size_of (integer) * 3
|
|
|
|
writer.write_int(self.send_counter)
|
|
|
|
writer.write(packet)
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-09-08 17:11:37 +03:00
|
|
|
crc = crc32(writer.get_bytes())
|
|
|
|
writer.write_int(crc, signed=False)
|
2016-11-30 00:29:42 +03:00
|
|
|
|
2016-09-08 17:11:37 +03:00
|
|
|
self.send_counter += 1
|
2016-10-09 06:10:41 +03:00
|
|
|
self.tcp_client.write(writer.get_bytes())
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-10-03 10:53:41 +03:00
|
|
|
def receive(self, timeout=timedelta(seconds=5)):
|
|
|
|
"""Receives a TCP message (tuple(sequence number, body)) from the connected peer.
|
|
|
|
There is a default timeout of 5 seconds before the operation is cancelled.
|
|
|
|
Timeout can be set to None for no timeout"""
|
2016-08-28 14:43:00 +03:00
|
|
|
|
2016-09-08 17:11:37 +03:00
|
|
|
# First read everything we need
|
2016-10-03 10:53:41 +03:00
|
|
|
packet_length_bytes = self.tcp_client.read(4, timeout)
|
2016-09-03 11:54:58 +03:00
|
|
|
packet_length = int.from_bytes(packet_length_bytes, byteorder='little')
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-10-03 10:53:41 +03:00
|
|
|
seq_bytes = self.tcp_client.read(4, timeout)
|
2016-09-03 11:54:58 +03:00
|
|
|
seq = int.from_bytes(seq_bytes, byteorder='little')
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-10-03 10:53:41 +03:00
|
|
|
body = self.tcp_client.read(packet_length - 12, timeout)
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-11-30 00:29:42 +03:00
|
|
|
checksum = int.from_bytes(
|
|
|
|
self.tcp_client.read(4, timeout), byteorder='little', signed=False)
|
2016-08-26 13:58:53 +03:00
|
|
|
|
|
|
|
# Then perform the checks
|
|
|
|
rv = packet_length_bytes + seq_bytes + body
|
2016-09-04 22:07:09 +03:00
|
|
|
valid_checksum = crc32(rv)
|
2016-08-26 13:58:53 +03:00
|
|
|
|
|
|
|
if checksum != valid_checksum:
|
2016-09-05 19:35:12 +03:00
|
|
|
raise InvalidChecksumError(checksum, valid_checksum)
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-09-08 17:11:37 +03:00
|
|
|
# If we passed the tests, we can then return a valid TCP message
|
|
|
|
return seq, body
|
2016-08-26 13:58:53 +03:00
|
|
|
|
2016-09-06 19:54:49 +03:00
|
|
|
def close(self):
|
2017-05-20 12:34:23 +03:00
|
|
|
self.tcp_client.close()
|
2016-09-09 12:47:37 +03:00
|
|
|
|
|
|
|
def cancel_receive(self):
|
2016-09-10 11:17:15 +03:00
|
|
|
"""Cancels (stops) trying to receive from the
|
|
|
|
remote peer and raises a ReadCancelledError"""
|
2016-09-09 12:47:37 +03:00
|
|
|
self.tcp_client.cancel_read()
|
|
|
|
|
|
|
|
def get_client_delay(self):
|
|
|
|
"""Gets the client read delay"""
|
|
|
|
return self.tcp_client.delay
|