diff --git a/pkg/registry/core/pod/strategy.go b/pkg/registry/core/pod/strategy.go index 4569a9ec413..e7ea508d7a4 100644 --- a/pkg/registry/core/pod/strategy.go +++ b/pkg/registry/core/pod/strategy.go @@ -250,7 +250,7 @@ func getPod(getter ResourceGetter, ctx context.Context, name string) (*api.Pod, return pod, nil } -// returns primary IP for a Pod +// getPodIP returns primary IP for a Pod func getPodIP(pod *api.Pod) string { if pod == nil { return "" @@ -327,25 +327,9 @@ func LogLocation( // Try to figure out a container // If a container was provided, it must be valid container := opts.Container - if len(container) == 0 { - switch len(pod.Spec.Containers) { - case 1: - container = pod.Spec.Containers[0].Name - case 0: - return nil, nil, errors.NewBadRequest(fmt.Sprintf("a container name must be specified for pod %s", name)) - default: - containerNames := getContainerNames(pod.Spec.Containers) - initContainerNames := getContainerNames(pod.Spec.InitContainers) - err := fmt.Sprintf("a container name must be specified for pod %s, choose one of: [%s]", name, containerNames) - if len(initContainerNames) > 0 { - err += fmt.Sprintf(" or one of the init containers: [%s]", initContainerNames) - } - return nil, nil, errors.NewBadRequest(err) - } - } else { - if !podHasContainerWithName(pod, container) { - return nil, nil, errors.NewBadRequest(fmt.Sprintf("container %s is not valid for pod %s", container, name)) - } + container, err = validateContainer(container, pod) + if err != nil { + return nil, nil, err } nodeName := types.NodeName(pod.Spec.NodeName) if len(nodeName) == 0 { @@ -488,26 +472,11 @@ func streamLocation( // Try to figure out a container // If a container was provided, it must be valid - if container == "" { - switch len(pod.Spec.Containers) { - case 1: - container = pod.Spec.Containers[0].Name - case 0: - return nil, nil, errors.NewBadRequest(fmt.Sprintf("a container name must be specified for pod %s", name)) - default: - containerNames := getContainerNames(pod.Spec.Containers) - initContainerNames := getContainerNames(pod.Spec.InitContainers) - err := fmt.Sprintf("a container name must be specified for pod %s, choose one of: [%s]", name, containerNames) - if len(initContainerNames) > 0 { - err += fmt.Sprintf(" or one of the init containers: [%s]", initContainerNames) - } - return nil, nil, errors.NewBadRequest(err) - } - } else { - if !podHasContainerWithName(pod, container) { - return nil, nil, errors.NewBadRequest(fmt.Sprintf("container %s is not valid for pod %s", container, name)) - } + container, err = validateContainer(container, pod) + if err != nil { + return nil, nil, err } + nodeName := types.NodeName(pod.Spec.NodeName) if len(nodeName) == 0 { // If pod has not been assigned a host, return an empty location @@ -564,3 +533,29 @@ func PortForwardLocation( } return loc, nodeInfo.Transport, nil } + +// validateContainer validate container is valid for pod, return valid container +func validateContainer(container string, pod *api.Pod) (string, error) { + if len(container) == 0 { + switch len(pod.Spec.Containers) { + case 1: + container = pod.Spec.Containers[0].Name + case 0: + return "", errors.NewBadRequest(fmt.Sprintf("a container name must be specified for pod %s", pod.Name)) + default: + containerNames := getContainerNames(pod.Spec.Containers) + initContainerNames := getContainerNames(pod.Spec.InitContainers) + err := fmt.Sprintf("a container name must be specified for pod %s, choose one of: [%s]", pod.Name, containerNames) + if len(initContainerNames) > 0 { + err += fmt.Sprintf(" or one of the init containers: [%s]", initContainerNames) + } + return "", errors.NewBadRequest(err) + } + } else { + if !podHasContainerWithName(pod, container) { + return "", errors.NewBadRequest(fmt.Sprintf("container %s is not valid for pod %s", container, pod.Name)) + } + } + + return container, nil +} diff --git a/pkg/registry/core/pod/strategy_test.go b/pkg/registry/core/pod/strategy_test.go index 39d94bd7667..22c104dfd97 100644 --- a/pkg/registry/core/pod/strategy_test.go +++ b/pkg/registry/core/pod/strategy_test.go @@ -320,6 +320,7 @@ func (g mockPodGetter) Get(context.Context, string, *metav1.GetOptions) (runtime func TestCheckLogLocation(t *testing.T) { ctx := genericapirequest.NewDefaultContext() + fakePodName := "test" tcs := []struct { name string in *api.Pod @@ -330,6 +331,7 @@ func TestCheckLogLocation(t *testing.T) { { name: "simple", in: &api.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, Spec: api.PodSpec{ Containers: []api.Container{ {Name: "mycontainer"}, @@ -345,6 +347,7 @@ func TestCheckLogLocation(t *testing.T) { { name: "insecure", in: &api.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, Spec: api.PodSpec{ Containers: []api.Container{ {Name: "mycontainer"}, @@ -362,8 +365,9 @@ func TestCheckLogLocation(t *testing.T) { { name: "missing container", in: &api.Pod{ - Spec: api.PodSpec{}, - Status: api.PodStatus{}, + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, + Spec: api.PodSpec{}, + Status: api.PodStatus{}, }, opts: &api.PodLogOptions{}, expectedErr: errors.NewBadRequest("a container name must be specified for pod test"), @@ -372,6 +376,7 @@ func TestCheckLogLocation(t *testing.T) { { name: "choice of two containers", in: &api.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, Spec: api.PodSpec{ Containers: []api.Container{ {Name: "container1"}, @@ -387,6 +392,7 @@ func TestCheckLogLocation(t *testing.T) { { name: "initcontainers", in: &api.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, Spec: api.PodSpec{ Containers: []api.Container{ {Name: "container1"}, @@ -405,6 +411,7 @@ func TestCheckLogLocation(t *testing.T) { { name: "bad container", in: &api.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, Spec: api.PodSpec{ Containers: []api.Container{ {Name: "container1"}, @@ -422,6 +429,7 @@ func TestCheckLogLocation(t *testing.T) { { name: "good with two containers", in: &api.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: fakePodName}, Spec: api.PodSpec{ Containers: []api.Container{ {Name: "container1"}, @@ -446,7 +454,7 @@ func TestCheckLogLocation(t *testing.T) { InsecureSkipTLSVerifyTransport: fakeInsecureRoundTripper, }} - _, actualTransport, err := LogLocation(getter, connectionGetter, ctx, "test", tc.opts) + _, actualTransport, err := LogLocation(getter, connectionGetter, ctx, fakePodName, tc.opts) if !reflect.DeepEqual(err, tc.expectedErr) { t.Errorf("expected %v, got %v", tc.expectedErr, err) }