diff --git a/pkg/kubeletclient/kubeletclient.go b/pkg/kubeletclient/kubeletclient.go index 4c19edb30..0833e42db 100644 --- a/pkg/kubeletclient/kubeletclient.go +++ b/pkg/kubeletclient/kubeletclient.go @@ -37,7 +37,6 @@ import ( const ( defaultKubeletSocket = "kubelet" // which is defined in k8s.io/kubernetes/pkg/kubelet/apis/podresources kubeletConnectionTimeout = 10 * time.Second - defaultKubeletSocketFile = "kubelet.sock" defaultPodResourcesMaxSize = 1024 * 1024 * 16 // 16 Mb defaultPodResourcesPath = "/var/lib/kubelet/pod-resources" unixProtocol = "unix" @@ -45,35 +44,28 @@ const ( // LocalEndpoint returns the full path to a unix socket at the given endpoint // which is in k8s.io/kubernetes/pkg/kubelet/util -func LocalEndpoint(path, file string) (string, error) { - u := url.URL{ +func localEndpoint(path string) *url.URL { + return &url.URL{ Scheme: unixProtocol, - Path: path, + Path: path + ".sock", } - return filepath.Join(u.String(), file+".sock"), nil -} - -func removeUnixProtocol(endpoint string) (string, error) { - u, err := url.Parse(endpoint) - if err != nil { - return "", err - } - if u.Scheme != unixProtocol { - return "", fmt.Errorf("only support unix socket endpoint") - } - return u.Path, nil } // GetResourceClient returns an instance of ResourceClient interface initialized with Pod resource information func GetResourceClient(kubeletSocket string) (types.ResourceClient, error) { - if kubeletSocket == "" { - kubeletSocket, _ = LocalEndpoint(defaultPodResourcesPath, defaultKubeletSocket) + kubeletSocketURL := localEndpoint(filepath.Join(defaultPodResourcesPath, defaultKubeletSocket)) + + if kubeletSocket != "" { + kubeletSocketURL = &url.URL{ + Scheme: unixProtocol, + Path: kubeletSocket, + } } // If Kubelet resource API endpoint exist use that by default // Or else fallback with checkpoint file - if hasKubeletAPIEndpoint(kubeletSocket) { + if hasKubeletAPIEndpoint(kubeletSocketURL) { logging.Debugf("GetResourceClient: using Kubelet resource API endpoint") - return getKubeletClient(kubeletSocket) + return getKubeletClient(kubeletSocketURL) } logging.Debugf("GetResourceClient: using Kubelet device plugin checkpoint") @@ -84,30 +76,23 @@ func dial(ctx context.Context, addr string) (net.Conn, error) { return (&net.Dialer{}).DialContext(ctx, unixProtocol, addr) } -func getKubeletResourceClient(kubeletSocket string, timeout time.Duration) (podresourcesapi.PodResourcesListerClient, *grpc.ClientConn, error) { - addr, err := removeUnixProtocol(kubeletSocket) - if err != nil { - return nil, nil, err - } +func getKubeletResourceClient(kubeletSocketURL *url.URL, timeout time.Duration) (podresourcesapi.PodResourcesListerClient, *grpc.ClientConn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - conn, err := grpc.DialContext(ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), + conn, err := grpc.DialContext(ctx, kubeletSocketURL.Path, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(dial), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(defaultPodResourcesMaxSize))) if err != nil { - return nil, nil, fmt.Errorf("error dialing socket %s: %v", kubeletSocket, err) + return nil, nil, fmt.Errorf("error dialing socket %s: %v", kubeletSocketURL.Path, err) } return podresourcesapi.NewPodResourcesListerClient(conn), conn, nil } -func getKubeletClient(kubeletSocket string) (types.ResourceClient, error) { +func getKubeletClient(kubeletSocketURL *url.URL) (types.ResourceClient, error) { newClient := &kubeletClient{} - if kubeletSocket == "" { - kubeletSocket, _ = LocalEndpoint(defaultPodResourcesPath, defaultKubeletSocket) - } - client, conn, err := getKubeletResourceClient(kubeletSocket, 10*time.Second) + client, conn, err := getKubeletResourceClient(kubeletSocketURL, 10*time.Second) if err != nil { return nil, logging.Errorf("getKubeletClient: error getting grpc client: %v\n", err) } @@ -165,13 +150,9 @@ func (rc *kubeletClient) GetPodResourceMap(pod *v1.Pod) (map[string]*types.Resou return resourceMap, nil } -func hasKubeletAPIEndpoint(endpoint string) bool { - u, err := url.Parse(endpoint) - if err != nil { - return false - } +func hasKubeletAPIEndpoint(url *url.URL) bool { // Check for kubelet resource API socket file - if _, err := os.Stat(u.Path); err != nil { + if _, err := os.Stat(url.Path); err != nil { logging.Debugf("hasKubeletAPIEndpoint: error looking up kubelet resource api socket file: %q", err) return false } diff --git a/pkg/kubeletclient/kubeletclient_test.go b/pkg/kubeletclient/kubeletclient_test.go index cb40cc161..8cd946de7 100644 --- a/pkg/kubeletclient/kubeletclient_test.go +++ b/pkg/kubeletclient/kubeletclient_test.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "net" + "net/url" "os" "path/filepath" "testing" @@ -85,7 +86,7 @@ func TestKubeletclient(t *testing.T) { RunSpecs(t, "Kubeletclient Suite") } -var testKubeletSocket string +var testKubeletSocket *url.URL // CreateListener creates a listener on the specified endpoint. // based from k8s.io/kubernetes/pkg/kubelet/util @@ -135,7 +136,7 @@ func setUp() error { socketDir = testingPodResourcesPath socketName = filepath.Join(socketDir, "kubelet.sock") - testKubeletSocket, _ = LocalEndpoint(socketDir, "kubelet") + testKubeletSocket = localEndpoint(filepath.Join(socketDir, "kubelet")) fakeServer = &fakeResourceServer{server: grpc.NewServer()} podresourcesapi.RegisterPodResourcesListerServer(fakeServer.server, fakeServer) @@ -169,7 +170,7 @@ var _ = Describe("Kubelet resource endpoint data read operations", func() { Context("GetResourceClient()", func() { It("should return no error", func() { - _, err := GetResourceClient(testKubeletSocket) + _, err := GetResourceClient(testKubeletSocket.Path) Expect(err).NotTo(HaveOccurred()) }) @@ -178,12 +179,6 @@ var _ = Describe("Kubelet resource endpoint data read operations", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error reading file")) }) - - It("should fail with invalid protocol", func() { - _, err := GetResourceClient("tcp:" + socketName) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("only support unix socket endpoint")) - }) }) Context("GetPodResourceMap() with valid pod name and namespace", func() { It("should return no error", func() { @@ -215,7 +210,9 @@ var _ = Describe("Kubelet resource endpoint data read operations", func() { }) It("should return an error with garbage socket value", func() { - _, err := getKubeletClient("/badfilepath!?//") + u, err := url.Parse("/badfilepath!?//") + Expect(err).NotTo(HaveOccurred()) + _, err = getKubeletClient(u) Expect(err).To(HaveOccurred()) }) })