diff --git a/projects/miragesdk/src/sdk/init.ml b/projects/miragesdk/src/sdk/init.ml index 6817c1e20..1ef40cb05 100644 --- a/projects/miragesdk/src/sdk/init.ml +++ b/projects/miragesdk/src/sdk/init.ml @@ -138,7 +138,8 @@ module Fd = struct let stdin = { name = "stdin" ; fd = Lwt_unix.stdin } let of_int name (i:int) = - let fd = Lwt_unix.of_unix_file_descr (Obj.magic i: Unix.file_descr) in + let fd : Unix.file_descr = Obj.magic i in + let fd = Lwt_unix.of_unix_file_descr fd in { name; fd } let to_int t = @@ -148,7 +149,7 @@ module Fd = struct let close fd = Log.debug (fun l -> l "close %a" pp fd); - Lwt_unix.close fd.fd + Unix.close (Lwt_unix.unix_file_descr fd.fd) let dev_null = Lwt_unix.of_unix_file_descr ~blocking:false @@ -156,9 +157,9 @@ module Fd = struct let redirect_to_dev_null fd = Log.debug (fun l -> l "redirect-stdin-to-dev-null"); - Lwt_unix.close fd.fd >>= fun () -> + Unix.close (Lwt_unix.unix_file_descr fd.fd); Lwt_unix.dup2 dev_null fd.fd; - Lwt_unix.close dev_null + Unix.close (Lwt_unix.unix_file_descr dev_null) let dup2 ~src ~dst = Log.debug (fun l -> l "dup2 %a => %a" pp src pp dst); @@ -220,14 +221,14 @@ end let exec_calf t cmd = Log.info (fun l -> l "child pid is %d" Unix.(getpid ())); - Fd.(redirect_to_dev_null stdin) >>= fun () -> + Fd.(redirect_to_dev_null stdin); (* close parent fds *) - Fd.close Pipe.(priv t.stdout) >>= fun () -> - Fd.close Pipe.(priv t.stderr) >>= fun () -> - Fd.close Pipe.(priv t.ctl) >>= fun () -> - Fd.close Pipe.(priv t.net) >>= fun () -> - Fd.close Pipe.(priv t.metrics) >>= fun () -> + Fd.close Pipe.(priv t.stdout); + Fd.close Pipe.(priv t.stderr); + Fd.close Pipe.(priv t.ctl); + Fd.close Pipe.(priv t.net); + Fd.close Pipe.(priv t.metrics); let cmds = String.concat " " cmd in @@ -239,10 +240,10 @@ let exec_calf t cmd = Log.info (fun l -> l "Executing %s" cmds); (* Move all open fds at the top *) - Fd.dup2 ~src:Pipe.(calf t.stdout) ~dst:calf_stdout >>= fun () -> - Fd.dup2 ~src:Pipe.(calf t.stderr) ~dst:calf_stderr >>= fun () -> - Fd.dup2 ~src:Pipe.(calf t.net) ~dst:calf_net >>= fun () -> - Fd.dup2 ~src:Pipe.(calf t.ctl) ~dst:calf_ctl >>= fun () -> + Fd.dup2 ~src:Pipe.(calf t.net) ~dst:calf_net; + Fd.dup2 ~src:Pipe.(calf t.ctl) ~dst:calf_ctl; + Fd.dup2 ~src:Pipe.(calf t.stderr) ~dst:calf_stderr; + Fd.dup2 ~src:Pipe.(calf t.stdout) ~dst:calf_stdout; (* exec the calf *) Unix.execve (List.hd cmd) (Array.of_list cmd) [||] @@ -255,22 +256,16 @@ let check_exit_status cmd status = | Unix.WSIGNALED i -> failf "%s: signal %d" cmds i | Unix.WSTOPPED i -> failf "%s: stopped %d" cmds i -let exec_priv t ~pid ~cmd ~net ~ctl ~handlers = +let exec_priv t ~pid cmd = - Fd.(redirect_to_dev_null stdin) >>= fun () -> + Fd.(redirect_to_dev_null stdin); (* close child fds *) - Fd.close Pipe.(calf t.stdout) >>= fun () -> - Fd.close Pipe.(calf t.stderr) >>= fun () -> - Fd.close Pipe.(calf t.net) >>= fun () -> - Fd.close Pipe.(calf t.ctl) >>= fun () -> - Fd.close Pipe.(calf t.metrics) >>= fun () -> - - - let priv_net = Fd.flow Pipe.(priv t.net) in - let priv_ctl = Fd.flow Pipe.(priv t.ctl) in - let priv_stdout = Fd.flow Pipe.(priv t.stdout) in - let priv_stderr = Fd.flow Pipe.(priv t.stderr) in + Fd.close Pipe.(calf t.stdout); + Fd.close Pipe.(calf t.stderr); + Fd.close Pipe.(calf t.net); + Fd.close Pipe.(calf t.ctl); + Fd.close Pipe.(calf t.metrics); let wait () = Lwt_unix.waitpid [] pid >>= fun (_pid, w) -> @@ -278,6 +273,21 @@ let exec_priv t ~pid ~cmd ~net ~ctl ~handlers = check_exit_status cmd w in + Lwt.return wait + +let block_for_ever = + let t, _ = Lwt.task () in + fun () -> t + +let exec_and_forward ?(handlers=block_for_ever) ~pid ~cmd ~net ~ctl t = + + exec_priv t ~pid cmd >>= fun wait -> + + let priv_net = Fd.flow Pipe.(priv t.net) in + let priv_ctl = Fd.flow Pipe.(priv t.ctl) in + let priv_stdout = Fd.flow Pipe.(priv t.stdout) in + let priv_stderr = Fd.flow Pipe.(priv t.stderr) in + Lwt.pick ([ wait (); (* data *) @@ -286,13 +296,17 @@ let exec_priv t ~pid ~cmd ~net ~ctl ~handlers = (* redirect the calf stdout to the shim stdout *) IO.forward ~src:priv_stdout ~dst:Fd.(flow stdout); IO.forward ~src:priv_stderr ~dst:Fd.(flow stderr); - (* TODO: Init.Fd.forward ~src:Init.Pipe.(priv metrics) ~dst:Init.Fd.metric; *) + (* TODO: Init.Fd.forward ~src:Init.Pipe.(priv metrics) + ~dst:Init.Fd.metric; *) ctl priv_ctl; handlers (); ]) -let run t ~net ~ctl ~handlers cmd = +let exec t cmd fn = Lwt_io.flush_all () >>= fun () -> match Lwt_unix.fork () with | 0 -> exec_calf t cmd - | pid -> exec_priv t ~pid ~cmd ~net ~ctl ~handlers + | pid -> fn pid + +let run t ~net ~ctl ?handlers cmd = + exec t cmd (fun pid -> exec_and_forward ?handlers ~pid ~cmd ~net ~ctl t) diff --git a/projects/miragesdk/src/sdk/init.mli b/projects/miragesdk/src/sdk/init.mli index 5ec269ec0..f3e225496 100644 --- a/projects/miragesdk/src/sdk/init.mli +++ b/projects/miragesdk/src/sdk/init.mli @@ -23,13 +23,13 @@ module Fd: sig val pp: t Fmt.t (** [pp_fd] pretty prints a file descriptor. *) - val redirect_to_dev_null: t -> unit Lwt.t + val redirect_to_dev_null: t -> unit (** [redirect_to_dev_null fd] redirects [fd] [/dev/null]. *) - val close: t -> unit Lwt.t + val close: t -> unit (** [close fd] closes [fd]. *) - val dup2: src:t -> dst:t -> unit Lwt.t + val dup2: src:t -> dst:t -> unit (** [dup2 ~src ~dst] calls [Unix.dup2] on [src] and [dst]. *) (** {1 Usefull File Descriptors} *) @@ -103,13 +103,17 @@ val rawlink: ?filter:string -> string -> IO.t {{:https://github.com/haesbaert/rawlink}rawlink} for more details on how to build that filter. *) +val exec: Pipe.monitor -> string list -> (int -> unit Lwt.t) -> unit Lwt.t +(** [exec t cmd k] executes [cmd] in an unprivileged calf process and + call [k] with the pid of the parent process. The child and parents + are connected using [t]. *) + (* FIXME(samoht): not very happy with that signatue *) val run: Pipe.monitor -> - net:IO.t -> - ctl:(IO.t -> unit Lwt.t) -> - handlers:(unit -> unit Lwt.t) -> + net:IO.t -> ctl:(IO.t -> unit Lwt.t) -> + ?handlers:(unit -> unit Lwt.t) -> string list -> unit Lwt.t -(** [run m ~net ~ctl ~handlers cmd] runs [cmd] in a unprivileged calf +(** [run m ~net ~ctl ?handlers cmd] runs [cmd] in a unprivileged calf process. [net] is the network interface flow. [ctl] is the control thread connected to the {Pipe.ctl} pipe. [handlers] are the system handler thread which will react to control data to perform diff --git a/projects/miragesdk/src/test/jbuild b/projects/miragesdk/src/test/jbuild index ef137b2c4..c1b135f79 100644 --- a/projects/miragesdk/src/test/jbuild +++ b/projects/miragesdk/src/test/jbuild @@ -2,7 +2,7 @@ (executables ((names (test)) - (libraries (sdk alcotest astring mtime.os)))) + (libraries (sdk alcotest astring mtime.os mirage-flow-lwt)))) (alias ((name runtest) diff --git a/projects/miragesdk/src/test/test.ml b/projects/miragesdk/src/test/test.ml index 7af003631..fe86ea480 100644 --- a/projects/miragesdk/src/test/test.ml +++ b/projects/miragesdk/src/test/test.ml @@ -124,8 +124,8 @@ let test_serialization to_cstruct of_cstruct message messages = List.iter test messages let test_send t write read message messages = - let calf = Init.(Fd.flow Pipe.(calf @@ ctl t)) in - let priv = Init.(Fd.flow Pipe.(priv @@ ctl t)) in + let calf = calf Init.Pipe.(ctl t) in + let priv = priv Init.Pipe.(ctl t) in let test m = write calf m >>= fun () -> read priv >|= function @@ -198,8 +198,8 @@ let delete_should_work t k = | Error (`Msg e) -> failf "write(%s) -> error: %s" k e let test_ctl t () = - let calf = Init.(Fd.flow Init.Pipe.(calf @@ ctl t)) in - let priv = Init.(Fd.flow Init.Pipe.(priv @@ ctl t)) in + let calf = calf Init.Pipe.(ctl t) in + let priv = priv Init.Pipe.(ctl t) in let k1 = "/foo/bar" in let k2 = "a" in let k3 = "b/c" in @@ -236,6 +236,33 @@ let test_ctl t () = server (); ] +let in_memory_flow () = + let flow = Mirage_flow_lwt.F.string () in + IO.create (module Mirage_flow_lwt.F) flow "mem" + +let test_exec () = + let test () = + let check n pipe = + let t = Init.Pipe.v () in + let pipe = pipe t in + Init.exec t ["/bin/sh"; "-c"; "echo foo >& " ^ string_of_int n] @@ fun _pid -> + read @@ priv pipe >>= fun foo -> + let name = Fmt.strf "fork %s" Init.Pipe.(name pipe) in + Alcotest.(check string) name "foo\n" foo; + Lwt.return_unit + in + check 1 Init.Pipe.stdout >>= fun () -> + (* avoid logging interference *) + let level = Logs.level () in + Logs.set_level None; + check 2 Init.Pipe.stderr >>= fun () -> + Logs.set_level level; + check 3 Init.Pipe.net >>= fun () -> + check 4 Init.Pipe.ctl >>= fun () -> + Lwt.return_unit + in + test () + let run f () = try Lwt_main.run (f ()) with e -> @@ -256,6 +283,7 @@ let test = [ "send queries" , `Quick, run (test_query_send t); "send replies" , `Quick, run (test_reply_send t); "ctl" , `Quick, run (test_ctl t); + "exec" , `Quick, run test_exec; ] let reporter ?(prefix="") () =