import asyncio import asyncio.streams import logging from typing import Iterable from proto import control_pb2 LOGGER = logging.getLogger("server") MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc15" SUPPORTED_FORMATS = [ "protobuf", "json", ] class Connection: def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str): self.reader = reader self.writer = writer self.namepsace = namespace self.app_name = app_name self.app_version = app_version self.chosen_format = chosen_format self.client_key = client_key def main(): logging.basicConfig(level=logging.DEBUG) try: asyncio.run(run()) except KeyboardInterrupt: LOGGER.info("Exiting") async def on_connect(reader, writer): LOGGER.info("connected") try: connection = await handshake(reader, writer) while True: data = await reader.read(4096) if not data: await asyncio.sleep(0) continue LOGGER.debug("Got data: '%s'", data) command = control_pb2.Command.FromString(data) LOGGER.info("Command: %s", command) except ClientError: LOGGER.info("Client error") class ClientError(Exception): pass async def handshake(reader, writer) -> Connection: data = await reader.read(36 + 1) magic = data.decode("UTF-8") LOGGER.debug("Received magic '%s'", magic) if magic != MAGIC + " ": await _write_error(writer, 1, "Magic not recognized") LOGGER.debug("Magic looks good.") data = await reader.read(1024) client_hand = data.decode("UTF-8") client_formats, namespace, app_name, app_version, client_key = client_hand.split(" ") chosen_format = _select_format(client_formats.split(","), SUPPORTED_FORMATS) if not chosen_format: await _write_error(writer, 2, "server does not support any of " + client_formats) target_url = "127.0.0.1:9988" server_pub_key = "fakeserverkey" server_hand = " ".join([ chosen_format, target_url, server_pub_key, ]) writer.write(server_hand.encode("UTF-8")) LOGGER.info("Sending %s", server_hand) return Connection( reader, writer, namespace, app_name, app_version, chosen_format, client_key) async def run(): server = await asyncio.start_server(on_connect, host="localhost", port=9988) async with server: try: await server.serve_forever() except KeyboardInterrupt: LOGGER.info("Exiting at user request.") def _select_format(client_formats: Iterable[str], supported_formats: Iterable[str]) -> str: "Pick a format to use with the client." for f in supported_formats: if f in client_formats: return f async def _write_error(writer, code: int, message: str) -> None: await writer.write(f"ERR{code}: {message}".encode("UTF-8")) writer.close() LOGGER.info("Client error: %d %s", code, message) raise ClientError("Failed") if __name__ == "__main__": main()