diff --git a/client.py b/client.py index 858fbdb..60636de 100644 --- a/client.py +++ b/client.py @@ -6,6 +6,7 @@ import datajack #@datajack.Table class Todo: name: str + index: int def main(): @@ -16,7 +17,14 @@ def main(): logging.basicConfig(level=logging.DEBUG) with datajack.connection(args.connection_uri) as dj: - dj.send("hi".encode("UTF-8")) + dj.write_schema( + "todo", + { + "name": str, + "index": int, + }, + ) + if __name__ == "__main__": main() diff --git a/control.proto b/control.proto index f7cd620..047b38c 100644 --- a/control.proto +++ b/control.proto @@ -2,6 +2,15 @@ syntax = "proto3"; package datafortress; +message SchemaEntry { + enum Type { + STRING = 0; + INT = 1; + } + Type type = 1; + string name = 2; +} + // Indicates a command message Command { enum Type { @@ -14,6 +23,8 @@ message Command { } Type type = 1; + + optional string table_name = 2; + repeated SchemaEntry schema_entry = 3; } - diff --git a/datajack/__init__.py b/datajack/__init__.py index eeb8505..6ffe2f3 100644 --- a/datajack/__init__.py +++ b/datajack/__init__.py @@ -1,7 +1,7 @@ import enum import logging import socket -from typing import Tuple +from typing import Mapping, Tuple, Type import urllib.parse from proto import control_pb2 @@ -21,6 +21,11 @@ DATA_FORMATS = ",".join([ "json", ]) +PYTHON_TYPE_TO_SCHEMA_TYPE = { + str: control_pb2.SchemaEntry.STRING, + int: control_pb2.SchemaEntry.INT, +} + # The types of access that are allowed with this connection class Permissions(enum.Enum): pass @@ -62,6 +67,18 @@ class Connection: def send(self, data): self.socket.send(data) + def write_schema(self, table_name: str, schema: Mapping[str, Type]) -> None: + "Send a command to write the given schema." + command = control_pb2.Command( + type=control_pb2.Command.Type.SCHEMA_WRITE, + table_name=table_name, + ) + for k, v in schema.items(): + entry = command.schema_entry.add() + entry.name = k + entry.type = PYTHON_TYPE_TO_SCHEMA_TYPE[v] + self.socket.send(command.SerializeToString()) + def _handshake(self): "Handshake with the server, ensure we have all the data we need." fields = [MAGIC, DATA_FORMATS, self.namespace, self.app_name, self.app_version, self.public_key,] diff --git a/server.py b/server.py index f15a55d..2d171a9 100644 --- a/server.py +++ b/server.py @@ -36,10 +36,12 @@ async def on_connect(reader, writer): connection = await handshake(reader, writer) while True: data = await reader.read(4096) + if not data: + await asyncio.sleep(0) + continue LOGGER.debug("Got data: '%s'", data) command = control_pb2.Command.FromString(data) LOGGER.info("Command: %s", command) - await asyncio.sleep(1) except ClientError: LOGGER.info("Client error")