diff --git a/pkg/proxy/util/utils.go b/pkg/proxy/util/utils.go index 98655939bee..ad2cf88de74 100644 --- a/pkg/proxy/util/utils.go +++ b/pkg/proxy/util/utils.go @@ -391,7 +391,13 @@ func NewFilteredDialContext(wrapped DialContext, resolv Resolver, opts *Filtered return wrapped } return func(ctx context.Context, network, address string) (net.Conn, error) { - resp, err := resolv.LookupIPAddr(ctx, address) + // DialContext is given host:port. LookupIPAddress expects host. + addressToResolve, _, err := net.SplitHostPort(address) + if err != nil { + addressToResolve = address + } + + resp, err := resolv.LookupIPAddr(ctx, addressToResolve) if err != nil { return nil, err } diff --git a/pkg/proxy/util/utils_test.go b/pkg/proxy/util/utils_test.go index aa1e6e18aba..895b3b7c245 100644 --- a/pkg/proxy/util/utils_test.go +++ b/pkg/proxy/util/utils_test.go @@ -277,6 +277,125 @@ func TestShouldSkipService(t *testing.T) { } } +func TestNewFilteredDialContext(t *testing.T) { + + _, cidr, _ := net.ParseCIDR("1.1.1.1/28") + + testCases := []struct { + name string + + // opts passed to NewFilteredDialContext + opts *FilteredDialOptions + + // value passed to dial + dial string + + // value expected to be passed to resolve + expectResolve string + // result from resolver + resolveTo []net.IPAddr + resolveErr error + + // expect the wrapped dialer to be called + expectWrappedDial bool + // expect an error result + expectErr string + }{ + { + name: "allow with nil opts", + opts: nil, + dial: "127.0.0.1:8080", + expectResolve: "", // resolver not called, no-op opts + expectWrappedDial: true, + expectErr: "", + }, + { + name: "allow localhost", + opts: &FilteredDialOptions{AllowLocalLoopback: true}, + dial: "127.0.0.1:8080", + expectResolve: "", // resolver not called, no-op opts + expectWrappedDial: true, + expectErr: "", + }, + { + name: "disallow localhost", + opts: &FilteredDialOptions{AllowLocalLoopback: false}, + dial: "127.0.0.1:8080", + expectResolve: "127.0.0.1", + resolveTo: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, + expectWrappedDial: false, + expectErr: "address not allowed", + }, + { + name: "disallow IP", + opts: &FilteredDialOptions{AllowLocalLoopback: false, DialHostCIDRDenylist: []*net.IPNet{cidr}}, + dial: "foo.com:8080", + expectResolve: "foo.com", + resolveTo: []net.IPAddr{{IP: net.ParseIP("1.1.1.1")}}, + expectWrappedDial: false, + expectErr: "address not allowed", + }, + { + name: "allow IP", + opts: &FilteredDialOptions{AllowLocalLoopback: false, DialHostCIDRDenylist: []*net.IPNet{cidr}}, + dial: "foo.com:8080", + expectResolve: "foo.com", + resolveTo: []net.IPAddr{{IP: net.ParseIP("2.2.2.2")}}, + expectWrappedDial: true, + expectErr: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wrappedDialer := &testDialer{} + testResolver := &testResolver{addrs: tc.resolveTo, err: tc.resolveErr} + dialer := NewFilteredDialContext(wrappedDialer.DialContext, testResolver, tc.opts) + _, err := dialer(context.TODO(), "tcp", tc.dial) + + if tc.expectResolve != testResolver.resolveAddress { + t.Fatalf("expected to resolve %s, got %s", tc.expectResolve, testResolver.resolveAddress) + } + if tc.expectWrappedDial != wrappedDialer.called { + t.Fatalf("expected wrapped dialer called %v, got %v", tc.expectWrappedDial, wrappedDialer.called) + } + + if err != nil { + if len(tc.expectErr) == 0 { + t.Fatalf("unexpected error: %v", err) + } else if !strings.Contains(err.Error(), tc.expectErr) { + t.Fatalf("expected error containing %q, got %v", tc.expectErr, err) + } + } else { + if len(tc.expectErr) > 0 { + t.Fatalf("expected error, got none") + } + } + }) + } +} + +type testDialer struct { + called bool +} + +func (t *testDialer) DialContext(_ context.Context, network, address string) (net.Conn, error) { + t.called = true + return nil, nil +} + +type testResolver struct { + addrs []net.IPAddr + err error + + resolveAddress string +} + +func (t *testResolver) LookupIPAddr(_ context.Context, address string) ([]net.IPAddr, error) { + t.resolveAddress = address + return t.addrs, t.err +} + type InterfaceAddrsPair struct { itf net.Interface addrs []net.Addr