refactor(pod log):refactor for container valiate, little cleanup

bug(pod log):TestCheckLogLocation should point out pod name

modify container if statement

fix typo
This commit is contained in:
yuzhiquan 2019-12-25 19:43:33 +08:00
parent 6c1080b3ca
commit ca69051475
2 changed files with 45 additions and 42 deletions

View File

@ -250,7 +250,7 @@ func getPod(getter ResourceGetter, ctx context.Context, name string) (*api.Pod,
return pod, nil
}
// returns primary IP for a Pod
// getPodIP returns primary IP for a Pod
func getPodIP(pod *api.Pod) string {
if pod == nil {
return ""
@ -327,25 +327,9 @@ func LogLocation(
// Try to figure out a container
// If a container was provided, it must be valid
container := opts.Container
if len(container) == 0 {
switch len(pod.Spec.Containers) {
case 1:
container = pod.Spec.Containers[0].Name
case 0:
return nil, nil, errors.NewBadRequest(fmt.Sprintf("a container name must be specified for pod %s", name))
default:
containerNames := getContainerNames(pod.Spec.Containers)
initContainerNames := getContainerNames(pod.Spec.InitContainers)
err := fmt.Sprintf("a container name must be specified for pod %s, choose one of: [%s]", name, containerNames)
if len(initContainerNames) > 0 {
err += fmt.Sprintf(" or one of the init containers: [%s]", initContainerNames)
}
return nil, nil, errors.NewBadRequest(err)
}
} else {
if !podHasContainerWithName(pod, container) {
return nil, nil, errors.NewBadRequest(fmt.Sprintf("container %s is not valid for pod %s", container, name))
}
container, err = validateContainer(container, pod)
if err != nil {
return nil, nil, err
}
nodeName := types.NodeName(pod.Spec.NodeName)
if len(nodeName) == 0 {
@ -488,26 +472,11 @@ func streamLocation(
// Try to figure out a container
// If a container was provided, it must be valid
if container == "" {
switch len(pod.Spec.Containers) {
case 1:
container = pod.Spec.Containers[0].Name
case 0:
return nil, nil, errors.NewBadRequest(fmt.Sprintf("a container name must be specified for pod %s", name))
default:
containerNames := getContainerNames(pod.Spec.Containers)
initContainerNames := getContainerNames(pod.Spec.InitContainers)
err := fmt.Sprintf("a container name must be specified for pod %s, choose one of: [%s]", name, containerNames)
if len(initContainerNames) > 0 {
err += fmt.Sprintf(" or one of the init containers: [%s]", initContainerNames)
}
return nil, nil, errors.NewBadRequest(err)
}
} else {
if !podHasContainerWithName(pod, container) {
return nil, nil, errors.NewBadRequest(fmt.Sprintf("container %s is not valid for pod %s", container, name))
}
container, err = validateContainer(container, pod)
if err != nil {
return nil, nil, err
}
nodeName := types.NodeName(pod.Spec.NodeName)
if len(nodeName) == 0 {
// If pod has not been assigned a host, return an empty location
@ -564,3 +533,29 @@ func PortForwardLocation(
}
return loc, nodeInfo.Transport, nil
}
// validateContainer validate container is valid for pod, return valid container
func validateContainer(container string, pod *api.Pod) (string, error) {
if len(container) == 0 {
switch len(pod.Spec.Containers) {
case 1:
container = pod.Spec.Containers[0].Name
case 0:
return "", errors.NewBadRequest(fmt.Sprintf("a container name must be specified for pod %s", pod.Name))
default:
containerNames := getContainerNames(pod.Spec.Containers)
initContainerNames := getContainerNames(pod.Spec.InitContainers)
err := fmt.Sprintf("a container name must be specified for pod %s, choose one of: [%s]", pod.Name, containerNames)
if len(initContainerNames) > 0 {
err += fmt.Sprintf(" or one of the init containers: [%s]", initContainerNames)
}
return "", errors.NewBadRequest(err)
}
} else {
if !podHasContainerWithName(pod, container) {
return "", errors.NewBadRequest(fmt.Sprintf("container %s is not valid for pod %s", container, pod.Name))
}
}
return container, nil
}

View File

@ -320,6 +320,7 @@ func (g mockPodGetter) Get(context.Context, string, *metav1.GetOptions) (runtime
func TestCheckLogLocation(t *testing.T) {
ctx := genericapirequest.NewDefaultContext()
fakePodName := "test"
tcs := []struct {
name string
in *api.Pod
@ -330,6 +331,7 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "simple",
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{
Containers: []api.Container{
{Name: "mycontainer"},
@ -345,6 +347,7 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "insecure",
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{
Containers: []api.Container{
{Name: "mycontainer"},
@ -362,8 +365,9 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "missing container",
in: &api.Pod{
Spec: api.PodSpec{},
Status: api.PodStatus{},
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{},
Status: api.PodStatus{},
},
opts: &api.PodLogOptions{},
expectedErr: errors.NewBadRequest("a container name must be specified for pod test"),
@ -372,6 +376,7 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "choice of two containers",
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{
Containers: []api.Container{
{Name: "container1"},
@ -387,6 +392,7 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "initcontainers",
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{
Containers: []api.Container{
{Name: "container1"},
@ -405,6 +411,7 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "bad container",
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{
Containers: []api.Container{
{Name: "container1"},
@ -422,6 +429,7 @@ func TestCheckLogLocation(t *testing.T) {
{
name: "good with two containers",
in: &api.Pod{
ObjectMeta: metav1.ObjectMeta{Name: fakePodName},
Spec: api.PodSpec{
Containers: []api.Container{
{Name: "container1"},
@ -446,7 +454,7 @@ func TestCheckLogLocation(t *testing.T) {
InsecureSkipTLSVerifyTransport: fakeInsecureRoundTripper,
}}
_, actualTransport, err := LogLocation(getter, connectionGetter, ctx, "test", tc.opts)
_, actualTransport, err := LogLocation(getter, connectionGetter, ctx, fakePodName, tc.opts)
if !reflect.DeepEqual(err, tc.expectedErr) {
t.Errorf("expected %v, got %v", tc.expectedErr, err)
}