2023-06-19 18:03:21 -07:00
|
|
|
import asyncio
|
|
|
|
import asyncio.streams
|
2023-12-12 12:40:25 -08:00
|
|
|
import dataclasses
|
2023-06-19 19:40:41 -07:00
|
|
|
import functools
|
2023-06-19 18:03:21 -07:00
|
|
|
import logging
|
2023-06-19 19:15:08 -07:00
|
|
|
from typing import Iterable
|
|
|
|
|
|
|
|
from proto import control_pb2
|
2023-06-19 18:03:21 -07:00
|
|
|
|
|
|
|
LOGGER = logging.getLogger("server")
|
|
|
|
|
2023-06-19 19:15:08 -07:00
|
|
|
MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc15"
|
|
|
|
SUPPORTED_FORMATS = [
|
|
|
|
"protobuf",
|
|
|
|
"json",
|
|
|
|
]
|
|
|
|
|
2023-12-12 12:40:25 -08:00
|
|
|
SCHEMA_TYPE_TO_PYTHON_TYPE = {
|
|
|
|
control_pb2.SchemaEntry.Type.STRING: str,
|
|
|
|
control_pb2.SchemaEntry.Type.INT: int,
|
|
|
|
}
|
|
|
|
|
2023-06-19 19:15:08 -07:00
|
|
|
class Connection:
|
|
|
|
def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str):
|
|
|
|
self.reader = reader
|
|
|
|
self.writer = writer
|
|
|
|
self.namepsace = namespace
|
|
|
|
self.app_name = app_name
|
|
|
|
self.app_version = app_version
|
|
|
|
self.chosen_format = chosen_format
|
|
|
|
self.client_key = client_key
|
2023-06-19 18:10:47 -07:00
|
|
|
|
2023-06-19 18:03:21 -07:00
|
|
|
def main():
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
2023-06-19 19:15:08 -07:00
|
|
|
try:
|
|
|
|
asyncio.run(run())
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
LOGGER.info("Exiting")
|
2023-06-19 18:03:21 -07:00
|
|
|
|
2023-06-19 19:40:41 -07:00
|
|
|
async def on_connect(command_queue: asyncio.Queue, reader, writer):
|
2023-06-19 18:03:21 -07:00
|
|
|
LOGGER.info("connected")
|
2023-06-19 19:15:08 -07:00
|
|
|
try:
|
|
|
|
connection = await handshake(reader, writer)
|
|
|
|
while True:
|
|
|
|
data = await reader.read(4096)
|
2023-06-19 19:29:53 -07:00
|
|
|
if not data:
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
continue
|
2023-06-19 19:15:08 -07:00
|
|
|
LOGGER.debug("Got data: '%s'", data)
|
|
|
|
command = control_pb2.Command.FromString(data)
|
2023-12-12 12:40:25 -08:00
|
|
|
await command_queue.put((command, write))
|
2023-06-19 19:15:08 -07:00
|
|
|
except ClientError:
|
|
|
|
LOGGER.info("Client error")
|
|
|
|
|
|
|
|
class ClientError(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
async def handshake(reader, writer) -> Connection:
|
|
|
|
data = await reader.read(36 + 1)
|
2023-06-19 18:10:47 -07:00
|
|
|
magic = data.decode("UTF-8")
|
|
|
|
LOGGER.debug("Received magic '%s'", magic)
|
2023-06-19 19:15:08 -07:00
|
|
|
if magic != MAGIC + " ":
|
|
|
|
await _write_error(writer, 1, "Magic not recognized")
|
2023-06-19 18:10:47 -07:00
|
|
|
LOGGER.debug("Magic looks good.")
|
2023-06-19 19:15:08 -07:00
|
|
|
data = await reader.read(1024)
|
|
|
|
client_hand = data.decode("UTF-8")
|
|
|
|
client_formats, namespace, app_name, app_version, client_key = client_hand.split(" ")
|
|
|
|
chosen_format = _select_format(client_formats.split(","), SUPPORTED_FORMATS)
|
|
|
|
if not chosen_format:
|
|
|
|
await _write_error(writer, 2, "server does not support any of " + client_formats)
|
|
|
|
target_url = "127.0.0.1:9988"
|
|
|
|
server_pub_key = "fakeserverkey"
|
|
|
|
server_hand = " ".join([
|
|
|
|
chosen_format,
|
|
|
|
target_url,
|
|
|
|
server_pub_key,
|
|
|
|
])
|
|
|
|
writer.write(server_hand.encode("UTF-8"))
|
|
|
|
LOGGER.info("Sending %s", server_hand)
|
|
|
|
return Connection(
|
|
|
|
reader, writer, namespace, app_name, app_version, chosen_format, client_key)
|
|
|
|
|
2023-12-12 12:40:25 -08:00
|
|
|
class DataManager:
|
|
|
|
def __init__(self):
|
|
|
|
self.schema = {}
|
|
|
|
self.process_task = None
|
|
|
|
|
|
|
|
def process(self, command_queue: asyncio.Queue) -> None:
|
|
|
|
self.process_task = asyncio.create_task(self._processor(command_queue))
|
|
|
|
|
|
|
|
async def _processor(self, command_queue: asyncio.Queue) -> None:
|
|
|
|
while True:
|
|
|
|
command, writer = await command_queue.get()
|
|
|
|
if command:
|
|
|
|
LOGGER.info("Processing %s", command)
|
|
|
|
if command.type == control_pb2.Command.Type.SCHEMA_WRITE:
|
|
|
|
_write_schema(command.table_name, command.schema_entry)
|
|
|
|
result = control_pb2.CommandResponse(
|
|
|
|
id=command.id,
|
|
|
|
|
|
|
|
writer.write(
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
|
|
def _write_schema(self, table_name: str, schema_entries: Iterable[control_pb2.SchemaEntry]) -> None:
|
|
|
|
table_schema = self.schema.get(table_name, {})
|
|
|
|
# Probably should check I'm not overwriting schema
|
|
|
|
for entry in schema_entries:
|
|
|
|
table_schema[entry.name] = SCHEMA_TYPE_TO_PYTHON_TYPE[entry.type]
|
|
|
|
|
|
|
|
def write_schema(table_name: str, schema_entries) -> None:
|
|
|
|
LOGGER.info("Writing schumea: %s", schema_entries)
|
|
|
|
|
2023-06-19 19:40:41 -07:00
|
|
|
|
2023-06-19 18:03:21 -07:00
|
|
|
async def run():
|
2023-06-19 19:40:41 -07:00
|
|
|
command_queue = asyncio.Queue()
|
2023-12-12 12:40:25 -08:00
|
|
|
data_manager = DataManager()
|
|
|
|
data_manager.process(command_queue)
|
2023-06-19 19:40:41 -07:00
|
|
|
client_handler = functools.partial(on_connect, command_queue)
|
|
|
|
server = await asyncio.start_server(client_handler, host="localhost", port=9988)
|
2023-06-19 18:03:21 -07:00
|
|
|
async with server:
|
2023-06-19 18:10:47 -07:00
|
|
|
try:
|
|
|
|
await server.serve_forever()
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
LOGGER.info("Exiting at user request.")
|
2023-06-19 18:03:21 -07:00
|
|
|
|
2023-06-19 19:15:08 -07:00
|
|
|
def _select_format(client_formats: Iterable[str], supported_formats: Iterable[str]) -> str:
|
|
|
|
"Pick a format to use with the client."
|
|
|
|
for f in supported_formats:
|
|
|
|
if f in client_formats:
|
|
|
|
return f
|
|
|
|
|
|
|
|
async def _write_error(writer, code: int, message: str) -> None:
|
|
|
|
await writer.write(f"ERR{code}: {message}".encode("UTF-8"))
|
|
|
|
writer.close()
|
|
|
|
LOGGER.info("Client error: %d %s", code, message)
|
|
|
|
raise ClientError("Failed")
|
|
|
|
|
|
|
|
|
2023-06-19 18:03:21 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|