import asyncio import asyncio.streams import dataclasses import functools import logging from typing import Iterable from proto import control_pb2 LOGGER = logging.getLogger("server") MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc15" SUPPORTED_FORMATS = [ "protobuf", "json", ] SCHEMA_TYPE_TO_PYTHON_TYPE = { control_pb2.SchemaEntry.Type.STRING: str, control_pb2.SchemaEntry.Type.INT: int, } class Connection: def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str): self.reader = reader self.writer = writer self.namepsace = namespace self.app_name = app_name self.app_version = app_version self.chosen_format = chosen_format self.client_key = client_key def main(): logging.basicConfig(level=logging.DEBUG) try: asyncio.run(run()) except KeyboardInterrupt: LOGGER.info("Exiting") async def on_connect(command_queue: asyncio.Queue, reader, writer): LOGGER.info("connected") try: connection = await handshake(reader, writer) while True: data = await reader.read(4096) if not data: await asyncio.sleep(0) continue LOGGER.debug("Got data: '%s'", data) command = control_pb2.Command.FromString(data) await command_queue.put((command, write)) except ClientError: LOGGER.info("Client error") class ClientError(Exception): pass async def handshake(reader, writer) -> Connection: data = await reader.read(36 + 1) magic = data.decode("UTF-8") LOGGER.debug("Received magic '%s'", magic) if magic != MAGIC + " ": await _write_error(writer, 1, "Magic not recognized") LOGGER.debug("Magic looks good.") data = await reader.read(1024) client_hand = data.decode("UTF-8") client_formats, namespace, app_name, app_version, client_key = client_hand.split(" ") chosen_format = _select_format(client_formats.split(","), SUPPORTED_FORMATS) if not chosen_format: await _write_error(writer, 2, "server does not support any of " + client_formats) target_url = "127.0.0.1:9988" server_pub_key = "fakeserverkey" server_hand = " ".join([ chosen_format, target_url, server_pub_key, ]) writer.write(server_hand.encode("UTF-8")) LOGGER.info("Sending %s", server_hand) return Connection( reader, writer, namespace, app_name, app_version, chosen_format, client_key) class DataManager: def __init__(self): self.schema = {} self.process_task = None def process(self, command_queue: asyncio.Queue) -> None: self.process_task = asyncio.create_task(self._processor(command_queue)) async def _processor(self, command_queue: asyncio.Queue) -> None: while True: command, writer = await command_queue.get() if command: LOGGER.info("Processing %s", command) if command.type == control_pb2.Command.Type.SCHEMA_WRITE: _write_schema(command.table_name, command.schema_entry) result = control_pb2.CommandResponse( id=command.id, writer.write( await asyncio.sleep(0) def _write_schema(self, table_name: str, schema_entries: Iterable[control_pb2.SchemaEntry]) -> None: table_schema = self.schema.get(table_name, {}) # Probably should check I'm not overwriting schema for entry in schema_entries: table_schema[entry.name] = SCHEMA_TYPE_TO_PYTHON_TYPE[entry.type] def write_schema(table_name: str, schema_entries) -> None: LOGGER.info("Writing schumea: %s", schema_entries) async def run(): command_queue = asyncio.Queue() data_manager = DataManager() data_manager.process(command_queue) client_handler = functools.partial(on_connect, command_queue) server = await asyncio.start_server(client_handler, host="localhost", port=9988) async with server: try: await server.serve_forever() except KeyboardInterrupt: LOGGER.info("Exiting at user request.") def _select_format(client_formats: Iterable[str], supported_formats: Iterable[str]) -> str: "Pick a format to use with the client." for f in supported_formats: if f in client_formats: return f async def _write_error(writer, code: int, message: str) -> None: await writer.write(f"ERR{code}: {message}".encode("UTF-8")) writer.close() LOGGER.info("Client error: %d %s", code, message) raise ClientError("Failed") if __name__ == "__main__": main()