diff --git a/virtcontainers/kata_agent.go b/virtcontainers/kata_agent.go index 58ad7931b..2c2e8d464 100644 --- a/virtcontainers/kata_agent.go +++ b/virtcontainers/kata_agent.go @@ -17,6 +17,7 @@ import ( "syscall" "time" + "github.com/gogo/protobuf/proto" kataclient "github.com/kata-containers/agent/protocols/client" "github.com/kata-containers/agent/protocols/grpc" vcAnnotations "github.com/kata-containers/runtime/virtcontainers/pkg/annotations" @@ -74,6 +75,7 @@ type kataAgent struct { shim shim proxy proxy client *kataclient.AgentClient + reqHandlers map[string]reqFunc state KataAgentState keepConn bool proxyBuiltIn bool @@ -917,6 +919,7 @@ func (k *kataAgent) connect() error { return err } + k.installReqFunc(client) k.client = client return nil @@ -930,7 +933,9 @@ func (k *kataAgent) disconnect() error { if err := k.client.Close(); err != nil && err != golangGrpc.ErrClientConnClosing { return err } + k.client = nil + k.reqHandlers = nil return nil } @@ -940,6 +945,50 @@ func (k *kataAgent) check() error { return err } +type reqFunc func(context.Context, interface{}, ...golangGrpc.CallOption) (interface{}, error) + +func (k *kataAgent) installReqFunc(c *kataclient.AgentClient) { + k.reqHandlers = make(map[string]reqFunc) + k.reqHandlers["grpc.CheckRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return k.client.Check(ctx, req.(*grpc.CheckRequest), opts...) + } + k.reqHandlers["grpc.ExecProcessRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.ExecProcess(ctx, req.(*grpc.ExecProcessRequest), opts...) + } + k.reqHandlers["grpc.CreateSandboxRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.CreateSandbox(ctx, req.(*grpc.CreateSandboxRequest), opts...) + } + k.reqHandlers["grpc.DestroySandboxRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.DestroySandbox(ctx, req.(*grpc.DestroySandboxRequest), opts...) + } + k.reqHandlers["grpc.CreateContainerRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.CreateContainer(ctx, req.(*grpc.CreateContainerRequest), opts...) + } + k.reqHandlers["grpc.StartContainerRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.StartContainer(ctx, req.(*grpc.StartContainerRequest), opts...) + } + k.reqHandlers["grpc.RemoveContainerRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.RemoveContainer(ctx, req.(*grpc.RemoveContainerRequest), opts...) + } + k.reqHandlers["grpc.SignalProcessRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.SignalProcess(ctx, req.(*grpc.SignalProcessRequest), opts...) + } + k.reqHandlers["grpc.UpdateRoutesRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.UpdateRoutes(ctx, req.(*grpc.UpdateRoutesRequest), opts...) + } + k.reqHandlers["grpc.UpdateInterfaceRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.UpdateInterface(ctx, req.(*grpc.UpdateInterfaceRequest), opts...) + } + k.reqHandlers["grpc.OnlineCPUMemRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.OnlineCPUMem(ctx, req.(*grpc.OnlineCPUMemRequest), opts...) + } + k.reqHandlers["grpc.ListProcessesRequest"] = func(ctx context.Context, req interface{}, opts ...golangGrpc.CallOption) (interface{}, error) { + return k.client.ListProcesses(ctx, req.(*grpc.ListProcessesRequest), opts...) + } +} + func (k *kataAgent) sendReq(request interface{}) (interface{}, error) { if err := k.connect(); err != nil { return nil, err @@ -948,44 +997,11 @@ func (k *kataAgent) sendReq(request interface{}) (interface{}, error) { defer k.disconnect() } - switch req := request.(type) { - case *grpc.CheckRequest: - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _, err := k.client.Check(ctx, req) - return nil, err - case *grpc.ExecProcessRequest: - _, err := k.client.ExecProcess(context.Background(), req) - return nil, err - case *grpc.CreateSandboxRequest: - _, err := k.client.CreateSandbox(context.Background(), req) - return nil, err - case *grpc.DestroySandboxRequest: - _, err := k.client.DestroySandbox(context.Background(), req) - return nil, err - case *grpc.CreateContainerRequest: - _, err := k.client.CreateContainer(context.Background(), req) - return nil, err - case *grpc.StartContainerRequest: - _, err := k.client.StartContainer(context.Background(), req) - return nil, err - case *grpc.RemoveContainerRequest: - _, err := k.client.RemoveContainer(context.Background(), req) - return nil, err - case *grpc.SignalProcessRequest: - _, err := k.client.SignalProcess(context.Background(), req) - return nil, err - case *grpc.UpdateRoutesRequest: - _, err := k.client.UpdateRoutes(context.Background(), req) - return nil, err - case *grpc.UpdateInterfaceRequest: - ifc, err := k.client.UpdateInterface(context.Background(), req) - return ifc, err - case *grpc.OnlineCPUMemRequest: - return k.client.OnlineCPUMem(context.Background(), req) - case *grpc.ListProcessesRequest: - return k.client.ListProcesses(context.Background(), req) - default: - return nil, fmt.Errorf("Unknown gRPC type %T", req) + msgName := proto.MessageName(request.(proto.Message)) + handler := k.reqHandlers[msgName] + if msgName == "" || handler == nil { + return nil, fmt.Errorf("Invalid request type") } + + return handler(context.Background(), request) }