diff --git a/projects/miragesdk/examples/mirage-dhcp.yml b/projects/miragesdk/examples/mirage-dhcp.yml index 82418a0ba..f3c7cef50 100644 --- a/projects/miragesdk/examples/mirage-dhcp.yml +++ b/projects/miragesdk/examples/mirage-dhcp.yml @@ -19,7 +19,7 @@ system: command: [/usr/bin/binfmt, -dir, /etc/binfmt.d/, -mount, /binfmt_misc] - name: dhcp-client network_mode: host - image: "mobylinux/dhcp-client:f6ef2cc4c3bf7dcad643f22fbd3d355af1725105" + image: "mobylinux/dhcp-client:aaf811d77ff8d8b2e16ca4dd9d0a2849ef8977b6" capabilities: - CAP_NET_ADMIN # to bring eth0 up - CAP_NET_RAW # to read /dev/eth0 diff --git a/projects/miragesdk/src/Dockerfile.build b/projects/miragesdk/src/Dockerfile.build index 449b1b5ad..4c0cfb568 100644 --- a/projects/miragesdk/src/Dockerfile.build +++ b/projects/miragesdk/src/Dockerfile.build @@ -6,6 +6,7 @@ RUN opam pin -n add mirage-net-unix https://github.com/samoht/mirage-net-unix.gi RUN opam depext -iy mirage-net-unix logs-syslog irmin-unix cohttp decompress RUN opam depext -iy rawlink tuntap.1.0.0 jbuilder irmin-watcher inotify +RUN opam install rresult RUN sudo mkdir -p /src COPY ./sdk /src/sdk @@ -15,5 +16,7 @@ RUN sudo chown opam -R /src USER opam WORKDIR /src +RUN opam pin add cstruct --dev # for ppx/jbuilder + RUN opam config exec -- jbuilder build dhcp-client/main.exe RUN sudo cp /src/_build/default/dhcp-client/main.exe /dhcp-client diff --git a/projects/miragesdk/src/Makefile b/projects/miragesdk/src/Makefile index 88a8bc1c2..9fb07d3f0 100644 --- a/projects/miragesdk/src/Makefile +++ b/projects/miragesdk/src/Makefile @@ -1,6 +1,7 @@ BASE=ocaml/opam:alpine-3.5_ocaml-4.04.0 FILES=$(shell find . -name jbuild) \ - $(shell find sdk/ -regex '.*\.mli?') \ + $(shell find sdk/ -name '*.ml') \ + $(shell find sdk/ -name '*.mli') \ dhcp-client/bpf/dhcp.c dhcp-client/main.ml IMAGE=dhcp-client OBJS=obj/dhcp-client @@ -36,9 +37,9 @@ enter-dev: .dev $(CALF_OBJS): $(CALF_FILES) mkdir -p obj/bin ( cd obj && \ - tar -C ../calf -cf - $(CALF_FILES:calf/%=%) | \ + tar -C ../dhcp-client/calf -cf - $(CALF_FILES:dhcp-client/calf/%=%) | \ docker run --rm -i --log-driver=none $(MIRAGE_COMPILE) -o dhcp-client-calf | \ - tar xf - ) && \ + tar xf - || exit 1) && \ touch $@ $(OBJS): .build $(FILES) @@ -47,7 +48,7 @@ $(OBJS): .build $(FILES) docker run --rm --net=none --log-driver=none -i $(IMAGE):build tar -cf - $(OBJS:obj/%=/%) | tar xf - ) && \ touch $@ -hash: Dockerfile.build Dockerfile.pkg $(FILES) $(CALF_FILES) .build +hash: Makefile Dockerfile.build Dockerfile.pkg $(FILES) $(CALF_FILES) .build { cat $^; \ docker run --rm --entrypoint sh $(IMAGE):build -c 'cat /lib/apk/db/installed'; \ docker run --rm --entrypoint sh $(IMAGE):build -c 'opam list'; } \ diff --git a/projects/miragesdk/src/dhcp-client/main.ml b/projects/miragesdk/src/dhcp-client/main.ml index e88041f87..35809c587 100644 --- a/projects/miragesdk/src/dhcp-client/main.ml +++ b/projects/miragesdk/src/dhcp-client/main.ml @@ -52,7 +52,7 @@ let run () cmd ethif path = ] in Ctl.v "/data" >>= fun ctl -> let fd = Init.(Fd.fd @@ Pipe.(priv ctl)) in - let ctl () = Ctl.serve ~routes ctl fd in + let ctl () = Ctl.Server.listen ~routes ctl fd in let handlers () = Handlers.watch path in Init.run ~net ~ctl ~handlers cmd ) diff --git a/projects/miragesdk/src/sdk/IO.ml b/projects/miragesdk/src/sdk/IO.ml index 2d9345729..a03189370 100644 --- a/projects/miragesdk/src/sdk/IO.ml +++ b/projects/miragesdk/src/sdk/IO.ml @@ -4,18 +4,20 @@ let src = Logs.Src.create "IO" ~doc:"IO helpers" module Log = (val Logs.src_log src : Logs.LOG) let rec really_write fd buf off len = - Log.debug (fun l -> l "really_write"); match len with | 0 -> Lwt.return_unit | len -> + Log.debug (fun l -> l "really_write off=%d len=%d" off len); Lwt_unix.write fd buf off len >>= fun n -> really_write fd buf (off+n) (len-n) +let write fd buf = really_write fd buf 0 (String.length buf) + let rec really_read fd buf off len = - Log.debug (fun l -> l "really_read"); match len with | 0 -> Lwt.return_unit | len -> + Log.debug (fun l -> l "really_read off=%d len=%d" off len); Lwt_unix.read fd buf off len >>= fun n -> really_read fd buf (off+n) (len-n) @@ -33,7 +35,7 @@ let read_all fd = String.concat "" bufs let read_n fd len = - Log.debug (fun l -> l "read_n"); + Log.debug (fun l -> l "read_n len=%d" len); let buf = Bytes.create len in let rec loop acc len = Lwt_unix.read fd buf 0 len >>= fun n -> diff --git a/projects/miragesdk/src/sdk/IO.mli b/projects/miragesdk/src/sdk/IO.mli index 121ba33c7..b07954d18 100644 --- a/projects/miragesdk/src/sdk/IO.mli +++ b/projects/miragesdk/src/sdk/IO.mli @@ -3,6 +3,9 @@ val really_write: Lwt_unix.file_descr -> string -> int -> int -> unit Lwt.t (** [really_write fd buf off len] writes exactly [len] bytes to [fd]. *) +val write: Lwt_unix.file_descr -> string -> unit Lwt.t +(** [write fd buf] writes all the buffer [buf] in [fd]. *) + val really_read: Lwt_unix.file_descr -> string -> int -> int -> unit Lwt.t (** [really_read fd buf off len] reads exactly [len] bytes from [fd]. *) diff --git a/projects/miragesdk/src/sdk/ctl.ml b/projects/miragesdk/src/sdk/ctl.ml index d561e067b..7b092ff4d 100644 --- a/projects/miragesdk/src/sdk/ctl.ml +++ b/projects/miragesdk/src/sdk/ctl.ml @@ -1,4 +1,5 @@ open Lwt.Infix +open Astring let src = Logs.Src.create "init" ~doc:"Init steps" module Log = (val Logs.src_log src : Logs.LOG) @@ -28,7 +29,9 @@ let v path = let () = Irmin.Private.Watch.set_listen_dir_hook Irmin_watcher.hook -module Message = struct +module Query = struct + + (* FIXME: this should probably be replaced by protobuf *) [%%cenum type operation = @@ -39,73 +42,264 @@ module Message = struct ] type t = { + version : int32; + id : int32; operation: operation; path : string; - payload : string option; + payload : string; } - [%%cstruct type message = { - operation : uint8_t; (* = type operation *) + [%%cstruct type msg = { + version : uint32_t; (* protocol version *) + id : uint32_t; (* session identifier *) + operation : uint8_t; (* = type operation *) path : uint16_t; - payload : uint16_t; + payload : uint32_t; } [@@little_endian] ] (* to avoid warning 32 *) - let _ = hexdump_message - let _ = operation_to_string + let _ = hexdump_msg let _ = string_to_operation - let read_message fd = + let pp ppf t = + Fmt.pf ppf "%ld:%s:%S:%S" + t.id (operation_to_string t.operation) t.path t.payload + + (* FIXME: allocate less ... *) + + let of_cstruct buf = + let open Rresult.R in + Log.debug (fun l -> l "Query.of_cstruct %S" @@ Cstruct.to_string buf); + let version = get_msg_version buf in + let id = get_msg_id buf in + (match int_to_operation (get_msg_operation buf) with + | None -> Error (`Msg "invalid operation") + | Some o -> Ok o) + >>= fun operation -> + let path_len = get_msg_path buf in + let payload_len = get_msg_payload buf in + let path = + Cstruct.sub buf sizeof_msg path_len + |> Cstruct.to_string + in + let payload = + Cstruct.sub buf (sizeof_msg + path_len) (Int32.to_int payload_len) + |> Cstruct.to_string + in + if String.Ascii.is_valid path then Ok { version; id; operation; path; payload } + else Error (`Msg "invalid path") + + let to_cstruct msg = + Log.debug (fun l -> l "Query.to_cstruct %a" pp msg); + let operation = operation_to_int msg.operation in + let path = String.length msg.path in + let payload = String.length msg.payload in + let len = sizeof_msg + path + payload in + let buf = Cstruct.create len in + set_msg_version buf msg.version; + set_msg_id buf msg.id; + set_msg_operation buf operation; + set_msg_path buf path; + set_msg_payload buf (Int32.of_int payload); + Cstruct.blit_from_bytes msg.path 0 buf sizeof_msg path; + Cstruct.blit_from_bytes msg.payload 0 buf (sizeof_msg+path) payload; + buf + + let read fd = IO.read_n fd 4 >>= fun buf -> + Log.debug (fun l -> l "Message.read len=%S" buf); let len = Cstruct.LE.get_uint32 (Cstruct.of_string buf) 0 |> Int32.to_int in - IO.read_n fd len >>= fun buf -> - let buf = Cstruct.of_string buf in - let operation = match int_to_operation (get_message_operation buf) with - | None -> failwith "invalid operation" - | Some o -> o - in - let path_len = get_message_path buf in - let payload_len = get_message_payload buf in - IO.read_n fd path_len >>= fun path -> - (match payload_len with - | 0 -> Lwt.return None - | n -> IO.read_n fd n >|= fun x -> Some x) - >|= fun payload -> - { operation; path; payload } + IO.read_n fd len >|= fun buf -> + of_cstruct (Cstruct.of_string buf) - let write_message fd msg = - let operation = operation_to_int msg.operation in - let path = String.length msg.path in - let payload = match msg.payload with - | None -> 0 - | Some x -> String.length x + let write fd msg = + let buf = to_cstruct msg |> Cstruct.to_string in + let len = + let len = Cstruct.create 4 in + Cstruct.LE.set_uint32 len 0 (Int32.of_int @@ String.length buf); + Cstruct.to_string len in - let len = sizeof_message + path + payload in - let buf = Cstruct.create len in - set_message_operation buf operation; - set_message_path buf path; - set_message_payload buf path; - Cstruct.blit_from_bytes msg.path 0 buf sizeof_message path; - let () = match msg.payload with - | None -> () - | Some x -> Cstruct.blit_from_bytes x 0 buf (sizeof_message+path) payload - in - IO.really_write fd (Cstruct.to_string buf) 0 len + IO.write fd len >>= fun () -> + IO.write fd buf end -module Dispatch = struct +module Reply = struct - open Message + (* FIXME: this should probably be replaced by protobuf *) - let with_key msg f = - match KV.Key.of_string msg.path with + [%%cenum + type status = + | Ok + | Error + [@@uint8_t] + ] + + type t = { + id : int32; + status : status; + payload: string; + } + + [%%cstruct type msg = { + id : uint32_t; (* session identifier *) + status : uint8_t; (* = type operation *) + payload: uint32_t; + } [@@little_endian] + ] + + (* to avoid warning 32 *) + let _ = hexdump_msg + let _ = string_to_status + + let pp ppf t = + Fmt.pf ppf "%ld:%s:%S" t.id (status_to_string t.status) t.payload + + (* FIXME: allocate less ... *) + + let of_cstruct buf = + let open Rresult.R in + Log.debug (fun l -> l "Message.of_cstruct %S" @@ Cstruct.to_string buf); + let id = get_msg_id buf in + (match int_to_status (get_msg_status buf) with + | None -> Error (`Msg "invalid operation") + | Some o -> Ok o) + >>= fun status -> + let payload_len = Int32.to_int (get_msg_payload buf) in + let payload = + Cstruct.sub buf sizeof_msg payload_len + |> Cstruct.to_string + in + Ok { id; status; payload } + + let to_cstruct msg = + Log.debug (fun l -> l "Message.to_cstruct %a" pp msg); + let status = status_to_int msg.status in + let payload = String.length msg.payload in + let len = sizeof_msg + payload in + let buf = Cstruct.create len in + set_msg_id buf msg.id; + set_msg_status buf status; + set_msg_payload buf (Int32.of_int payload); + Cstruct.blit_from_bytes msg.payload 0 buf sizeof_msg payload; + buf + + let read fd = + IO.read_n fd 4 >>= fun buf -> + Log.debug (fun l -> l "Message.read len=%S" buf); + let len = + Cstruct.LE.get_uint32 (Cstruct.of_string buf) 0 + |> Int32.to_int + in + IO.read_n fd len >|= fun buf -> + of_cstruct (Cstruct.of_string buf) + + let write fd msg = + let buf = to_cstruct msg |> Cstruct.to_string in + let len = + let len = Cstruct.create 4 in + Cstruct.LE.set_uint32 len 0 (Int32.of_int @@ String.length buf); + Cstruct.to_string len + in + IO.write fd len >>= fun () -> + IO.write fd buf + +end + +let err_not_found = "err-not-found" + +module Client = struct + + let new_id = + let n = ref 0l in + fun () -> n := Int32.succ !n; !n + + let version = 0l + + module K = struct + type t = int32 + let equal = Int32.equal + let hash = Hashtbl.hash + end + module Cache = Hashtbl.Make(K) + + type t = { + fd : Lwt_unix.file_descr; + replies: Reply.t Cache.t; + } + + let v fd = { fd; replies = Cache.create 12 } + + let call t query = + let id = query.Query.id in + Query.write t.fd query >>= fun () -> + let rec loop () = + try + let r = Cache.find t.replies id in + Cache.remove t.replies id; + Lwt.return r + with Not_found -> + Reply.read t.fd >>= function + | Error (`Msg e) -> + Log.err (fun l -> l "Got %s while waiting for a reply to %ld" e id); + loop () + | Ok r -> + if r.id = id then Lwt.return r + else ( + (* FIXME: maybe we want to check if id is not already + allocated *) + Cache.add t.replies r.id r; + loop () + ) + in + loop () >|= fun r -> + assert (r.Reply.id = id); + match r.Reply.status with + | Ok -> Ok r.Reply.payload + | Error -> Error (`Msg r.Reply.payload) + + let query operation path payload = + let id = new_id () in + { Query.version; id; operation; path; payload } + + let read t path = + call t (query Read path "") >|= function + | Ok x -> Ok (Some x) + | Error (`Msg e) -> + if e = err_not_found then Ok None + else Error (`Msg e) + + let write t path v = + call t (query Write path v) >|= function + | Ok "" -> Ok () + | Ok _ -> Error (`Msg "invalid return") + | Error _ as e -> e + + let delete t path = + call t (query Delete path "") >|= function + | Ok "" -> Ok () + | Ok _ -> Error (`Msg "invalid return") + | Error _ as e -> e + +end + +module Server = struct + + let ok q payload = + { Reply.id = q.Query.id; status = Reply.Ok; payload } + + let error q payload = + { Reply.id = q.Query.id; status = Reply.Error; payload } + + let with_key q f = + match KV.Key.of_string q.Query.path with | Ok x -> f x - | Error (`Msg e) -> Fmt.kstrf Lwt.fail_with "invalid key: %s" e + | Error (`Msg e) -> + Fmt.kstrf (fun msg -> Lwt.return (error q msg)) "invalid key: %s" e let infof fmt = Fmt.kstrf (fun msg () -> @@ -113,33 +307,56 @@ module Dispatch = struct Irmin.Info.v ~date ~author:"calf" msg ) fmt - let dispatch db msg = - with_key msg (fun key -> - match msg.operation with + let dispatch db q = + with_key q (fun key -> + match q.Query.operation with | Write -> let info = infof "Updating %a" KV.Key.pp key in - (match msg.payload with - | None -> Fmt.kstrf Lwt.fail_with "dispatch: missing payload" - | Some v -> KV.set db ~info key v) - | _ -> failwith "TODO" + KV.set db ~info key q.payload >|= fun () -> + ok q "" + | Delete -> + let info = infof "Removing %a" KV.Key.pp key in + KV.remove db ~info key >|= fun () -> + ok q "" + | Read -> + KV.find db key >|= function + | None -> error q err_not_found + | Some v -> ok q v ) - let serve fd db ~routes = - let msgs = Queue.create () in + + let int_of_fd (t:Lwt_unix.file_descr) = + (Obj.magic (Lwt_unix.unix_file_descr t): int) + + let listen ~routes db fd = + Lwt_unix.blocking fd >>= fun blocking -> + Log.debug (fun l -> + l "Serving the control state over fd:%d (blocking=%b)" + (int_of_fd fd) blocking + ); + let queries = Queue.create () in let cond = Lwt_condition.create () in let rec listen () = - read_message fd >>= fun msg -> - Queue.add msg msgs; - Lwt_condition.signal cond (); - listen () + Query.read fd >>= function + | Error (`Msg e) -> + Log.err (fun l -> l "received invalid message: %s" e); + listen () + | Ok q -> + Queue.add q queries; + Lwt_condition.signal cond (); + listen () in let rec process () = Lwt_condition.wait cond >>= fun () -> - let msg = Queue.pop msgs in - (if List.mem msg.path routes then dispatch db msg - else ( - Log.err (fun l -> l "%s is not an allowed path" msg.path); - Lwt.return_unit; + let q = Queue.pop queries in + let path = q.Query.path in + (if List.mem path routes then ( + dispatch db q >>= fun r -> + Reply.write fd r + ) else ( + let err = Fmt.strf "%s is not an allowed path" path in + Log.err (fun l -> l "%ld: %s" q.Query.id path); + Reply.write fd (error q err) )) >>= fun () -> process () in @@ -149,14 +366,3 @@ module Dispatch = struct ] end - -let int_of_fd (t:Lwt_unix.file_descr) = - (Obj.magic (Lwt_unix.unix_file_descr t): int) - -let serve ~routes db fd = - Lwt_unix.blocking fd >>= fun blocking -> - Log.debug (fun l -> - l "Serving the control state over fd:%d (blocking=%b)" - (int_of_fd fd) blocking - ); - Dispatch.serve fd db ~routes diff --git a/projects/miragesdk/src/sdk/ctl.mli b/projects/miragesdk/src/sdk/ctl.mli index 5af101be8..0dad5aa8e 100644 --- a/projects/miragesdk/src/sdk/ctl.mli +++ b/projects/miragesdk/src/sdk/ctl.mli @@ -1,9 +1,7 @@ (** [Control] handle the server part of the control path, running in the privileged container. *) -module KV: Irmin.KV with type contents = string - -module Message: sig +module Query: sig (** The type for operations. *) type operation = @@ -11,29 +9,109 @@ module Message: sig | Read | Delete - (** The type for control messages. *) + (** The type for control plane queries. *) type t = { + version : int32; (** Protocol version. *) + id : int32; (** Session identifier. *) operation: operation; - path : string; - payload : string option; + path : string; (** Should be only valid ASCII. *) + payload : string; (** Arbitrary payload. *) } - val write_message: Lwt_unix.file_descr -> t -> unit Lwt.t - (** [write_message fd t] writes a control message. *) + val pp: t Fmt.t + (** [pp] is the pretty-printer for queries. *) - val read_message: Lwt_unix.file_descr -> t Lwt.t - (** [read_message fd] reads a control message. *) + val of_cstruct: Cstruct.t -> (t, [`Msg of string]) result + (** [of_cstruct buf] is the query [t] such that the serialization of + [t] is [buf]. *) + + val to_cstruct: t -> Cstruct.t + (** [to_cstruct t] is the serialization of [t]. *) + + val write: Lwt_unix.file_descr -> t -> unit Lwt.t + (** [write fd t] writes a query message. *) + + val read: Lwt_unix.file_descr -> (t, [`Msg of string]) result Lwt.t + (** [read fd] reads a query message. *) end -val v: string -> KV.t Lwt.t -(** [v p] is the KV store storing the control state, located at path - [p] in the filesystem of the privileged container. *) +module Reply: sig -val serve: routes:string list -> KV.t -> Lwt_unix.file_descr -> unit Lwt.t -(** [serve ~routes kv fd] is the thread exposing the KV store [kv], - holding control state, running inside the privileged container. - [routes] are the routes exposed by the server (currently over a - simple HTTP server -- but will change to something else later, - probably protobuf) to the calf and [kv] is the control state - handler. *) + (** The type for status. *) + type status = + | Ok + | Error + + (** The type for control plane replies. *) + type t = { + id : int32; (** Session identifier. *) + status : status; (** Status of the operation. *) + payload: string; (** Arbitrary payload. *) + } + + val pp: t Fmt.t + (** [pp] is the pretty-printer for replies. *) + + val of_cstruct: Cstruct.t -> (t, [`Msg of string]) result + (** [of_cstruct buf] is the reply [t] such that the serialization of + [t] is [buf]. *) + + val to_cstruct: t -> Cstruct.t + (** [to_cstruct t] is the serialization of [t]. *) + + val write: Lwt_unix.file_descr -> t -> unit Lwt.t + (** [write fd t] writes a reply message. *) + + val read: Lwt_unix.file_descr -> (t, [`Msg of string]) result Lwt.t + (** [read fd] reads a reply message. *) + +end + +module Client: sig + + (** Client-side of the control plane. The control plane state is a + simple KV store that the client can query with read/write/delete + operations. + + TODO: decide if we want to support test_and_set (instead of + write) and some kind of watches. *) + + type t + (** The type for client state. *) + + val v: Lwt_unix.file_descr -> t + (** [v fd] is the client state using [fd] to send requests to the + server. A client state also stores some state for all the + incomplete client queries. *) + + val read: t -> string -> (string option, [`Msg of string]) result Lwt.t + (** [read t k] is the value associated with the key [k] in the + control plane state. Return [None] if no value is associated to + [k]. *) + + val write: t -> string -> string -> (unit, [`Msg of string]) result Lwt.t + (** [write t p v] associates [v] to the key [k] in the control plane + state. *) + + val delete: t -> string -> (unit, [`Msg of string]) result Lwt.t + (** [delete t k] remove [k]'s binding in the control plane state. *) + +end + +(** [KV] stores tje control plane state. *) +module KV: Irmin.KV with type contents = string + +val v: string -> KV.t Lwt.t +(** [v p] is the KV store storing the control plane state, located at + path [p] in the filesystem of the privileged container. *) + +module Server: sig + + val listen: routes:string list -> KV.t -> Lwt_unix.file_descr -> unit Lwt.t + (** [listen ~routes kv fd] is the thread exposing the KV store [kv], + holding control plane state, running inside the privileged + container. [routes] are the routes exposed by the server to the + calf and [kv] is the control plane state. *) + +end diff --git a/projects/miragesdk/src/sdk/init.ml b/projects/miragesdk/src/sdk/init.ml index b9e2cd7be..b77e48f88 100644 --- a/projects/miragesdk/src/sdk/init.ml +++ b/projects/miragesdk/src/sdk/init.ml @@ -78,9 +78,7 @@ module Fd = struct let buf = Bytes.create len in let rec loop () = Lwt_unix.read src.fd buf 0 len >>= fun len -> - if len = 0 then - (* FIXME: why this ever happen *) - Fmt.kstrf Lwt.fail_with "FORWARD[%a => %a]: EOF" pp src pp dst + if len = 0 then Lwt.return_unit (* EOF *) else ( Log.debug (fun l -> l "FORWARD[%a => %a]: %S (%d)" @@ -97,6 +95,8 @@ module Pipe = struct type t = Fd.t * Fd.t + let name (x, _) = x.Fd.name + let priv = fst let calf = snd diff --git a/projects/miragesdk/src/sdk/init.mli b/projects/miragesdk/src/sdk/init.mli index fc50863cd..1e2828a26 100644 --- a/projects/miragesdk/src/sdk/init.mli +++ b/projects/miragesdk/src/sdk/init.mli @@ -63,6 +63,9 @@ module Pipe: sig (** The type for pipes. Could be either uni-directional (normal pipes) or a bi-directional (socket pairs). *) + val name: t -> string + (** [name t] is [t]'s name. *) + val priv: t -> Fd.t (** [priv p] is the private side of the pipe [p]. *) diff --git a/projects/miragesdk/src/sdk/jbuild b/projects/miragesdk/src/sdk/jbuild index 53403e2ad..e357caa71 100644 --- a/projects/miragesdk/src/sdk/jbuild +++ b/projects/miragesdk/src/sdk/jbuild @@ -4,6 +4,6 @@ ((name sdk) (libraries (threads cstruct.lwt cmdliner fmt.cli logs.fmt logs.cli fmt.tty decompress irmin irmin-git lwt.unix rawlink tuntap dispatch - irmin-watcher inotify)) + irmin-watcher inotify astring rresult)) (preprocess (per_file ((pps (cstruct.ppx)) (ctl)))) )) diff --git a/projects/miragesdk/src/test/test.ml b/projects/miragesdk/src/test/test.ml index 0aa18877a..5876f1805 100644 --- a/projects/miragesdk/src/test/test.ml +++ b/projects/miragesdk/src/test/test.ml @@ -2,35 +2,234 @@ open Astring open Lwt.Infix open Sdk -let random_string n = Bytes.create n +let random_string n = + Bytes.init n (fun _ -> char_of_int (Random.int 255)) + +(* workaround https://github.com/mirage/alcotest/issues/88 *) +exception Check_error of string + +let check_raises msg exn f = + Lwt.catch (fun () -> + f () >>= fun () -> + Lwt.fail (Check_error msg) + ) (function + | Check_error e -> Alcotest.fail e + | e -> + if exn e then Lwt.return_unit + else Fmt.kstrf Alcotest.fail "%s raised %a" msg Fmt.exn e) + +let is_unix_error = function + | Unix.Unix_error _ -> true + | _ -> false + +let escape = String.Ascii.escape + +let write fd strs = + Lwt_list.iter_s (fun str -> + IO.really_write fd str 0 (String.length str) + ) strs let test_pipe pipe () = let calf = Init.Fd.fd @@ Init.Pipe.(calf pipe) in let priv = Init.Fd.fd @@ Init.Pipe.(priv pipe) in - let test str = - (* check the the pipe is unidirectional *) - IO.really_write calf str 0 (String.length str) >>= fun () -> + let name = Init.Pipe.name pipe in + let test strs = + let escape_strs = String.concat ~sep:"" @@ List.map escape strs in + (* pipes are unidirectional *) + (* calf -> priv works *) + write calf strs >>= fun () -> IO.read_all priv >>= fun buf -> - Alcotest.(check string) "stdout" - (String.Ascii.escape str) (String.Ascii.escape buf); - Lwt.catch (fun () -> - IO.really_write priv str 0 (String.length str) >|= fun () -> - Alcotest.fail "priv side is writable!" - ) (fun _ -> Lwt.return_unit) - >>= fun () -> - Lwt.catch (fun () -> - IO.read_all calf >|= fun _ -> - Alcotest.fail "calf sid is readable!" - ) (fun _ -> Lwt.return_unit) - >>= fun () -> + let msg = Fmt.strf "%s: calf -> priv" name in + Alcotest.(check string) msg escape_strs (escape buf); + (* priv -> calf don't *) + check_raises (Fmt.strf "%s: priv side is writable!" name) is_unix_error + (fun () -> write priv strs) >>= fun () -> + check_raises (Fmt.strf "%s: calf sid is readable!" name) is_unix_error + (fun () -> IO.read_all calf >|= ignore) >>= fun () -> Lwt.return_unit in - test (random_string 1) >>= fun () -> - test (random_string 100) >>= fun () -> - test (random_string 10241) >>= fun () -> + test [random_string 1] >>= fun () -> + test [random_string 1; random_string 1; random_string 10] >>= fun () -> + test [random_string 100] >>= fun () -> + test [random_string 10241] >>= fun () -> Lwt.return_unit +let test_socketpair pipe () = + let calf = Init.Fd.fd @@ Init.Pipe.(calf pipe) in + let priv = Init.Fd.fd @@ Init.Pipe.(priv pipe) in + let name = Init.Pipe.name pipe in + let test strs = + let escape_strs = String.concat ~sep:"" @@ List.map escape strs in + (* socket pairs are bi-directional *) + (* calf -> priv works *) + write calf strs >>= fun () -> + IO.read_all priv >>= fun buf -> + Alcotest.(check string) (name ^ " calf -> priv") escape_strs (escape buf); + (* priv -> cal works *) + write priv strs >>= fun () -> + IO.read_all calf >>= fun buf -> + Alcotest.(check string) (name ^ " priv -> calf") escape_strs (escape buf); + Lwt.return_unit + in + test [random_string 1] >>= fun () -> + test [random_string 1; random_string 1; random_string 10] >>= fun () -> + test [random_string 100] >>= fun () -> + (* note: if size(writes) > 8192 then the next writes will block (as + we are using SOCK_STREAM *) + let n = 8182 / 4 in + test [ + random_string n; + random_string n; + random_string n; + random_string n; + ] >>= fun () -> + + Lwt.return_unit + +let query = Alcotest.testable Ctl.Query.pp (=) +let reply = Alcotest.testable Ctl.Reply.pp (=) + +let queries = + let open Ctl.Query in + [ + { version = 0l; id = 0l; operation = Read; path = "/foo/bar"; payload = "" }; + { version = Int32.max_int; id = Int32.max_int; operation = Write ; path = ""; payload = "foo" }; + { version = 1l;id = 0l; operation = Delete; path = ""; payload = "" }; + { version = -2l; id = -3l; operation = Delete; path = "foo"; payload = "foo" }; + ] + +let replies = + let open Ctl.Reply in + [ + { id = 0l; status = Ok; payload = "" }; + { id = Int32.max_int; status = Ok; payload = "foo" }; + { id = 0l; status = Error; payload = "" }; + { id = -3l; status = Error; payload = "foo" }; + ] + +let test_serialization to_cstruct of_cstruct message messages = + let test m = + let buf = to_cstruct m in + match of_cstruct buf with + | Ok m' -> Alcotest.(check message) "to_cstruct/of_cstruct" m m' + | Error (`Msg e) -> Alcotest.fail ("Message.of_cstruct: " ^ e) + in + List.iter test messages + +let test_send write read message messages = + let calf = Init.Fd.fd @@ Init.Pipe.(calf ctl) in + let priv = Init.Fd.fd @@ Init.Pipe.(priv ctl) in + let test m = + write calf m >>= fun () -> + read priv >|= function + | Ok m' -> Alcotest.(check message) "write/read" m m' + | Error (`Msg e) -> Alcotest.fail ("Message.read: " ^ e) + in + Lwt_list.iter_s test messages + +let test_query_serialization () = + let open Ctl.Query in + test_serialization to_cstruct of_cstruct query queries + +let test_reply_serialization () = + let open Ctl.Reply in + test_serialization to_cstruct of_cstruct reply replies + +let test_query_send () = + let open Ctl.Query in + test_send write read query queries + +let test_reply_send () = + let open Ctl.Reply in + test_send write read reply replies + +let failf fmt = Fmt.kstrf Alcotest.fail fmt + +(* read ops *) + +let read_should_err t k = + Ctl.Client.read t k >|= function + | Error (`Msg _) -> () + | Ok None -> failf "read(%s) -> got: none, expected: err" k + | Ok Some v -> failf "read(%s) -> got: found:%S, expected: err" k v + +let read_should_none t k = + Ctl.Client.read t k >|= function + | Error (`Msg e) -> failf "read(%s) -> got: error:%s, expected none" k e + | Ok None -> () + | Ok Some v -> failf "read(%s) -> got: found:%S, expected none" k v + +let read_should_work t k v = + Ctl.Client.read t k >|= function + | Error (`Msg e) -> failf "read(%s) -> got: error:%s, expected ok" k e + | Ok None -> failf "read(%s) -> got: none, expected ok" k + | Ok Some v' -> + if v <> v' then failf "read(%s) -> got: ok:%S, expected: ok:%S" k v' v + +(* write ops *) + +let write_should_err t k v = + Ctl.Client.write t k v >|= function + | Ok () -> failf "write(%s) -> ok" k + | Error _ -> () + +let write_should_work t k v = + Ctl.Client.write t k v >|= function + | Ok () -> () + | Error (`Msg e) -> failf "write(%s) -> error: %s" k e + +(* del ops *) + +let delete_should_err t k = + Ctl.Client.delete t k >|= function + | Ok () -> failf "del(%s) -> ok" k + | Error _ -> () + +let delete_should_work t k = + Ctl.Client.delete t k >|= function + | Ok () -> () + | Error (`Msg e) -> failf "write(%s) -> error: %s" k e + +let test_ctl () = + let calf = Init.Fd.fd @@ Init.Pipe.(calf ctl) in + let priv = Init.Fd.fd @@ Init.Pipe.(priv ctl) in + let k1 = "/foo/bar" in + let k2 = "a" in + let k3 = "b/c" in + let k4 = "xxxxxx" in + let routes = [k1; k2; k3] in + let git_root = "/tmp/sdk/ctl" in + let _ = Sys.command (Fmt.strf "rm -rf %s" git_root) in + Ctl.v git_root >>= fun ctl -> + let server () = Ctl.Server.listen ~routes ctl priv in + let client () = + let t = Ctl.Client.v calf in + let allowed k v = + delete_should_work t k >>= fun () -> + read_should_none t k >>= fun () -> + write_should_work t k v >>= fun () -> + read_should_work t k v >>= fun () -> + let path = String.cuts ~empty:false ~sep:"/" k in + Ctl.KV.get ctl path >|= fun v' -> + Alcotest.(check string) "in the db" v v' + in + let disallowed k v = + read_should_err t k >>= fun () -> + write_should_err t k v >>= fun () -> + delete_should_err t k + in + allowed k1 "" >>= fun () -> + allowed k2 "xxx" >>= fun () -> + allowed k3 (random_string (255 * 1024)) >>= fun () -> + disallowed k4 "" >>= fun () -> + Lwt.return_unit + in + Lwt.pick [ + client (); + server (); + ] + let run f () = try Lwt_main.run (f ()) with e -> @@ -40,9 +239,16 @@ let run f () = let test_stderr () = () let test = [ - "stdout" , `Quick, run (test_pipe Init.Pipe.stdout); - "stdout" , `Quick, run (test_pipe Init.Pipe.stderr); - ] + "stdout is a pipe" , `Quick, run (test_pipe Init.Pipe.stdout); + "stdout is a pipe" , `Quick, run (test_pipe Init.Pipe.stderr); + "net is a socket pair", `Quick, run (test_socketpair Init.Pipe.net); + "ctl is a socket pair", `Quick, run (test_socketpair Init.Pipe.ctl); + "seralize queries" , `Quick, test_query_serialization; + "seralize replies" , `Quick, test_reply_serialization; + "send queries" , `Quick, run test_query_send; + "send replies" , `Quick, run test_reply_send; + "ctl" , `Quick, run test_ctl; +] let reporter ?(prefix="") () = let pad n x =