diff --git a/projects/miragesdk/src/sdk/IO.ml b/projects/miragesdk/src/sdk/IO.ml index 2d9345729..a03189370 100644 --- a/projects/miragesdk/src/sdk/IO.ml +++ b/projects/miragesdk/src/sdk/IO.ml @@ -4,18 +4,20 @@ let src = Logs.Src.create "IO" ~doc:"IO helpers" module Log = (val Logs.src_log src : Logs.LOG) let rec really_write fd buf off len = - Log.debug (fun l -> l "really_write"); match len with | 0 -> Lwt.return_unit | len -> + Log.debug (fun l -> l "really_write off=%d len=%d" off len); Lwt_unix.write fd buf off len >>= fun n -> really_write fd buf (off+n) (len-n) +let write fd buf = really_write fd buf 0 (String.length buf) + let rec really_read fd buf off len = - Log.debug (fun l -> l "really_read"); match len with | 0 -> Lwt.return_unit | len -> + Log.debug (fun l -> l "really_read off=%d len=%d" off len); Lwt_unix.read fd buf off len >>= fun n -> really_read fd buf (off+n) (len-n) @@ -33,7 +35,7 @@ let read_all fd = String.concat "" bufs let read_n fd len = - Log.debug (fun l -> l "read_n"); + Log.debug (fun l -> l "read_n len=%d" len); let buf = Bytes.create len in let rec loop acc len = Lwt_unix.read fd buf 0 len >>= fun n -> diff --git a/projects/miragesdk/src/sdk/IO.mli b/projects/miragesdk/src/sdk/IO.mli index 121ba33c7..b07954d18 100644 --- a/projects/miragesdk/src/sdk/IO.mli +++ b/projects/miragesdk/src/sdk/IO.mli @@ -3,6 +3,9 @@ val really_write: Lwt_unix.file_descr -> string -> int -> int -> unit Lwt.t (** [really_write fd buf off len] writes exactly [len] bytes to [fd]. *) +val write: Lwt_unix.file_descr -> string -> unit Lwt.t +(** [write fd buf] writes all the buffer [buf] in [fd]. *) + val really_read: Lwt_unix.file_descr -> string -> int -> int -> unit Lwt.t (** [really_read fd buf off len] reads exactly [len] bytes from [fd]. *) diff --git a/projects/miragesdk/src/sdk/ctl.ml b/projects/miragesdk/src/sdk/ctl.ml index d561e067b..4bad89563 100644 --- a/projects/miragesdk/src/sdk/ctl.ml +++ b/projects/miragesdk/src/sdk/ctl.ml @@ -41,7 +41,7 @@ module Message = struct type t = { operation: operation; path : string; - payload : string option; + payload : string; } [%%cstruct type message = { @@ -53,48 +53,65 @@ module Message = struct (* to avoid warning 32 *) let _ = hexdump_message - let _ = operation_to_string let _ = string_to_operation - let read_message fd = - IO.read_n fd 4 >>= fun buf -> - let len = - Cstruct.LE.get_uint32 (Cstruct.of_string buf) 0 - |> Int32.to_int - in - IO.read_n fd len >>= fun buf -> - let buf = Cstruct.of_string buf in + let pp ppf t = + Fmt.pf ppf "%s:%S:%S" (operation_to_string t.operation) t.path t.payload + + (* FIXME: allocate less ... *) + + let of_cstruct buf = + Log.debug (fun l -> l "Message.of_cstruct %S" @@ Cstruct.to_string buf); let operation = match int_to_operation (get_message_operation buf) with | None -> failwith "invalid operation" | Some o -> o in let path_len = get_message_path buf in let payload_len = get_message_payload buf in - IO.read_n fd path_len >>= fun path -> - (match payload_len with - | 0 -> Lwt.return None - | n -> IO.read_n fd n >|= fun x -> Some x) - >|= fun payload -> + Log.debug (fun l -> l "XXX path=%d len=%d" path_len payload_len); + let path = + Cstruct.sub buf sizeof_message path_len + |> Cstruct.to_string + in + let payload = + Cstruct.sub buf (sizeof_message+path_len) payload_len + |> Cstruct.to_string + in { operation; path; payload } - let write_message fd msg = + let to_cstruct msg = + Log.debug (fun l -> l "Message.to_cstruct %a" pp msg); let operation = operation_to_int msg.operation in let path = String.length msg.path in - let payload = match msg.payload with - | None -> 0 - | Some x -> String.length x - in + let payload = String.length msg.payload in let len = sizeof_message + path + payload in let buf = Cstruct.create len in set_message_operation buf operation; set_message_path buf path; - set_message_payload buf path; + set_message_payload buf payload; Cstruct.blit_from_bytes msg.path 0 buf sizeof_message path; - let () = match msg.payload with - | None -> () - | Some x -> Cstruct.blit_from_bytes x 0 buf (sizeof_message+path) payload + Cstruct.blit_from_bytes msg.payload 0 buf (sizeof_message+path) payload; + buf + + let read fd = + IO.read_n fd 4 >>= fun buf -> + Log.debug (fun l -> l "Message.read len=%S" buf); + let len = + Cstruct.LE.get_uint32 (Cstruct.of_string buf) 0 + |> Int32.to_int in - IO.really_write fd (Cstruct.to_string buf) 0 len + IO.read_n fd len >|= fun buf -> + of_cstruct (Cstruct.of_string buf) + + let write fd msg = + let buf = to_cstruct msg |> Cstruct.to_string in + let len = + let len = Cstruct.create 4 in + Cstruct.LE.set_uint32 len 0 (Int32.of_int @@ String.length buf); + Cstruct.to_string len + in + IO.write fd len >>= fun () -> + IO.write fd buf end @@ -118,9 +135,7 @@ module Dispatch = struct match msg.operation with | Write -> let info = infof "Updating %a" KV.Key.pp key in - (match msg.payload with - | None -> Fmt.kstrf Lwt.fail_with "dispatch: missing payload" - | Some v -> KV.set db ~info key v) + KV.set db ~info key msg.payload | _ -> failwith "TODO" ) @@ -128,7 +143,7 @@ module Dispatch = struct let msgs = Queue.create () in let cond = Lwt_condition.create () in let rec listen () = - read_message fd >>= fun msg -> + Message.read fd >>= fun msg -> Queue.add msg msgs; Lwt_condition.signal cond (); listen () diff --git a/projects/miragesdk/src/sdk/ctl.mli b/projects/miragesdk/src/sdk/ctl.mli index 5af101be8..452e65cca 100644 --- a/projects/miragesdk/src/sdk/ctl.mli +++ b/projects/miragesdk/src/sdk/ctl.mli @@ -15,14 +15,24 @@ module Message: sig type t = { operation: operation; path : string; - payload : string option; + payload : string; } - val write_message: Lwt_unix.file_descr -> t -> unit Lwt.t - (** [write_message fd t] writes a control message. *) + val pp: t Fmt.t + (** [pp] is the pretty-printer for messages. *) - val read_message: Lwt_unix.file_descr -> t Lwt.t - (** [read_message fd] reads a control message. *) + val of_cstruct: Cstruct.t -> t + (** [of_cstruct buf] is the message [t] such that the serialization + of [t] is [buf]. *) + + val to_cstruct: t -> Cstruct.t + (** [to_cstruct t] is the serialization of [t]. *) + + val write: Lwt_unix.file_descr -> t -> unit Lwt.t + (** [write fd t] writes a control message. *) + + val read: Lwt_unix.file_descr -> t Lwt.t + (** [read fd] reads a control message. *) end diff --git a/projects/miragesdk/src/sdk/init.ml b/projects/miragesdk/src/sdk/init.ml index b9e2cd7be..c44e7ae84 100644 --- a/projects/miragesdk/src/sdk/init.ml +++ b/projects/miragesdk/src/sdk/init.ml @@ -97,6 +97,8 @@ module Pipe = struct type t = Fd.t * Fd.t + let name (x, _) = x.Fd.name + let priv = fst let calf = snd diff --git a/projects/miragesdk/src/sdk/init.mli b/projects/miragesdk/src/sdk/init.mli index fc50863cd..1e2828a26 100644 --- a/projects/miragesdk/src/sdk/init.mli +++ b/projects/miragesdk/src/sdk/init.mli @@ -63,6 +63,9 @@ module Pipe: sig (** The type for pipes. Could be either uni-directional (normal pipes) or a bi-directional (socket pairs). *) + val name: t -> string + (** [name t] is [t]'s name. *) + val priv: t -> Fd.t (** [priv p] is the private side of the pipe [p]. *) diff --git a/projects/miragesdk/src/test/test.ml b/projects/miragesdk/src/test/test.ml index 0aa18877a..925ed7c04 100644 --- a/projects/miragesdk/src/test/test.ml +++ b/projects/miragesdk/src/test/test.ml @@ -2,35 +2,121 @@ open Astring open Lwt.Infix open Sdk -let random_string n = Bytes.create n +let random_string n = + Bytes.init n (fun _ -> char_of_int (Random.int 255)) + +(* workaround https://github.com/mirage/alcotest/issues/88 *) +exception Check_error of string + +let check_raises msg exn f = + Lwt.catch (fun () -> + f () >>= fun () -> + Lwt.fail (Check_error msg) + ) (function + | Check_error e -> Alcotest.fail e + | e -> + if exn e then Lwt.return_unit + else Fmt.kstrf Alcotest.fail "%s raised %a" msg Fmt.exn e) + +let is_unix_error = function + | Unix.Unix_error _ -> true + | _ -> false + +let escape = String.Ascii.escape + +let write fd strs = + Lwt_list.iter_s (fun str -> + IO.really_write fd str 0 (String.length str) + ) strs let test_pipe pipe () = let calf = Init.Fd.fd @@ Init.Pipe.(calf pipe) in let priv = Init.Fd.fd @@ Init.Pipe.(priv pipe) in - let test str = - (* check the the pipe is unidirectional *) - IO.really_write calf str 0 (String.length str) >>= fun () -> + let name = Init.Pipe.name pipe in + let test strs = + let escape_strs = String.concat ~sep:"" @@ List.map escape strs in + (* pipes are unidirectional *) + (* calf -> priv works *) + write calf strs >>= fun () -> IO.read_all priv >>= fun buf -> - Alcotest.(check string) "stdout" - (String.Ascii.escape str) (String.Ascii.escape buf); - Lwt.catch (fun () -> - IO.really_write priv str 0 (String.length str) >|= fun () -> - Alcotest.fail "priv side is writable!" - ) (fun _ -> Lwt.return_unit) - >>= fun () -> - Lwt.catch (fun () -> - IO.read_all calf >|= fun _ -> - Alcotest.fail "calf sid is readable!" - ) (fun _ -> Lwt.return_unit) - >>= fun () -> + let msg = Fmt.strf "%s: calf -> priv" name in + Alcotest.(check string) msg escape_strs (escape buf); + (* priv -> calf don't *) + check_raises (Fmt.strf "%s: priv side is writable!" name) is_unix_error + (fun () -> write priv strs) >>= fun () -> + check_raises (Fmt.strf "%s: calf sid is readable!" name) is_unix_error + (fun () -> IO.read_all calf >|= ignore) >>= fun () -> Lwt.return_unit in - test (random_string 1) >>= fun () -> - test (random_string 100) >>= fun () -> - test (random_string 10241) >>= fun () -> + test [random_string 1] >>= fun () -> + test [random_string 1; random_string 1; random_string 10] >>= fun () -> + test [random_string 100] >>= fun () -> + test [random_string 10241] >>= fun () -> Lwt.return_unit +let test_socketpair pipe () = + let calf = Init.Fd.fd @@ Init.Pipe.(calf pipe) in + let priv = Init.Fd.fd @@ Init.Pipe.(priv pipe) in + let name = Init.Pipe.name pipe in + let test strs = + let escape_strs = String.concat ~sep:"" @@ List.map escape strs in + (* socket pairs are bi-directional *) + (* calf -> priv works *) + write calf strs >>= fun () -> + IO.read_all priv >>= fun buf -> + Alcotest.(check string) (name ^ " calf -> priv") escape_strs (escape buf); + (* priv -> cal works *) + write priv strs >>= fun () -> + IO.read_all calf >>= fun buf -> + Alcotest.(check string) (name ^ " priv -> calf") escape_strs (escape buf); + Lwt.return_unit + in + test [random_string 1] >>= fun () -> + test [random_string 1; random_string 1; random_string 10] >>= fun () -> + test [random_string 100] >>= fun () -> + (* note: if size(writes) > 8192 then the next writes will block (as + we are using SOCK_STREAM *) + let n = 8182 / 4 in + test [ + random_string n; + random_string n; + random_string n; + random_string n; + ] >>= fun () -> + + Lwt.return_unit + +let message = Alcotest.testable Ctl.Message.pp (=) + +let test_message_serialization () = + let test m = + let buf = Ctl.Message.to_cstruct m in + let m' = Ctl.Message.of_cstruct buf in + Alcotest.(check message) "to_cstruct/of_cstruct" m m' + in + List.iter test [ + { operation = Read ; path = "/foo/bar"; payload = "" }; + { operation = Write ; path = "" ; payload = "foo" }; + { operation = Delete; path = "" ; payload = "" }; + { operation = Delete; path = "foo" ; payload = "foo" }; + ] + +let test_message_send () = + let calf = Init.Fd.fd @@ Init.Pipe.(calf ctl) in + let priv = Init.Fd.fd @@ Init.Pipe.(priv ctl) in + let test m = + Ctl.Message.write calf m >>= fun () -> + Ctl.Message.read priv >|= fun m' -> + Alcotest.(check message) "write/read" m m' + in + Lwt_list.iter_s test [ + { operation = Read ; path = "/foo/bar"; payload = "" }; + { operation = Write ; path = "" ; payload = "foo" }; + { operation = Delete; path = "" ; payload = "" }; + { operation = Delete; path = "foo" ; payload = "foo" }; + ] + let run f () = try Lwt_main.run (f ()) with e -> @@ -40,9 +126,13 @@ let run f () = let test_stderr () = () let test = [ - "stdout" , `Quick, run (test_pipe Init.Pipe.stdout); - "stdout" , `Quick, run (test_pipe Init.Pipe.stderr); - ] + "stdout is a pipe" , `Quick, run (test_pipe Init.Pipe.stdout); + "stdout is a pipe" , `Quick, run (test_pipe Init.Pipe.stderr); + "net is a socket pair", `Quick, run (test_socketpair Init.Pipe.net); + "ctl is a socket pair", `Quick, run (test_socketpair Init.Pipe.ctl); + "seralize messages" , `Quick, test_message_serialization; + "send messages" , `Quick, run test_message_send; +] let reporter ?(prefix="") () = let pad n x =