From aadc8520bb2c26fd35f6a45b8b34b56a9939dfc2 Mon Sep 17 00:00:00 2001 From: Eli Ribble Date: Mon, 19 Jun 2023 19:15:08 -0700 Subject: [PATCH] Working protobuf passing. This is barely working, really. In fact, I have an "unused" field in my enum because if I don't prowide it the protobuf serializes down to an empty buffer which breaks all kinds of things. --- README.md | 12 ++++++- client.py | 8 +++++ control.proto | 19 +++++++++++ datajack/__init__.py | 59 ++++++++++++++++++++++++++++++-- main.go | 15 ++++++-- server.py | 81 ++++++++++++++++++++++++++++++++++++++------ 6 files changed, 179 insertions(+), 15 deletions(-) create mode 100644 control.proto diff --git a/README.md b/README.md index 9a538b5..adeec2a 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,13 @@ The point is can we create something that is database-like without being a tradi * JSON, probably >:P 1. Data is never deleted. Time-travel is built-in * Kinda like Datomic? - + +## Incantations + +``` +protoc --python_out proto control.proto +``` + ## Protocol concerns * Built-in behavior for dealing with a cluster @@ -69,10 +75,14 @@ The point is can we create something that is database-like without being a tradi * Figure out the most efficient path between the client and server to avoid unnecessary network hops. * Is this necessary, or does the operating system do this for us? * This includes redirecting to another process or server in load-balanced applications. + * I think we don't need to specify that we are waiting for user authorization if we don't want to, we can just let the client queue up requests and wait to confirm them. ## Client features * Query data lots of neat ways * Subscribe to updates of a particular query, get pushed data on those updates + * Transactions + * Explicit ordering and unordering of reads/writes + * Multiplexed comms to allow parallel reads/writes over a single connection. ## Server features * Triggers? diff --git a/client.py b/client.py index 75bc147..858fbdb 100644 --- a/client.py +++ b/client.py @@ -1,12 +1,20 @@ import argparse +import logging import datajack +#@datajack.Table +class Todo: + name: str + + def main(): parser = argparse.ArgumentParser() parser.add_argument("connection_uri", help="The URI to use to connect to the remove datajack.") args = parser.parse_args() + logging.basicConfig(level=logging.DEBUG) + with datajack.connection(args.connection_uri) as dj: dj.send("hi".encode("UTF-8")) diff --git a/control.proto b/control.proto new file mode 100644 index 0000000..f7cd620 --- /dev/null +++ b/control.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package datafortress; + +// Indicates a command +message Command { + enum Type { + UNUSED = 0; + STATUS = 1; + SCHEMA_READ = 2; + SCHEMA_WRITE = 3; + QUERY = 4; + WRITE = 5; + } + + Type type = 1; +} + + diff --git a/datajack/__init__.py b/datajack/__init__.py index 3cd619b..eeb8505 100644 --- a/datajack/__init__.py +++ b/datajack/__init__.py @@ -1,12 +1,48 @@ +import enum +import logging import socket +from typing import Tuple import urllib.parse +from proto import control_pb2 + +LOGGER = logging.getLogger(__name__) + +# Uniquely identifies a datajack connection. +# This indicates the handshake style expected and is used +# to quickly and easily determine if the client is even speaking the +# right protocol for the server. +MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc15" + +# The list of formats supported by this client in descending order of preference +# This is used for negotiation of the data format that will be used with the server. +DATA_FORMATS = ",".join([ + "protobuf", + "json", +]) + +# The types of access that are allowed with this connection +class Permissions(enum.Enum): + pass + +def _parse_netloc(netloc) -> Tuple[str, str, str, int]: + app, _, connection = netloc.partition("@") + appname, _, version = app.partition(":") + host, _, part = connection.partition(":") + return appname, version, host, part + class Connection: def __init__(self, uri): parts = urllib.parse.urlparse(uri) + print(parts) netloc = parts.netloc - self.host, _, self.port = netloc.partition(":") - + assert parts.scheme == "socket" + self.app_name, self.app_version, self.host, self.port = _parse_netloc(netloc) + self.namespace = parts.path.lstrip("/") + self.public_key = "pretend_key" + self.data_format = None + self.address = None + self.server_key = None def __enter__(self): self.connect() @@ -18,6 +54,7 @@ class Connection: def connect(self): self.socket = socket.socket() self.socket.connect((self.host, int(self.port))) + self._handshake() def disconnect(self): pass @@ -25,5 +62,23 @@ class Connection: def send(self, data): self.socket.send(data) + def _handshake(self): + "Handshake with the server, ensure we have all the data we need." + fields = [MAGIC, DATA_FORMATS, self.namespace, self.app_name, self.app_version, self.public_key,] + cliend_hand = " ".join(fields) + self.socket.send(cliend_hand.encode("UTF-8")) + server_hand = self.socket.recv(1024) + if not server_hand: + print("Failed to get server hand") + self.data_format, self.address, self.server_key = server_hand.decode("UTF-8").split(" ") + LOGGER.info("Data format: %s", self.data_format) + command = control_pb2.Command( + type=control_pb2.Command.Type.STATUS + ) + to_send = command.SerializeToString() + LOGGER.info("Sending '%s'", to_send) + self.socket.send(to_send) + + def connection(uri) -> Connection: return Connection(uri) diff --git a/main.go b/main.go index 9b40c8e..2a62580 100644 --- a/main.go +++ b/main.go @@ -4,12 +4,15 @@ import ( "fmt" "net" "os" + "strings" ) const ( SERVER_HOST = "localhost" SERVER_PORT = "9988" SERVER_TYPE = "tcp" + MAGIC_HEADER = "e8437140-4347-48cc-a31d-dcdc944ffc15" ) + func main() { fmt.Println("Server Running...") server, err := net.Listen(SERVER_TYPE, SERVER_HOST+":"+SERVER_PORT) @@ -36,7 +39,15 @@ func processClient(connection net.Conn) { if err != nil { fmt.Println("Error reading:", err.Error()) } - fmt.Println("Received: ", string(buffer[:mLen])) - _, err = connection.Write([]byte("Thanks! Got your message:" + string(buffer[:mLen]))) + // handshake + // split out the elements of the clients handshake + // read the first 40 bytes, check the magic header, reject invalids + // handle data format negotiation and respond with selected protocol + // store information about + // _, err = connection.Write([]byte("Thanks! Got your message:" + string(buffer[:mLen]))) + data := string(buffer[:mLen]) + parts := strings.Split(data, " ") + fmt.Println("Received: ", parts) + _, err = connection.Write([]byte("thanks")) connection.Close() } diff --git a/server.py b/server.py index 7f28be3..f15a55d 100644 --- a/server.py +++ b/server.py @@ -1,28 +1,76 @@ import asyncio import asyncio.streams import logging +from typing import Iterable + +from proto import control_pb2 LOGGER = logging.getLogger("server") -MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc16" +MAGIC = "e8437140-4347-48cc-a31d-dcdc944ffc15" +SUPPORTED_FORMATS = [ + "protobuf", + "json", +] + +class Connection: + def __init__(self, reader, writer, namespace: str, app_name: str, app_version: str, chosen_format: str, client_key: str): + self.reader = reader + self.writer = writer + self.namepsace = namespace + self.app_name = app_name + self.app_version = app_version + self.chosen_format = chosen_format + self.client_key = client_key def main(): logging.basicConfig(level=logging.DEBUG) - asyncio.run(run()) + try: + asyncio.run(run()) + except KeyboardInterrupt: + LOGGER.info("Exiting") async def on_connect(reader, writer): LOGGER.info("connected") - data = await reader.read(36) + try: + connection = await handshake(reader, writer) + while True: + data = await reader.read(4096) + LOGGER.debug("Got data: '%s'", data) + command = control_pb2.Command.FromString(data) + LOGGER.info("Command: %s", command) + await asyncio.sleep(1) + except ClientError: + LOGGER.info("Client error") + +class ClientError(Exception): + pass + +async def handshake(reader, writer) -> Connection: + data = await reader.read(36 + 1) magic = data.decode("UTF-8") LOGGER.debug("Received magic '%s'", magic) - if magic != MAGIC: - writer.write("ERR1: Magic not recognized".encode("UTF-8")) - writer.close() - LOGGER.info("Bad magic, closing connection.") - return + if magic != MAGIC + " ": + await _write_error(writer, 1, "Magic not recognized") LOGGER.debug("Magic looks good.") - - + data = await reader.read(1024) + client_hand = data.decode("UTF-8") + client_formats, namespace, app_name, app_version, client_key = client_hand.split(" ") + chosen_format = _select_format(client_formats.split(","), SUPPORTED_FORMATS) + if not chosen_format: + await _write_error(writer, 2, "server does not support any of " + client_formats) + target_url = "127.0.0.1:9988" + server_pub_key = "fakeserverkey" + server_hand = " ".join([ + chosen_format, + target_url, + server_pub_key, + ]) + writer.write(server_hand.encode("UTF-8")) + LOGGER.info("Sending %s", server_hand) + return Connection( + reader, writer, namespace, app_name, app_version, chosen_format, client_key) + async def run(): server = await asyncio.start_server(on_connect, host="localhost", port=9988) async with server: @@ -31,5 +79,18 @@ async def run(): except KeyboardInterrupt: LOGGER.info("Exiting at user request.") +def _select_format(client_formats: Iterable[str], supported_formats: Iterable[str]) -> str: + "Pick a format to use with the client." + for f in supported_formats: + if f in client_formats: + return f + +async def _write_error(writer, code: int, message: str) -> None: + await writer.write(f"ERR{code}: {message}".encode("UTF-8")) + writer.close() + LOGGER.info("Client error: %d %s", code, message) + raise ClientError("Failed") + + if __name__ == "__main__": main()