diff --git a/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go b/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go index 68b87ae3492..9da0e2a099c 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go +++ b/staging/src/k8s.io/apiserver/pkg/server/egressselector/egress_selector.go @@ -358,6 +358,16 @@ func NewEgressSelector(config *apiserver.EgressSelectorConfiguration) (*EgressSe return cs, nil } +// NewEgressSelectorWithMap returns a EgressSelector with the supplied EgressType to DialFunc map. +func NewEgressSelectorWithMap(m map[EgressType]utilnet.DialFunc) *EgressSelector { + if m == nil { + m = make(map[EgressType]utilnet.DialFunc) + } + return &EgressSelector{ + egressToDialer: m, + } +} + // Lookup gets the dialer function for the network context. // This is configured for the Kubernetes API Server at startup. func (cs *EgressSelector) Lookup(networkContext NetworkContext) (utilnet.DialFunc, error) { diff --git a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go index 1dabf5bb862..9249ec56b9b 100644 --- a/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go +++ b/staging/src/k8s.io/kube-aggregator/pkg/apiserver/handler_proxy_test.go @@ -17,10 +17,12 @@ limitations under the License. package apiserver import ( + "context" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -38,10 +40,12 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/pkg/util/proxy" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/authentication/user" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/apiserver/pkg/server/egressselector" "k8s.io/component-base/metrics" "k8s.io/component-base/metrics/legacyregistry" apiregistration "k8s.io/kube-aggregator/pkg/apis/apiregistration/v1" @@ -379,11 +383,42 @@ func TestProxyHandler(t *testing.T) { } } +type mockEgressDialer struct { + called int +} + +func (m *mockEgressDialer) dial(ctx context.Context, net, addr string) (net.Conn, error) { + m.called++ + return http.DefaultTransport.(*http.Transport).DialContext(ctx, net, addr) +} + +func (m *mockEgressDialer) dialBroken(ctx context.Context, net, addr string) (net.Conn, error) { + m.called++ + return nil, fmt.Errorf("Broken dialer") +} + +func newDialerAndSelector() (*mockEgressDialer, *egressselector.EgressSelector) { + dialer := &mockEgressDialer{} + m := make(map[egressselector.EgressType]utilnet.DialFunc) + m[egressselector.Cluster] = dialer.dial + es := egressselector.NewEgressSelectorWithMap(m) + return dialer, es +} + +func newBrokenDialerAndSelector() (*mockEgressDialer, *egressselector.EgressSelector) { + dialer := &mockEgressDialer{} + m := make(map[egressselector.EgressType]utilnet.DialFunc) + m[egressselector.Cluster] = dialer.dialBroken + es := egressselector.NewEgressSelectorWithMap(m) + return dialer, es +} + func TestProxyUpgrade(t *testing.T) { testcases := map[string]struct { - APIService *apiregistration.APIService - ExpectError bool - ExpectCalled bool + APIService *apiregistration.APIService + NewEgressSelector func() (*mockEgressDialer, *egressselector.EgressSelector) + ExpectError bool + ExpectCalled bool }{ "valid hostname + CABundle": { APIService: &apiregistration.APIService{ @@ -436,14 +471,49 @@ func TestProxyUpgrade(t *testing.T) { ExpectError: true, ExpectCalled: false, }, + "valid hostname + CABundle + egress selector": { + APIService: &apiregistration.APIService{ + Spec: apiregistration.APIServiceSpec{ + CABundle: testCACrt, + Group: "mygroup", + Version: "v1", + Service: &apiregistration.ServiceReference{Name: "test-service", Namespace: "test-ns", Port: pointer.Int32Ptr(443)}, + }, + Status: apiregistration.APIServiceStatus{ + Conditions: []apiregistration.APIServiceCondition{ + {Type: apiregistration.Available, Status: apiregistration.ConditionTrue}, + }, + }, + }, + NewEgressSelector: newDialerAndSelector, + ExpectError: false, + ExpectCalled: true, + }, + "valid hostname + CABundle + egress selector non working": { + APIService: &apiregistration.APIService{ + Spec: apiregistration.APIServiceSpec{ + CABundle: testCACrt, + Group: "mygroup", + Version: "v1", + Service: &apiregistration.ServiceReference{Name: "test-service", Namespace: "test-ns", Port: pointer.Int32Ptr(443)}, + }, + Status: apiregistration.APIServiceStatus{ + Conditions: []apiregistration.APIServiceCondition{ + {Type: apiregistration.Available, Status: apiregistration.ConditionTrue}, + }, + }, + }, + NewEgressSelector: newBrokenDialerAndSelector, + ExpectError: true, + ExpectCalled: false, + }, } for k, tc := range testcases { tcName := k - path := "/apis/" + tc.APIService.Spec.Group + "/" + tc.APIService.Spec.Version + "/foo" - timesCalled := int32(0) - - func() { // Cleanup after each test case. + t.Run(tcName, func(t *testing.T) { + path := "/apis/" + tc.APIService.Spec.Group + "/" + tc.APIService.Spec.Version + "/foo" + timesCalled := int32(0) backendHandler := http.NewServeMux() backendHandler.Handle(path, websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(×Called, 1) @@ -475,6 +545,14 @@ func TestProxyUpgrade(t *testing.T) { proxyTransport: &http.Transport{}, proxyCurrentCertKeyContent: func() ([]byte, []byte) { return emptyCert(), emptyCert() }, } + + var dialer *mockEgressDialer + var selector *egressselector.EgressSelector + if tc.NewEgressSelector != nil { + dialer, selector = tc.NewEgressSelector() + proxyHandler.egressSelector = selector + } + proxyHandler.updateAPIService(tc.APIService) aggregator := httptest.NewServer(contextHandler(proxyHandler, &user.DefaultInfo{Name: "username"})) defer aggregator.Close() @@ -487,6 +565,12 @@ func TestProxyUpgrade(t *testing.T) { return } defer ws.Close() + + // if the egressselector is configured assume it has to be called + if dialer != nil && dialer.called != 1 { + t.Errorf("expect egress dialer gets called %d times, got %d", 1, dialer.called) + } + if tc.ExpectError { t.Errorf("%s: expected websocket error, got none", tcName) return @@ -507,7 +591,7 @@ func TestProxyUpgrade(t *testing.T) { t.Errorf("%s: expected '%#v', got '%#v'", tcName, e, a) return } - }() + }) } }