diff --git a/projects/miragesdk/src/sdk/ctl.ml b/projects/miragesdk/src/sdk/ctl.ml index 4bad89563..7b092ff4d 100644 --- a/projects/miragesdk/src/sdk/ctl.ml +++ b/projects/miragesdk/src/sdk/ctl.ml @@ -1,4 +1,5 @@ open Lwt.Infix +open Astring let src = Logs.Src.create "init" ~doc:"Init steps" module Log = (val Logs.src_log src : Logs.LOG) @@ -28,7 +29,9 @@ let v path = let () = Irmin.Private.Watch.set_listen_dir_hook Irmin_watcher.hook -module Message = struct +module Query = struct + + (* FIXME: this should probably be replaced by protobuf *) [%%cenum type operation = @@ -39,58 +42,68 @@ module Message = struct ] type t = { + version : int32; + id : int32; operation: operation; path : string; payload : string; } - [%%cstruct type message = { - operation : uint8_t; (* = type operation *) + [%%cstruct type msg = { + version : uint32_t; (* protocol version *) + id : uint32_t; (* session identifier *) + operation : uint8_t; (* = type operation *) path : uint16_t; - payload : uint16_t; + payload : uint32_t; } [@@little_endian] ] (* to avoid warning 32 *) - let _ = hexdump_message + let _ = hexdump_msg let _ = string_to_operation let pp ppf t = - Fmt.pf ppf "%s:%S:%S" (operation_to_string t.operation) t.path t.payload + Fmt.pf ppf "%ld:%s:%S:%S" + t.id (operation_to_string t.operation) t.path t.payload (* FIXME: allocate less ... *) let of_cstruct buf = - Log.debug (fun l -> l "Message.of_cstruct %S" @@ Cstruct.to_string buf); - let operation = match int_to_operation (get_message_operation buf) with - | None -> failwith "invalid operation" - | Some o -> o - in - let path_len = get_message_path buf in - let payload_len = get_message_payload buf in - Log.debug (fun l -> l "XXX path=%d len=%d" path_len payload_len); + let open Rresult.R in + Log.debug (fun l -> l "Query.of_cstruct %S" @@ Cstruct.to_string buf); + let version = get_msg_version buf in + let id = get_msg_id buf in + (match int_to_operation (get_msg_operation buf) with + | None -> Error (`Msg "invalid operation") + | Some o -> Ok o) + >>= fun operation -> + let path_len = get_msg_path buf in + let payload_len = get_msg_payload buf in let path = - Cstruct.sub buf sizeof_message path_len + Cstruct.sub buf sizeof_msg path_len |> Cstruct.to_string in let payload = - Cstruct.sub buf (sizeof_message+path_len) payload_len + Cstruct.sub buf (sizeof_msg + path_len) (Int32.to_int payload_len) |> Cstruct.to_string in - { operation; path; payload } + if String.Ascii.is_valid path then Ok { version; id; operation; path; payload } + else Error (`Msg "invalid path") let to_cstruct msg = - Log.debug (fun l -> l "Message.to_cstruct %a" pp msg); + Log.debug (fun l -> l "Query.to_cstruct %a" pp msg); let operation = operation_to_int msg.operation in let path = String.length msg.path in let payload = String.length msg.payload in - let len = sizeof_message + path + payload in + let len = sizeof_msg + path + payload in let buf = Cstruct.create len in - set_message_operation buf operation; - set_message_path buf path; - set_message_payload buf payload; - Cstruct.blit_from_bytes msg.path 0 buf sizeof_message path; - Cstruct.blit_from_bytes msg.payload 0 buf (sizeof_message+path) payload; + set_msg_version buf msg.version; + set_msg_id buf msg.id; + set_msg_operation buf operation; + set_msg_path buf path; + set_msg_payload buf (Int32.of_int payload); + Cstruct.blit_from_bytes msg.path 0 buf sizeof_msg path; + Cstruct.blit_from_bytes msg.payload 0 buf (sizeof_msg+path) payload; buf let read fd = @@ -115,14 +128,178 @@ module Message = struct end -module Dispatch = struct +module Reply = struct - open Message + (* FIXME: this should probably be replaced by protobuf *) - let with_key msg f = - match KV.Key.of_string msg.path with + [%%cenum + type status = + | Ok + | Error + [@@uint8_t] + ] + + type t = { + id : int32; + status : status; + payload: string; + } + + [%%cstruct type msg = { + id : uint32_t; (* session identifier *) + status : uint8_t; (* = type operation *) + payload: uint32_t; + } [@@little_endian] + ] + + (* to avoid warning 32 *) + let _ = hexdump_msg + let _ = string_to_status + + let pp ppf t = + Fmt.pf ppf "%ld:%s:%S" t.id (status_to_string t.status) t.payload + + (* FIXME: allocate less ... *) + + let of_cstruct buf = + let open Rresult.R in + Log.debug (fun l -> l "Message.of_cstruct %S" @@ Cstruct.to_string buf); + let id = get_msg_id buf in + (match int_to_status (get_msg_status buf) with + | None -> Error (`Msg "invalid operation") + | Some o -> Ok o) + >>= fun status -> + let payload_len = Int32.to_int (get_msg_payload buf) in + let payload = + Cstruct.sub buf sizeof_msg payload_len + |> Cstruct.to_string + in + Ok { id; status; payload } + + let to_cstruct msg = + Log.debug (fun l -> l "Message.to_cstruct %a" pp msg); + let status = status_to_int msg.status in + let payload = String.length msg.payload in + let len = sizeof_msg + payload in + let buf = Cstruct.create len in + set_msg_id buf msg.id; + set_msg_status buf status; + set_msg_payload buf (Int32.of_int payload); + Cstruct.blit_from_bytes msg.payload 0 buf sizeof_msg payload; + buf + + let read fd = + IO.read_n fd 4 >>= fun buf -> + Log.debug (fun l -> l "Message.read len=%S" buf); + let len = + Cstruct.LE.get_uint32 (Cstruct.of_string buf) 0 + |> Int32.to_int + in + IO.read_n fd len >|= fun buf -> + of_cstruct (Cstruct.of_string buf) + + let write fd msg = + let buf = to_cstruct msg |> Cstruct.to_string in + let len = + let len = Cstruct.create 4 in + Cstruct.LE.set_uint32 len 0 (Int32.of_int @@ String.length buf); + Cstruct.to_string len + in + IO.write fd len >>= fun () -> + IO.write fd buf + +end + +let err_not_found = "err-not-found" + +module Client = struct + + let new_id = + let n = ref 0l in + fun () -> n := Int32.succ !n; !n + + let version = 0l + + module K = struct + type t = int32 + let equal = Int32.equal + let hash = Hashtbl.hash + end + module Cache = Hashtbl.Make(K) + + type t = { + fd : Lwt_unix.file_descr; + replies: Reply.t Cache.t; + } + + let v fd = { fd; replies = Cache.create 12 } + + let call t query = + let id = query.Query.id in + Query.write t.fd query >>= fun () -> + let rec loop () = + try + let r = Cache.find t.replies id in + Cache.remove t.replies id; + Lwt.return r + with Not_found -> + Reply.read t.fd >>= function + | Error (`Msg e) -> + Log.err (fun l -> l "Got %s while waiting for a reply to %ld" e id); + loop () + | Ok r -> + if r.id = id then Lwt.return r + else ( + (* FIXME: maybe we want to check if id is not already + allocated *) + Cache.add t.replies r.id r; + loop () + ) + in + loop () >|= fun r -> + assert (r.Reply.id = id); + match r.Reply.status with + | Ok -> Ok r.Reply.payload + | Error -> Error (`Msg r.Reply.payload) + + let query operation path payload = + let id = new_id () in + { Query.version; id; operation; path; payload } + + let read t path = + call t (query Read path "") >|= function + | Ok x -> Ok (Some x) + | Error (`Msg e) -> + if e = err_not_found then Ok None + else Error (`Msg e) + + let write t path v = + call t (query Write path v) >|= function + | Ok "" -> Ok () + | Ok _ -> Error (`Msg "invalid return") + | Error _ as e -> e + + let delete t path = + call t (query Delete path "") >|= function + | Ok "" -> Ok () + | Ok _ -> Error (`Msg "invalid return") + | Error _ as e -> e + +end + +module Server = struct + + let ok q payload = + { Reply.id = q.Query.id; status = Reply.Ok; payload } + + let error q payload = + { Reply.id = q.Query.id; status = Reply.Error; payload } + + let with_key q f = + match KV.Key.of_string q.Query.path with | Ok x -> f x - | Error (`Msg e) -> Fmt.kstrf Lwt.fail_with "invalid key: %s" e + | Error (`Msg e) -> + Fmt.kstrf (fun msg -> Lwt.return (error q msg)) "invalid key: %s" e let infof fmt = Fmt.kstrf (fun msg () -> @@ -130,31 +307,56 @@ module Dispatch = struct Irmin.Info.v ~date ~author:"calf" msg ) fmt - let dispatch db msg = - with_key msg (fun key -> - match msg.operation with + let dispatch db q = + with_key q (fun key -> + match q.Query.operation with | Write -> let info = infof "Updating %a" KV.Key.pp key in - KV.set db ~info key msg.payload - | _ -> failwith "TODO" + KV.set db ~info key q.payload >|= fun () -> + ok q "" + | Delete -> + let info = infof "Removing %a" KV.Key.pp key in + KV.remove db ~info key >|= fun () -> + ok q "" + | Read -> + KV.find db key >|= function + | None -> error q err_not_found + | Some v -> ok q v ) - let serve fd db ~routes = - let msgs = Queue.create () in + + let int_of_fd (t:Lwt_unix.file_descr) = + (Obj.magic (Lwt_unix.unix_file_descr t): int) + + let listen ~routes db fd = + Lwt_unix.blocking fd >>= fun blocking -> + Log.debug (fun l -> + l "Serving the control state over fd:%d (blocking=%b)" + (int_of_fd fd) blocking + ); + let queries = Queue.create () in let cond = Lwt_condition.create () in let rec listen () = - Message.read fd >>= fun msg -> - Queue.add msg msgs; - Lwt_condition.signal cond (); - listen () + Query.read fd >>= function + | Error (`Msg e) -> + Log.err (fun l -> l "received invalid message: %s" e); + listen () + | Ok q -> + Queue.add q queries; + Lwt_condition.signal cond (); + listen () in let rec process () = Lwt_condition.wait cond >>= fun () -> - let msg = Queue.pop msgs in - (if List.mem msg.path routes then dispatch db msg - else ( - Log.err (fun l -> l "%s is not an allowed path" msg.path); - Lwt.return_unit; + let q = Queue.pop queries in + let path = q.Query.path in + (if List.mem path routes then ( + dispatch db q >>= fun r -> + Reply.write fd r + ) else ( + let err = Fmt.strf "%s is not an allowed path" path in + Log.err (fun l -> l "%ld: %s" q.Query.id path); + Reply.write fd (error q err) )) >>= fun () -> process () in @@ -164,14 +366,3 @@ module Dispatch = struct ] end - -let int_of_fd (t:Lwt_unix.file_descr) = - (Obj.magic (Lwt_unix.unix_file_descr t): int) - -let serve ~routes db fd = - Lwt_unix.blocking fd >>= fun blocking -> - Log.debug (fun l -> - l "Serving the control state over fd:%d (blocking=%b)" - (int_of_fd fd) blocking - ); - Dispatch.serve fd db ~routes diff --git a/projects/miragesdk/src/sdk/ctl.mli b/projects/miragesdk/src/sdk/ctl.mli index 452e65cca..0dad5aa8e 100644 --- a/projects/miragesdk/src/sdk/ctl.mli +++ b/projects/miragesdk/src/sdk/ctl.mli @@ -1,9 +1,7 @@ (** [Control] handle the server part of the control path, running in the privileged container. *) -module KV: Irmin.KV with type contents = string - -module Message: sig +module Query: sig (** The type for operations. *) type operation = @@ -11,39 +9,109 @@ module Message: sig | Read | Delete - (** The type for control messages. *) + (** The type for control plane queries. *) type t = { + version : int32; (** Protocol version. *) + id : int32; (** Session identifier. *) operation: operation; - path : string; - payload : string; + path : string; (** Should be only valid ASCII. *) + payload : string; (** Arbitrary payload. *) } val pp: t Fmt.t - (** [pp] is the pretty-printer for messages. *) + (** [pp] is the pretty-printer for queries. *) - val of_cstruct: Cstruct.t -> t - (** [of_cstruct buf] is the message [t] such that the serialization - of [t] is [buf]. *) + val of_cstruct: Cstruct.t -> (t, [`Msg of string]) result + (** [of_cstruct buf] is the query [t] such that the serialization of + [t] is [buf]. *) val to_cstruct: t -> Cstruct.t (** [to_cstruct t] is the serialization of [t]. *) val write: Lwt_unix.file_descr -> t -> unit Lwt.t - (** [write fd t] writes a control message. *) + (** [write fd t] writes a query message. *) - val read: Lwt_unix.file_descr -> t Lwt.t - (** [read fd] reads a control message. *) + val read: Lwt_unix.file_descr -> (t, [`Msg of string]) result Lwt.t + (** [read fd] reads a query message. *) end -val v: string -> KV.t Lwt.t -(** [v p] is the KV store storing the control state, located at path - [p] in the filesystem of the privileged container. *) +module Reply: sig -val serve: routes:string list -> KV.t -> Lwt_unix.file_descr -> unit Lwt.t -(** [serve ~routes kv fd] is the thread exposing the KV store [kv], - holding control state, running inside the privileged container. - [routes] are the routes exposed by the server (currently over a - simple HTTP server -- but will change to something else later, - probably protobuf) to the calf and [kv] is the control state - handler. *) + (** The type for status. *) + type status = + | Ok + | Error + + (** The type for control plane replies. *) + type t = { + id : int32; (** Session identifier. *) + status : status; (** Status of the operation. *) + payload: string; (** Arbitrary payload. *) + } + + val pp: t Fmt.t + (** [pp] is the pretty-printer for replies. *) + + val of_cstruct: Cstruct.t -> (t, [`Msg of string]) result + (** [of_cstruct buf] is the reply [t] such that the serialization of + [t] is [buf]. *) + + val to_cstruct: t -> Cstruct.t + (** [to_cstruct t] is the serialization of [t]. *) + + val write: Lwt_unix.file_descr -> t -> unit Lwt.t + (** [write fd t] writes a reply message. *) + + val read: Lwt_unix.file_descr -> (t, [`Msg of string]) result Lwt.t + (** [read fd] reads a reply message. *) + +end + +module Client: sig + + (** Client-side of the control plane. The control plane state is a + simple KV store that the client can query with read/write/delete + operations. + + TODO: decide if we want to support test_and_set (instead of + write) and some kind of watches. *) + + type t + (** The type for client state. *) + + val v: Lwt_unix.file_descr -> t + (** [v fd] is the client state using [fd] to send requests to the + server. A client state also stores some state for all the + incomplete client queries. *) + + val read: t -> string -> (string option, [`Msg of string]) result Lwt.t + (** [read t k] is the value associated with the key [k] in the + control plane state. Return [None] if no value is associated to + [k]. *) + + val write: t -> string -> string -> (unit, [`Msg of string]) result Lwt.t + (** [write t p v] associates [v] to the key [k] in the control plane + state. *) + + val delete: t -> string -> (unit, [`Msg of string]) result Lwt.t + (** [delete t k] remove [k]'s binding in the control plane state. *) + +end + +(** [KV] stores tje control plane state. *) +module KV: Irmin.KV with type contents = string + +val v: string -> KV.t Lwt.t +(** [v p] is the KV store storing the control plane state, located at + path [p] in the filesystem of the privileged container. *) + +module Server: sig + + val listen: routes:string list -> KV.t -> Lwt_unix.file_descr -> unit Lwt.t + (** [listen ~routes kv fd] is the thread exposing the KV store [kv], + holding control plane state, running inside the privileged + container. [routes] are the routes exposed by the server to the + calf and [kv] is the control plane state. *) + +end diff --git a/projects/miragesdk/src/sdk/jbuild b/projects/miragesdk/src/sdk/jbuild index 53403e2ad..e357caa71 100644 --- a/projects/miragesdk/src/sdk/jbuild +++ b/projects/miragesdk/src/sdk/jbuild @@ -4,6 +4,6 @@ ((name sdk) (libraries (threads cstruct.lwt cmdliner fmt.cli logs.fmt logs.cli fmt.tty decompress irmin irmin-git lwt.unix rawlink tuntap dispatch - irmin-watcher inotify)) + irmin-watcher inotify astring rresult)) (preprocess (per_file ((pps (cstruct.ppx)) (ctl)))) )) diff --git a/projects/miragesdk/src/test/test.ml b/projects/miragesdk/src/test/test.ml index 925ed7c04..5876f1805 100644 --- a/projects/miragesdk/src/test/test.ml +++ b/projects/miragesdk/src/test/test.ml @@ -87,34 +87,147 @@ let test_socketpair pipe () = Lwt.return_unit -let message = Alcotest.testable Ctl.Message.pp (=) +let query = Alcotest.testable Ctl.Query.pp (=) +let reply = Alcotest.testable Ctl.Reply.pp (=) -let test_message_serialization () = - let test m = - let buf = Ctl.Message.to_cstruct m in - let m' = Ctl.Message.of_cstruct buf in - Alcotest.(check message) "to_cstruct/of_cstruct" m m' - in - List.iter test [ - { operation = Read ; path = "/foo/bar"; payload = "" }; - { operation = Write ; path = "" ; payload = "foo" }; - { operation = Delete; path = "" ; payload = "" }; - { operation = Delete; path = "foo" ; payload = "foo" }; +let queries = + let open Ctl.Query in + [ + { version = 0l; id = 0l; operation = Read; path = "/foo/bar"; payload = "" }; + { version = Int32.max_int; id = Int32.max_int; operation = Write ; path = ""; payload = "foo" }; + { version = 1l;id = 0l; operation = Delete; path = ""; payload = "" }; + { version = -2l; id = -3l; operation = Delete; path = "foo"; payload = "foo" }; ] -let test_message_send () = +let replies = + let open Ctl.Reply in + [ + { id = 0l; status = Ok; payload = "" }; + { id = Int32.max_int; status = Ok; payload = "foo" }; + { id = 0l; status = Error; payload = "" }; + { id = -3l; status = Error; payload = "foo" }; + ] + +let test_serialization to_cstruct of_cstruct message messages = + let test m = + let buf = to_cstruct m in + match of_cstruct buf with + | Ok m' -> Alcotest.(check message) "to_cstruct/of_cstruct" m m' + | Error (`Msg e) -> Alcotest.fail ("Message.of_cstruct: " ^ e) + in + List.iter test messages + +let test_send write read message messages = let calf = Init.Fd.fd @@ Init.Pipe.(calf ctl) in let priv = Init.Fd.fd @@ Init.Pipe.(priv ctl) in let test m = - Ctl.Message.write calf m >>= fun () -> - Ctl.Message.read priv >|= fun m' -> - Alcotest.(check message) "write/read" m m' + write calf m >>= fun () -> + read priv >|= function + | Ok m' -> Alcotest.(check message) "write/read" m m' + | Error (`Msg e) -> Alcotest.fail ("Message.read: " ^ e) in - Lwt_list.iter_s test [ - { operation = Read ; path = "/foo/bar"; payload = "" }; - { operation = Write ; path = "" ; payload = "foo" }; - { operation = Delete; path = "" ; payload = "" }; - { operation = Delete; path = "foo" ; payload = "foo" }; + Lwt_list.iter_s test messages + +let test_query_serialization () = + let open Ctl.Query in + test_serialization to_cstruct of_cstruct query queries + +let test_reply_serialization () = + let open Ctl.Reply in + test_serialization to_cstruct of_cstruct reply replies + +let test_query_send () = + let open Ctl.Query in + test_send write read query queries + +let test_reply_send () = + let open Ctl.Reply in + test_send write read reply replies + +let failf fmt = Fmt.kstrf Alcotest.fail fmt + +(* read ops *) + +let read_should_err t k = + Ctl.Client.read t k >|= function + | Error (`Msg _) -> () + | Ok None -> failf "read(%s) -> got: none, expected: err" k + | Ok Some v -> failf "read(%s) -> got: found:%S, expected: err" k v + +let read_should_none t k = + Ctl.Client.read t k >|= function + | Error (`Msg e) -> failf "read(%s) -> got: error:%s, expected none" k e + | Ok None -> () + | Ok Some v -> failf "read(%s) -> got: found:%S, expected none" k v + +let read_should_work t k v = + Ctl.Client.read t k >|= function + | Error (`Msg e) -> failf "read(%s) -> got: error:%s, expected ok" k e + | Ok None -> failf "read(%s) -> got: none, expected ok" k + | Ok Some v' -> + if v <> v' then failf "read(%s) -> got: ok:%S, expected: ok:%S" k v' v + +(* write ops *) + +let write_should_err t k v = + Ctl.Client.write t k v >|= function + | Ok () -> failf "write(%s) -> ok" k + | Error _ -> () + +let write_should_work t k v = + Ctl.Client.write t k v >|= function + | Ok () -> () + | Error (`Msg e) -> failf "write(%s) -> error: %s" k e + +(* del ops *) + +let delete_should_err t k = + Ctl.Client.delete t k >|= function + | Ok () -> failf "del(%s) -> ok" k + | Error _ -> () + +let delete_should_work t k = + Ctl.Client.delete t k >|= function + | Ok () -> () + | Error (`Msg e) -> failf "write(%s) -> error: %s" k e + +let test_ctl () = + let calf = Init.Fd.fd @@ Init.Pipe.(calf ctl) in + let priv = Init.Fd.fd @@ Init.Pipe.(priv ctl) in + let k1 = "/foo/bar" in + let k2 = "a" in + let k3 = "b/c" in + let k4 = "xxxxxx" in + let routes = [k1; k2; k3] in + let git_root = "/tmp/sdk/ctl" in + let _ = Sys.command (Fmt.strf "rm -rf %s" git_root) in + Ctl.v git_root >>= fun ctl -> + let server () = Ctl.Server.listen ~routes ctl priv in + let client () = + let t = Ctl.Client.v calf in + let allowed k v = + delete_should_work t k >>= fun () -> + read_should_none t k >>= fun () -> + write_should_work t k v >>= fun () -> + read_should_work t k v >>= fun () -> + let path = String.cuts ~empty:false ~sep:"/" k in + Ctl.KV.get ctl path >|= fun v' -> + Alcotest.(check string) "in the db" v v' + in + let disallowed k v = + read_should_err t k >>= fun () -> + write_should_err t k v >>= fun () -> + delete_should_err t k + in + allowed k1 "" >>= fun () -> + allowed k2 "xxx" >>= fun () -> + allowed k3 (random_string (255 * 1024)) >>= fun () -> + disallowed k4 "" >>= fun () -> + Lwt.return_unit + in + Lwt.pick [ + client (); + server (); ] let run f () = @@ -130,8 +243,11 @@ let test = [ "stdout is a pipe" , `Quick, run (test_pipe Init.Pipe.stderr); "net is a socket pair", `Quick, run (test_socketpair Init.Pipe.net); "ctl is a socket pair", `Quick, run (test_socketpair Init.Pipe.ctl); - "seralize messages" , `Quick, test_message_serialization; - "send messages" , `Quick, run test_message_send; + "seralize queries" , `Quick, test_query_serialization; + "seralize replies" , `Quick, test_reply_serialization; + "send queries" , `Quick, run test_query_send; + "send replies" , `Quick, run test_reply_send; + "ctl" , `Quick, run test_ctl; ] let reporter ?(prefix="") () =