diff --git a/pkg/service/endpoints_controller.go b/pkg/service/endpoints_controller.go index 93eddb6f893..6fd0d963467 100644 --- a/pkg/service/endpoints_controller.go +++ b/pkg/service/endpoints_controller.go @@ -138,21 +138,36 @@ func endpointsEqual(e *api.Endpoints, endpoints []string) bool { // findPort locates the container port for the given manifest and portName. func findPort(manifest *api.ContainerManifest, portName util.IntOrString) (int, error) { - if ((portName.Kind == util.IntstrString && len(portName.StrVal) == 0) || - (portName.Kind == util.IntstrInt && portName.IntVal == 0)) && - len(manifest.Containers[0].Ports) > 0 { - return manifest.Containers[0].Ports[0].ContainerPort, nil + firstContainerPort := 0 + if len(manifest.Containers[0].Ports) > 0 { + firstContainerPort = manifest.Containers[0].Ports[0].ContainerPort } - if portName.Kind == util.IntstrInt { - return portName.IntVal, nil - } - name := portName.StrVal - for _, container := range manifest.Containers { - for _, port := range container.Ports { - if port.Name == name { - return port.ContainerPort, nil + + switch portName.Kind { + case util.IntstrString: + if len(portName.StrVal) == 0 { + if firstContainerPort != 0 { + return firstContainerPort, nil + } + break + } + name := portName.StrVal + for _, container := range manifest.Containers { + for _, port := range container.Ports { + if port.Name == name { + return port.ContainerPort, nil + } } } + case util.IntstrInt: + if portName.IntVal == 0 { + if firstContainerPort != 0 { + return firstContainerPort, nil + } + break + } + return portName.IntVal, nil } - return -1, fmt.Errorf("no suitable port for manifest: %s", manifest.ID) + + return 0, fmt.Errorf("no suitable port for manifest: %s", manifest.ID) } diff --git a/pkg/service/endpoints_controller_test.go b/pkg/service/endpoints_controller_test.go index 4833721304c..e456bd3c97a 100644 --- a/pkg/service/endpoints_controller_test.go +++ b/pkg/service/endpoints_controller_test.go @@ -79,45 +79,86 @@ func TestFindPort(t *testing.T) { }, }, } - port, err := findPort(&manifest, util.IntOrString{Kind: util.IntstrString, StrVal: "foo"}) - if err != nil { - t.Errorf("unexpected error: %v", err) + emptyPortsManifest := api.ContainerManifest{ + Containers: []api.Container{ + { + Ports: []api.Port{}, + }, + }, } - if port != 8080 { - t.Errorf("Expected 8080, Got %d", port) + tests := []struct { + manifest api.ContainerManifest + portName util.IntOrString + + wport int + werr bool + }{ + { + manifest, + util.IntOrString{Kind: util.IntstrString, StrVal: "foo"}, + 8080, + false, + }, + { + manifest, + util.IntOrString{Kind: util.IntstrString, StrVal: "bar"}, + 8000, + false, + }, + { + manifest, + util.IntOrString{Kind: util.IntstrInt, IntVal: 8000}, + 8000, + false, + }, + { + manifest, + util.IntOrString{Kind: util.IntstrInt, IntVal: 7000}, + 7000, + false, + }, + { + manifest, + util.IntOrString{Kind: util.IntstrString, StrVal: "baz"}, + 0, + true, + }, + { + manifest, + util.IntOrString{Kind: util.IntstrString, StrVal: ""}, + 8080, + false, + }, + { + manifest, + util.IntOrString{Kind: util.IntstrInt, IntVal: 0}, + 8080, + false, + }, + { + emptyPortsManifest, + util.IntOrString{Kind: util.IntstrString, StrVal: ""}, + 0, + true, + }, + { + emptyPortsManifest, + util.IntOrString{Kind: util.IntstrInt, IntVal: 0}, + 0, + true, + }, } - port, err = findPort(&manifest, util.IntOrString{Kind: util.IntstrString, StrVal: "bar"}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if port != 8000 { - t.Errorf("Expected 8000, Got %d", port) - } - port, err = findPort(&manifest, util.IntOrString{Kind: util.IntstrInt, IntVal: 8000}) - if port != 8000 { - t.Errorf("Expected 8000, Got %d", port) - } - port, err = findPort(&manifest, util.IntOrString{Kind: util.IntstrInt, IntVal: 7000}) - if port != 7000 { - t.Errorf("Expected 7000, Got %d", port) - } - port, err = findPort(&manifest, util.IntOrString{Kind: util.IntstrString, StrVal: "baz"}) - if err == nil { - t.Error("unexpected non-error") - } - port, err = findPort(&manifest, util.IntOrString{Kind: util.IntstrString, StrVal: ""}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if port != 8080 { - t.Errorf("Expected 8080, Got %d", port) - } - port, err = findPort(&manifest, util.IntOrString{}) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if port != 8080 { - t.Errorf("Expected 8080, Got %d", port) + for _, test := range tests { + port, err := findPort(&test.manifest, test.portName) + if port != test.wport { + t.Errorf("Expected port %d, Got %d", test.wport, port) + } + if err == nil && test.werr { + t.Errorf("unexpected non-error") + } + if err != nil && test.werr == false { + t.Errorf("unexpected error: %v", err) + } } }