GCE: Add ListLocations to Cloud TPU API

This commit is contained in:
Yang Guo 2018-08-27 16:56:18 -07:00
parent 3da79f5cab
commit 8b6b81c7c9

View File

@ -38,16 +38,14 @@ func newTPUService(client *http.Client) (*tpuService, error) {
return nil, err
}
return &tpuService{
nodesService: tpuapi.NewProjectsLocationsNodesService(s),
operationsService: tpuapi.NewProjectsLocationsOperationsService(s),
projects: tpuapi.NewProjectsService(s),
}, nil
}
// tpuService encapsulates the TPU services on nodes and the operations on the
// nodes.
type tpuService struct {
nodesService *tpuapi.ProjectsLocationsNodesService
operationsService *tpuapi.ProjectsLocationsOperationsService
projects *tpuapi.ProjectsService
}
// CreateTPU creates the Cloud TPU node with the specified name in the
@ -59,7 +57,7 @@ func (gce *GCECloud) CreateTPU(ctx context.Context, name, zone string, node *tpu
var op *tpuapi.Operation
parent := getTPUParentName(gce.projectID, zone)
op, err = gce.tpuService.nodesService.Create(parent, node).NodeId(name).Do()
op, err = gce.tpuService.projects.Locations.Nodes.Create(parent, node).NodeId(name).Do()
if err != nil {
return nil, err
}
@ -92,7 +90,7 @@ func (gce *GCECloud) DeleteTPU(ctx context.Context, name, zone string) error {
var op *tpuapi.Operation
name = getTPUName(gce.projectID, zone, name)
op, err = gce.tpuService.nodesService.Delete(name).Do()
op, err = gce.tpuService.projects.Locations.Nodes.Delete(name).Do()
if err != nil {
return err
}
@ -114,7 +112,7 @@ func (gce *GCECloud) GetTPU(ctx context.Context, name, zone string) (*tpuapi.Nod
mc := newTPUMetricContext("get", zone)
name = getTPUName(gce.projectID, zone, name)
node, err := gce.tpuService.nodesService.Get(name).Do()
node, err := gce.tpuService.projects.Locations.Nodes.Get(name).Do()
if err != nil {
return nil, mc.Observe(err)
}
@ -126,13 +124,24 @@ func (gce *GCECloud) ListTPUs(ctx context.Context, zone string) ([]*tpuapi.Node,
mc := newTPUMetricContext("list", zone)
parent := getTPUParentName(gce.projectID, zone)
response, err := gce.tpuService.nodesService.List(parent).Do()
response, err := gce.tpuService.projects.Locations.Nodes.List(parent).Do()
if err != nil {
return nil, mc.Observe(err)
}
return response.Nodes, mc.Observe(nil)
}
// ListLocations returns the zones where Cloud TPUs are available.
func (gce *GCECloud) ListLocations(ctx context.Context) ([]*tpuapi.Location, error) {
mc := newTPUMetricContext("list_locations", "")
parent := getTPUProjectURL(gce.projectID)
response, err := gce.tpuService.projects.Locations.List(parent).Do()
if err != nil {
return nil, mc.Observe(err)
}
return response.Locations, mc.Observe(nil)
}
// waitForTPUOp checks whether the op is done every 30 seconds before the ctx
// is cancelled.
func (gce *GCECloud) waitForTPUOp(ctx context.Context, op *tpuapi.Operation) (*tpuapi.Operation, error) {
@ -155,7 +164,7 @@ func (gce *GCECloud) waitForTPUOp(ctx context.Context, op *tpuapi.Operation) (*t
}
var err error
op, err = gce.tpuService.operationsService.Get(op.Name).Do()
op, err = gce.tpuService.projects.Locations.Operations.Get(op.Name).Do()
if err != nil {
return true, err
}
@ -188,6 +197,10 @@ func getErrorFromTPUOp(op *tpuapi.Operation) error {
return nil
}
func getTPUProjectURL(project string) string {
return fmt.Sprintf("projects/%s", project)
}
func getTPUParentName(project, zone string) string {
return fmt.Sprintf("projects/%s/locations/%s", project, zone)
}