From 8b6b81c7c982e9a8c0190efa0aca31627e7f1002 Mon Sep 17 00:00:00 2001
From: Yang Guo <ygg@google.com>
Date: Mon, 27 Aug 2018 16:56:18 -0700
Subject: [PATCH] GCE: Add ListLocations to Cloud TPU API

---
 pkg/cloudprovider/providers/gce/gce_tpu.go | 31 +++++++++++++++-------
 1 file changed, 22 insertions(+), 9 deletions(-)

diff --git a/pkg/cloudprovider/providers/gce/gce_tpu.go b/pkg/cloudprovider/providers/gce/gce_tpu.go
index 01d3cce275f..0a78f62cb3a 100644
--- a/pkg/cloudprovider/providers/gce/gce_tpu.go
+++ b/pkg/cloudprovider/providers/gce/gce_tpu.go
@@ -38,16 +38,14 @@ func newTPUService(client *http.Client) (*tpuService, error) {
 		return nil, err
 	}
 	return &tpuService{
-		nodesService:      tpuapi.NewProjectsLocationsNodesService(s),
-		operationsService: tpuapi.NewProjectsLocationsOperationsService(s),
+		projects: tpuapi.NewProjectsService(s),
 	}, nil
 }
 
 // tpuService encapsulates the TPU services on nodes and the operations on the
 // nodes.
 type tpuService struct {
-	nodesService      *tpuapi.ProjectsLocationsNodesService
-	operationsService *tpuapi.ProjectsLocationsOperationsService
+	projects *tpuapi.ProjectsService
 }
 
 // CreateTPU creates the Cloud TPU node with the specified name in the
@@ -59,7 +57,7 @@ func (gce *GCECloud) CreateTPU(ctx context.Context, name, zone string, node *tpu
 
 	var op *tpuapi.Operation
 	parent := getTPUParentName(gce.projectID, zone)
-	op, err = gce.tpuService.nodesService.Create(parent, node).NodeId(name).Do()
+	op, err = gce.tpuService.projects.Locations.Nodes.Create(parent, node).NodeId(name).Do()
 	if err != nil {
 		return nil, err
 	}
@@ -92,7 +90,7 @@ func (gce *GCECloud) DeleteTPU(ctx context.Context, name, zone string) error {
 
 	var op *tpuapi.Operation
 	name = getTPUName(gce.projectID, zone, name)
-	op, err = gce.tpuService.nodesService.Delete(name).Do()
+	op, err = gce.tpuService.projects.Locations.Nodes.Delete(name).Do()
 	if err != nil {
 		return err
 	}
@@ -114,7 +112,7 @@ func (gce *GCECloud) GetTPU(ctx context.Context, name, zone string) (*tpuapi.Nod
 	mc := newTPUMetricContext("get", zone)
 
 	name = getTPUName(gce.projectID, zone, name)
-	node, err := gce.tpuService.nodesService.Get(name).Do()
+	node, err := gce.tpuService.projects.Locations.Nodes.Get(name).Do()
 	if err != nil {
 		return nil, mc.Observe(err)
 	}
@@ -126,13 +124,24 @@ func (gce *GCECloud) ListTPUs(ctx context.Context, zone string) ([]*tpuapi.Node,
 	mc := newTPUMetricContext("list", zone)
 
 	parent := getTPUParentName(gce.projectID, zone)
-	response, err := gce.tpuService.nodesService.List(parent).Do()
+	response, err := gce.tpuService.projects.Locations.Nodes.List(parent).Do()
 	if err != nil {
 		return nil, mc.Observe(err)
 	}
 	return response.Nodes, mc.Observe(nil)
 }
 
+// ListLocations returns the zones where Cloud TPUs are available.
+func (gce *GCECloud) ListLocations(ctx context.Context) ([]*tpuapi.Location, error) {
+	mc := newTPUMetricContext("list_locations", "")
+	parent := getTPUProjectURL(gce.projectID)
+	response, err := gce.tpuService.projects.Locations.List(parent).Do()
+	if err != nil {
+		return nil, mc.Observe(err)
+	}
+	return response.Locations, mc.Observe(nil)
+}
+
 // waitForTPUOp checks whether the op is done every 30 seconds before the ctx
 // is cancelled.
 func (gce *GCECloud) waitForTPUOp(ctx context.Context, op *tpuapi.Operation) (*tpuapi.Operation, error) {
@@ -155,7 +164,7 @@ func (gce *GCECloud) waitForTPUOp(ctx context.Context, op *tpuapi.Operation) (*t
 		}
 
 		var err error
-		op, err = gce.tpuService.operationsService.Get(op.Name).Do()
+		op, err = gce.tpuService.projects.Locations.Operations.Get(op.Name).Do()
 		if err != nil {
 			return true, err
 		}
@@ -188,6 +197,10 @@ func getErrorFromTPUOp(op *tpuapi.Operation) error {
 	return nil
 }
 
+func getTPUProjectURL(project string) string {
+	return fmt.Sprintf("projects/%s", project)
+}
+
 func getTPUParentName(project, zone string) string {
 	return fmt.Sprintf("projects/%s/locations/%s", project, zone)
 }