diff --git a/pkg/kubectl/cmd/exec.go b/pkg/kubectl/cmd/exec.go index 5febff45ef8..787b6ed0fff 100644 --- a/pkg/kubectl/cmd/exec.go +++ b/pkg/kubectl/cmd/exec.go @@ -80,27 +80,29 @@ type execParams struct { tty bool } -func extractPodAndContainer(cmd *cobra.Command, args []string, p *execParams) (podName string, containerName string, err error) { - if len(p.podName) == 0 && len(args) == 0 { - return "", "", cmdutil.UsageError(cmd, "POD is required for exec") +func extractPodAndContainer(cmd *cobra.Command, argsIn []string, p *execParams) (podName string, containerName string, args []string, err error) { + if len(p.podName) == 0 && len(argsIn) == 0 { + return "", "", nil, cmdutil.UsageError(cmd, "POD is required for exec") } if len(p.podName) != 0 { printDeprecationWarning("exec POD", "-p POD") podName = p.podName - if len(args) < 1 { - return "", "", cmdutil.UsageError(cmd, "COMMAND is required for exec") + if len(argsIn) < 1 { + return "", "", nil, cmdutil.UsageError(cmd, "COMMAND is required for exec") } + args = argsIn } else { - podName = args[0] - if len(args) < 2 { - return "", "", cmdutil.UsageError(cmd, "COMMAND is required for exec") + podName = argsIn[0] + args = argsIn[1:] + if len(args) < 1 { + return "", "", nil, cmdutil.UsageError(cmd, "COMMAND is required for exec") } } - return podName, p.containerName, nil + return podName, p.containerName, args, nil } -func RunExec(f *cmdutil.Factory, cmd *cobra.Command, cmdIn io.Reader, cmdOut, cmdErr io.Writer, p *execParams, args []string, re remoteExecutor) error { - podName, containerName, err := extractPodAndContainer(cmd, args, p) +func RunExec(f *cmdutil.Factory, cmd *cobra.Command, cmdIn io.Reader, cmdOut, cmdErr io.Writer, p *execParams, argsIn []string, re remoteExecutor) error { + podName, containerName, args, err := extractPodAndContainer(cmd, argsIn, p) namespace, err := f.DefaultNamespace() if err != nil { return err diff --git a/pkg/kubectl/cmd/exec_test.go b/pkg/kubectl/cmd/exec_test.go index 05e5f9ecb78..31a5e448463 100644 --- a/pkg/kubectl/cmd/exec_test.go +++ b/pkg/kubectl/cmd/exec_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "net/http" + "reflect" "testing" "github.com/spf13/cobra" @@ -43,58 +44,73 @@ func TestPodAndContainer(t *testing.T) { tests := []struct { args []string p *execParams + name string expectError bool expectedPod string expectedContainer string + expectedArgs []string }{ { p: &execParams{}, expectError: true, + name: "empty", }, { p: &execParams{podName: "foo"}, expectError: true, + name: "no cmd", }, { p: &execParams{podName: "foo", containerName: "bar"}, expectError: true, + name: "no cmd, w/ container", }, { - p: &execParams{podName: "foo"}, - args: []string{"cmd"}, - expectedPod: "foo", + p: &execParams{podName: "foo"}, + args: []string{"cmd"}, + expectedPod: "foo", + expectedArgs: []string{"cmd"}, + name: "pod in flags", }, { p: &execParams{}, args: []string{"foo"}, expectError: true, + name: "no cmd, w/o flags", }, { - p: &execParams{}, - args: []string{"foo", "cmd"}, - expectedPod: "foo", + p: &execParams{}, + args: []string{"foo", "cmd"}, + expectedPod: "foo", + expectedArgs: []string{"cmd"}, + name: "cmd, w/o flags", }, { p: &execParams{containerName: "bar"}, args: []string{"foo", "cmd"}, expectedPod: "foo", expectedContainer: "bar", + expectedArgs: []string{"cmd"}, + name: "cmd, container in flag", }, } for _, test := range tests { cmd := &cobra.Command{} - podName, containerName, err := extractPodAndContainer(cmd, test.args, test.p) + podName, containerName, args, err := extractPodAndContainer(cmd, test.args, test.p) if podName != test.expectedPod { - t.Errorf("expected: %s, got: %s", test.expectedPod, podName) + t.Errorf("expected: %s, got: %s (%s)", test.expectedPod, podName, test.name) } if containerName != test.expectedContainer { - t.Errorf("expected: %s, got: %s", test.expectedContainer, containerName) + t.Errorf("expected: %s, got: %s (%s)", test.expectedContainer, containerName, test.name) } if test.expectError && err == nil { - t.Error("unexpected non-error") + t.Errorf("unexpected non-error (%s)", test.name) } if !test.expectError && err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("unexpected error: %v (%s)", err, test.name) + } + if !reflect.DeepEqual(test.expectedArgs, args) { + t.Errorf("expected: %v, got %v (%s)", test.expectedArgs, args, test.name) } } }