Temporary WIP for adding commands

This commit is contained in:
Eli Ribble 2023-12-12 13:40:25 -07:00
parent 61f4e2cdae
commit 090f55e225
2 changed files with 55 additions and 11 deletions

View File

@ -22,9 +22,22 @@ message Command {
WRITE = 5; WRITE = 5;
} }
Type type = 1; // unique identifier for the command supplied by the client.
// Any response to the command will reference this ID.
string id = 1;
Type type = 2;
optional string table_name = 2; optional string table_name = 3;
repeated SchemaEntry schema_entry = 3; repeated SchemaEntry schema_entry = 4;
} }
message Error {
int32 code = 1;
string message = 2;
}
message CommandResponse {
string id = 1;
bool success = 2;
optional Error error = 3;
}

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import asyncio.streams import asyncio.streams
import dataclasses
import functools import functools
import logging import logging
from typing import Iterable from typing import Iterable
@ -14,6 +15,11 @@ SUPPORTED_FORMATS = [
"json", "json",
] ]
SCHEMA_TYPE_TO_PYTHON_TYPE = {
control_pb2.SchemaEntry.Type.STRING: str,
control_pb2.SchemaEntry.Type.INT: int,
}
class Connection: class Connection:
def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str): def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str):
self.reader = reader self.reader = reader
@ -42,7 +48,7 @@ async def on_connect(command_queue: asyncio.Queue, reader, writer):
continue continue
LOGGER.debug("Got data: '%s'", data) LOGGER.debug("Got data: '%s'", data)
command = control_pb2.Command.FromString(data) command = control_pb2.Command.FromString(data)
await command_queue.put(command) await command_queue.put((command, write))
except ClientError: except ClientError:
LOGGER.info("Client error") LOGGER.info("Client error")
@ -74,16 +80,41 @@ async def handshake(reader, writer) -> Connection:
return Connection( return Connection(
reader, writer, namespace, app_name, app_version, chosen_format, client_key) reader, writer, namespace, app_name, app_version, chosen_format, client_key)
async def processor(command_queue): class DataManager:
def __init__(self):
self.schema = {}
self.process_task = None
def process(self, command_queue: asyncio.Queue) -> None:
self.process_task = asyncio.create_task(self._processor(command_queue))
async def _processor(self, command_queue: asyncio.Queue) -> None:
while True: while True:
command = await command_queue.get() command, writer = await command_queue.get()
if command: if command:
LOGGER.info("Processing %s", command) LOGGER.info("Processing %s", command)
if command.type == control_pb2.Command.Type.SCHEMA_WRITE:
_write_schema(command.table_name, command.schema_entry)
result = control_pb2.CommandResponse(
id=command.id,
writer.write(
await asyncio.sleep(0) await asyncio.sleep(0)
def _write_schema(self, table_name: str, schema_entries: Iterable[control_pb2.SchemaEntry]) -> None:
table_schema = self.schema.get(table_name, {})
# Probably should check I'm not overwriting schema
for entry in schema_entries:
table_schema[entry.name] = SCHEMA_TYPE_TO_PYTHON_TYPE[entry.type]
def write_schema(table_name: str, schema_entries) -> None:
LOGGER.info("Writing schumea: %s", schema_entries)
async def run(): async def run():
command_queue = asyncio.Queue() command_queue = asyncio.Queue()
command_processor = asyncio.create_task(processor(command_queue)) data_manager = DataManager()
data_manager.process(command_queue)
client_handler = functools.partial(on_connect, command_queue) client_handler = functools.partial(on_connect, command_queue)
server = await asyncio.start_server(client_handler, host="localhost", port=9988) server = await asyncio.start_server(client_handler, host="localhost", port=9988)
async with server: async with server: