2023-06-19 18:03:21 -07:00
|
|
|
import asyncio
|
|
|
|
import asyncio.streams
|
|
|
|
import logging
|
2023-06-19 19:15:08 -07:00
|
|
|
from typing import Iterable
|
|
|
|
|
|
|
|
from proto import control_pb2
|
2023-06-19 18:03:21 -07:00
|
|
|
|
|
|
|
LOGGER = logging.getLogger("server")
|
|
|
|
|
2023-06-19 19:15:08 -07:00
|
|
|
MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc15"
|
|
|
|
SUPPORTED_FORMATS = [
|
|
|
|
"protobuf",
|
|
|
|
"json",
|
|
|
|
]
|
|
|
|
|
|
|
|
class Connection:
|
|
|
|
def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str):
|
|
|
|
self.reader = reader
|
|
|
|
self.writer = writer
|
|
|
|
self.namepsace = namespace
|
|
|
|
self.app_name = app_name
|
|
|
|
self.app_version = app_version
|
|
|
|
self.chosen_format = chosen_format
|
|
|
|
self.client_key = client_key
|
2023-06-19 18:10:47 -07:00
|
|
|
|
2023-06-19 18:03:21 -07:00
|
|
|
def main():
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
2023-06-19 19:15:08 -07:00
|
|
|
try:
|
|
|
|
asyncio.run(run())
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
LOGGER.info("Exiting")
|
2023-06-19 18:03:21 -07:00
|
|
|
|
|
|
|
async def on_connect(reader, writer):
|
|
|
|
LOGGER.info("connected")
|
2023-06-19 19:15:08 -07:00
|
|
|
try:
|
|
|
|
connection = await handshake(reader, writer)
|
|
|
|
while True:
|
|
|
|
data = await reader.read(4096)
|
|
|
|
LOGGER.debug("Got data: '%s'", data)
|
|
|
|
command = control_pb2.Command.FromString(data)
|
|
|
|
LOGGER.info("Command: %s", command)
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
except ClientError:
|
|
|
|
LOGGER.info("Client error")
|
|
|
|
|
|
|
|
class ClientError(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
async def handshake(reader, writer) -> Connection:
|
|
|
|
data = await reader.read(36 + 1)
|
2023-06-19 18:10:47 -07:00
|
|
|
magic = data.decode("UTF-8")
|
|
|
|
LOGGER.debug("Received magic '%s'", magic)
|
2023-06-19 19:15:08 -07:00
|
|
|
if magic != MAGIC + " ":
|
|
|
|
await _write_error(writer, 1, "Magic not recognized")
|
2023-06-19 18:10:47 -07:00
|
|
|
LOGGER.debug("Magic looks good.")
|
2023-06-19 19:15:08 -07:00
|
|
|
data = await reader.read(1024)
|
|
|
|
client_hand = data.decode("UTF-8")
|
|
|
|
client_formats, namespace, app_name, app_version, client_key = client_hand.split(" ")
|
|
|
|
chosen_format = _select_format(client_formats.split(","), SUPPORTED_FORMATS)
|
|
|
|
if not chosen_format:
|
|
|
|
await _write_error(writer, 2, "server does not support any of " + client_formats)
|
|
|
|
target_url = "127.0.0.1:9988"
|
|
|
|
server_pub_key = "fakeserverkey"
|
|
|
|
server_hand = " ".join([
|
|
|
|
chosen_format,
|
|
|
|
target_url,
|
|
|
|
server_pub_key,
|
|
|
|
])
|
|
|
|
writer.write(server_hand.encode("UTF-8"))
|
|
|
|
LOGGER.info("Sending %s", server_hand)
|
|
|
|
return Connection(
|
|
|
|
reader, writer, namespace, app_name, app_version, chosen_format, client_key)
|
|
|
|
|
2023-06-19 18:03:21 -07:00
|
|
|
async def run():
|
|
|
|
server = await asyncio.start_server(on_connect, host="localhost", port=9988)
|
|
|
|
async with server:
|
2023-06-19 18:10:47 -07:00
|
|
|
try:
|
|
|
|
await server.serve_forever()
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
LOGGER.info("Exiting at user request.")
|
2023-06-19 18:03:21 -07:00
|
|
|
|
2023-06-19 19:15:08 -07:00
|
|
|
def _select_format(client_formats: Iterable[str], supported_formats: Iterable[str]) -> str:
|
|
|
|
"Pick a format to use with the client."
|
|
|
|
for f in supported_formats:
|
|
|
|
if f in client_formats:
|
|
|
|
return f
|
|
|
|
|
|
|
|
async def _write_error(writer, code: int, message: str) -> None:
|
|
|
|
await writer.write(f"ERR{code}: {message}".encode("UTF-8"))
|
|
|
|
writer.close()
|
|
|
|
LOGGER.info("Client error: %d %s", code, message)
|
|
|
|
raise ClientError("Failed")
|
|
|
|
|
|
|
|
|
2023-06-19 18:03:21 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|