From 59a6489e84714fab7ddea8bf7bd274d6c05ed7d8 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Wed, 25 Jun 2014 13:21:32 -0700 Subject: [PATCH] Add tracked operations to apiserver --- pkg/apiserver/apiserver.go | 130 ++++++++++++++---------- pkg/apiserver/apiserver_test.go | 120 ++++++++++++++++------ pkg/apiserver/operation.go | 175 ++++++++++++++++++++++++++++++++ pkg/apiserver/operation_test.go | 69 +++++++++++++ 4 files changed, 410 insertions(+), 84 deletions(-) create mode 100644 pkg/apiserver/operation.go create mode 100644 pkg/apiserver/operation_test.go diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 5507d6f0fe7..4b7bc537c7e 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -41,11 +41,25 @@ type RESTStorage interface { Update(interface{}) (<-chan interface{}, error) } -func MakeAsync(fn func() interface{}) <-chan interface{} { - channel := make(chan interface{}, 1) +// MakeAsync takes a function and executes it, delivering the result in the way required +// by RESTStorage's Update, Delete, and Create methods. +func MakeAsync(fn func() (interface{}, error)) <-chan interface{} { + channel := make(chan interface{}) go func() { defer util.HandleCrash() - channel <- fn() + obj, err := fn() + if err != nil { + channel <- &api.Status{ + Status: api.StatusFailure, + Details: err.Error(), + } + } else { + channel <- obj + } + // 'close' is used to signal that no further values will + // be written to the channel. Not strictly necessary, but + // also won't hurt. + close(channel) }() return channel } @@ -59,6 +73,7 @@ func MakeAsync(fn func() interface{}) <-chan interface{} { type ApiServer struct { prefix string storage map[string]RESTStorage + ops *Operations } // New creates a new ApiServer object. @@ -68,6 +83,7 @@ func New(storage map[string]RESTStorage, prefix string) *ApiServer { return &ApiServer{ storage: storage, prefix: prefix, + ops: NewOperations(), } } @@ -108,6 +124,10 @@ func (server *ApiServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { server.notFound(req, w) return } + if requestParts[0] == "operations" { + server.handleOperationRequest(requestParts[1:], w, req) + return + } storage := server.storage[requestParts[0]] if storage == nil { logger.Addf("'%v' has no storage object", requestParts[0]) @@ -144,15 +164,30 @@ func (server *ApiServer) readBody(req *http.Request) ([]byte, error) { return body, err } -func (server *ApiServer) waitForObject(out <-chan interface{}, timeout time.Duration) (interface{}, error) { - tick := time.After(timeout) - var obj interface{} - select { - case obj = <-out: - return obj, nil - case <-tick: - return nil, fmt.Errorf("Timed out waiting for synchronization.") +// finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an +// Operation to recieve the result and returning its ID down the writer. +func (server *ApiServer) finishReq(out <-chan interface{}, sync bool, timeout time.Duration, w http.ResponseWriter) { + op := server.ops.NewOperation(out) + if sync { + op.WaitFor(timeout) } + obj, complete := op.Describe() + if complete { + server.write(http.StatusOK, obj, w) + } else { + server.write(http.StatusAccepted, obj, w) + } +} + +func parseTimeout(str string) time.Duration { + if str != "" { + timeout, err := time.ParseDuration(str) + if err == nil { + return timeout + } + glog.Errorf("Failed to parse: %#v '%s'", err, str) + } + return 30 * time.Second } // handleREST is the main dispatcher for the server. It switches on the HTTP method, and then @@ -170,11 +205,7 @@ func (server *ApiServer) waitForObject(out <-chan interface{}, timeout time.Dura // labels= Used for filtering list operations func (server *ApiServer) handleREST(parts []string, requestUrl *url.URL, req *http.Request, w http.ResponseWriter, storage RESTStorage) { sync := requestUrl.Query().Get("sync") == "true" - timeout, err := time.ParseDuration(requestUrl.Query().Get("timeout")) - if err != nil && len(requestUrl.Query().Get("timeout")) > 0 { - glog.Errorf("Failed to parse: %#v '%s'", err, requestUrl.Query().Get("timeout")) - timeout = time.Second * 30 - } + timeout := parseTimeout(requestUrl.Query().Get("timeout")) switch req.Method { case "GET": switch len(parts) { @@ -184,12 +215,12 @@ func (server *ApiServer) handleREST(parts []string, requestUrl *url.URL, req *ht server.error(err, w) return } - controllers, err := storage.List(selector) + list, err := storage.List(selector) if err != nil { server.error(err, w) return } - server.write(http.StatusOK, controllers, w) + server.write(http.StatusOK, list, w) case 2: item, err := storage.Get(parts[1]) if err != nil { @@ -204,7 +235,6 @@ func (server *ApiServer) handleREST(parts []string, requestUrl *url.URL, req *ht default: server.notFound(req, w) } - return case "POST": if len(parts) != 1 { server.notFound(req, w) @@ -221,44 +251,22 @@ func (server *ApiServer) handleREST(parts []string, requestUrl *url.URL, req *ht return } out, err := storage.Create(obj) - if err == nil && sync { - obj, err = server.waitForObject(out, timeout) - } if err != nil { server.error(err, w) return } - var statusCode int - if sync { - statusCode = http.StatusOK - } else { - statusCode = http.StatusAccepted - } - server.write(statusCode, obj, w) - return + server.finishReq(out, sync, timeout, w) case "DELETE": if len(parts) != 2 { server.notFound(req, w) return } out, err := storage.Delete(parts[1]) - var obj interface{} - obj = api.Status{Status: api.StatusSuccess} - if err == nil && sync { - obj, err = server.waitForObject(out, timeout) - } if err != nil { server.error(err, w) return } - var statusCode int - if sync { - statusCode = http.StatusOK - } else { - statusCode = http.StatusAccepted - } - server.write(statusCode, obj, w) - return + server.finishReq(out, sync, timeout, w) case "PUT": if len(parts) != 2 { server.notFound(req, w) @@ -274,22 +282,36 @@ func (server *ApiServer) handleREST(parts []string, requestUrl *url.URL, req *ht return } out, err := storage.Update(obj) - if err == nil && sync { - obj, err = server.waitForObject(out, timeout) - } if err != nil { server.error(err, w) return } - var statusCode int - if sync { - statusCode = http.StatusOK - } else { - statusCode = http.StatusAccepted - } - server.write(statusCode, obj, w) - return + server.finishReq(out, sync, timeout, w) default: server.notFound(req, w) } } + +func (server *ApiServer) handleOperationRequest(parts []string, w http.ResponseWriter, req *http.Request) { + if req.Method != "GET" { + server.notFound(req, w) + } + if len(parts) == 0 { + // List outstanding operations. + list := server.ops.List() + server.write(http.StatusOK, list, w) + return + } + + op := server.ops.Get(parts[0]) + if op == nil { + server.notFound(req, w) + } + + obj, complete := op.Describe() + if complete { + server.write(http.StatusOK, obj, w) + } else { + server.write(http.StatusAccepted, obj, w) + } +} diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 8b0e535feeb..45a597b3346 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -18,7 +18,6 @@ package apiserver import ( "bytes" - "encoding/json" "fmt" "io/ioutil" "net/http" @@ -26,6 +25,7 @@ import ( "reflect" "sync" "testing" + "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/api" "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" @@ -58,7 +58,10 @@ type SimpleRESTStorage struct { item Simple deleted string updated Simple - channel <-chan interface{} + created Simple + + // called when answering update, delete, create + injectedFunction func(obj interface{}) (returnObj interface{}, err error) } func (storage *SimpleRESTStorage) List(labels.Selector) (interface{}, error) { @@ -74,7 +77,15 @@ func (storage *SimpleRESTStorage) Get(id string) (interface{}, error) { func (storage *SimpleRESTStorage) Delete(id string) (<-chan interface{}, error) { storage.deleted = id - return storage.channel, storage.err + if storage.err != nil { + return nil, storage.err + } + return MakeAsync(func() (interface{}, error) { + if storage.injectedFunction != nil { + return storage.injectedFunction(id) + } + return api.Status{Status: api.StatusSuccess}, nil + }), nil } func (storage *SimpleRESTStorage) Extract(body []byte) (interface{}, error) { @@ -83,13 +94,30 @@ func (storage *SimpleRESTStorage) Extract(body []byte) (interface{}, error) { return item, storage.err } -func (storage *SimpleRESTStorage) Create(interface{}) (<-chan interface{}, error) { - return storage.channel, storage.err +func (storage *SimpleRESTStorage) Create(obj interface{}) (<-chan interface{}, error) { + storage.created = obj.(Simple) + if storage.err != nil { + return nil, storage.err + } + return MakeAsync(func() (interface{}, error) { + if storage.injectedFunction != nil { + return storage.injectedFunction(obj) + } + return obj, nil + }), nil } -func (storage *SimpleRESTStorage) Update(object interface{}) (<-chan interface{}, error) { - storage.updated = object.(Simple) - return storage.channel, storage.err +func (storage *SimpleRESTStorage) Update(obj interface{}) (<-chan interface{}, error) { + storage.updated = obj.(Simple) + if storage.err != nil { + return nil, storage.err + } + return MakeAsync(func() (interface{}, error) { + if storage.injectedFunction != nil { + return storage.injectedFunction(obj) + } + return obj, nil + }), nil } func extractBody(response *http.Response, object interface{}) (string, error) { @@ -214,7 +242,7 @@ func TestUpdate(t *testing.T) { item := Simple{ Name: "bar", } - body, err := json.Marshal(item) + body, err := api.Encode(item) expectNoError(t, err) client := http.Client{} request, err := http.NewRequest("PUT", server.URL+"/prefix/version/simple/"+ID, bytes.NewReader(body)) @@ -270,14 +298,15 @@ func TestMissingStorage(t *testing.T) { } func TestCreate(t *testing.T) { + simpleStorage := &SimpleRESTStorage{} handler := New(map[string]RESTStorage{ - "foo": &SimpleRESTStorage{}, + "foo": simpleStorage, }, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} simple := Simple{Name: "foo"} - data, _ := json.Marshal(simple) + data, _ := api.Encode(simple) request, err := http.NewRequest("POST", server.URL+"/prefix/version/foo", bytes.NewBuffer(data)) expectNoError(t, err) response, err := client.Do(request) @@ -286,18 +315,32 @@ func TestCreate(t *testing.T) { t.Errorf("Unexpected response %#v", response) } - var itemOut Simple + var itemOut api.Status body, err := extractBody(response, &itemOut) expectNoError(t, err) - if !reflect.DeepEqual(itemOut, simple) { - t.Errorf("Unexpected data: %#v, expected %#v (%s)", itemOut, simple, string(body)) + if itemOut.Status != api.StatusWorking || itemOut.Details == "" { + t.Errorf("Unexpected status: %#v (%s)", itemOut, string(body)) + } +} + +func TestParseTimeout(t *testing.T) { + if d := parseTimeout(""); d != 30*time.Second { + t.Errorf("blank timeout produces %v", d) + } + if d := parseTimeout("not a timeout"); d != 30*time.Second { + t.Errorf("bad timeout produces %v", d) + } + if d := parseTimeout("10s"); d != 10*time.Second { + t.Errorf("10s timeout produced: %v", d) } } func TestSyncCreate(t *testing.T) { - channel := make(chan interface{}, 1) storage := SimpleRESTStorage{ - channel: channel, + injectedFunction: func(obj interface{}) (interface{}, error) { + time.Sleep(2 * time.Second) + return obj, nil + }, } handler := New(map[string]RESTStorage{ "foo": &storage, @@ -306,7 +349,7 @@ func TestSyncCreate(t *testing.T) { client := http.Client{} simple := Simple{Name: "foo"} - data, _ := json.Marshal(simple) + data, _ := api.Encode(simple) request, err := http.NewRequest("POST", server.URL+"/prefix/version/foo?sync=true", bytes.NewBuffer(data)) expectNoError(t, err) wg := sync.WaitGroup{} @@ -314,37 +357,54 @@ func TestSyncCreate(t *testing.T) { var response *http.Response go func() { response, err = client.Do(request) - expectNoError(t, err) - if response.StatusCode != 200 { - t.Errorf("Unexpected response %#v", response) - } wg.Done() }() - output := Simple{Name: "bar"} - channel <- output wg.Wait() + expectNoError(t, err) var itemOut Simple body, err := extractBody(response, &itemOut) expectNoError(t, err) - if !reflect.DeepEqual(itemOut, output) { + if !reflect.DeepEqual(itemOut, simple) { t.Errorf("Unexpected data: %#v, expected %#v (%s)", itemOut, simple, string(body)) } + if response.StatusCode != http.StatusOK { + t.Errorf("Unexpected status: %d, Expected: %d, %#v", response.StatusCode, http.StatusOK, response) + } } func TestSyncCreateTimeout(t *testing.T) { + storage := SimpleRESTStorage{ + injectedFunction: func(obj interface{}) (interface{}, error) { + time.Sleep(10 * time.Second) + return obj, nil + }, + } handler := New(map[string]RESTStorage{ - "foo": &SimpleRESTStorage{}, + "foo": &storage, }, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} simple := Simple{Name: "foo"} - data, _ := json.Marshal(simple) - request, err := http.NewRequest("POST", server.URL+"/prefix/version/foo?sync=true&timeout=1us", bytes.NewBuffer(data)) + data, _ := api.Encode(simple) + request, err := http.NewRequest("POST", server.URL+"/prefix/version/foo?sync=true&timeout=2s", bytes.NewBuffer(data)) expectNoError(t, err) - response, err := client.Do(request) + wg := sync.WaitGroup{} + wg.Add(1) + var response *http.Response + go func() { + response, err = client.Do(request) + wg.Done() + }() + wg.Wait() expectNoError(t, err) - if response.StatusCode != 500 { - t.Errorf("Unexpected response %#v", response) + var itemOut api.Status + _, err = extractBody(response, &itemOut) + expectNoError(t, err) + if itemOut.Status != api.StatusWorking || itemOut.Details == "" { + t.Errorf("Unexpected status %#v", itemOut) + } + if response.StatusCode != http.StatusAccepted { + t.Errorf("Unexpected status: %d, Expected: %d, %#v", response.StatusCode, 202, response) } } diff --git a/pkg/apiserver/operation.go b/pkg/apiserver/operation.go new file mode 100644 index 00000000000..41c20c07e54 --- /dev/null +++ b/pkg/apiserver/operation.go @@ -0,0 +1,175 @@ +/* +Copyright 2014 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package apiserver + +import ( + "fmt" + "sort" + "sync" + "time" + + "github.com/GoogleCloudPlatform/kubernetes/pkg/api" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util" +) + +func init() { + api.AddKnownTypes(ServerOp{}, ServerOpList{}) +} + +// Operation information, as delivered to API clients. +type ServerOp struct { + api.JSONBase `yaml:",inline" json:",inline"` +} + +// Operation list, as delivered to API clients. +type ServerOpList struct { + api.JSONBase `yaml:",inline" json:",inline"` + Items []ServerOp `yaml:"items,omitempty" json:"items,omitempty"` +} + +// Operation represents an ongoing action which the server is performing. +type Operation struct { + ID string + result interface{} + awaiting <-chan interface{} + finished *time.Time + lock sync.Mutex + notify chan bool +} + +// Operations tracks all the ongoing operations. +type Operations struct { + lock sync.Mutex + ops map[string]*Operation + nextID int +} + +// Returns a new Operations repository. +func NewOperations() *Operations { + ops := &Operations{ + ops: map[string]*Operation{}, + } + go util.Forever(func() { ops.expire(10 * time.Minute) }, 5*time.Minute) + return ops +} + +// Add a new operation. +func (ops *Operations) NewOperation(from <-chan interface{}) *Operation { + ops.lock.Lock() + defer ops.lock.Unlock() + id := fmt.Sprintf("%v", ops.nextID) + ops.nextID++ + + op := &Operation{ + ID: id, + awaiting: from, + notify: make(chan bool, 1), + } + go op.wait() + ops.ops[id] = op + return op +} + +// List operations for an API client. +func (ops *Operations) List() ServerOpList { + ops.lock.Lock() + defer ops.lock.Unlock() + + ids := []string{} + for id := range ops.ops { + ids = append(ids, id) + } + sort.StringSlice(ids).Sort() + ol := ServerOpList{} + for _, id := range ids { + ol.Items = append(ol.Items, ServerOp{JSONBase: api.JSONBase{ID: id}}) + } + return ol +} + +// Returns the operation with the given ID, or nil +func (ops *Operations) Get(id string) *Operation { + ops.lock.Lock() + defer ops.lock.Unlock() + return ops.ops[id] +} + +// Garbage collect operations that have finished longer than maxAge ago. +func (ops *Operations) expire(maxAge time.Duration) { + ops.lock.Lock() + defer ops.lock.Unlock() + keep := map[string]*Operation{} + limitTime := time.Now().Add(-maxAge) + for id, op := range ops.ops { + if !op.expired(limitTime) { + keep[id] = op + } + } + ops.ops = keep +} + +// Waits forever for the operation to complete; call via go when +// the operation is created. Sets op.finished when the operation +// does complete. Does not keep op locked while waiting. +func (op *Operation) wait() { + defer util.HandleCrash() + result := <-op.awaiting + + op.lock.Lock() + defer op.lock.Unlock() + op.result = result + finished := time.Now() + op.finished = &finished + op.notify <- true +} + +// Wait for the specified duration, or until the operation finishes, +// whichever happens first. +func (op *Operation) WaitFor(timeout time.Duration) { + select { + case <-time.After(timeout): + case <-op.notify: + // Re-send on this channel in case there are others + // waiting for notification. + op.notify <- true + } +} + +// Returns true if this operation finished before limitTime. +func (op *Operation) expired(limitTime time.Time) bool { + op.lock.Lock() + defer op.lock.Unlock() + if op.finished == nil { + return false + } + return op.finished.Before(limitTime) +} + +// Return status information or the result of the operation if it is complete, +// with a bool indicating true in the latter case. +func (op *Operation) Describe() (description interface{}, finished bool) { + op.lock.Lock() + defer op.lock.Unlock() + + if op.finished == nil { + return api.Status{ + Status: api.StatusWorking, + Details: op.ID, + }, false + } + return op.result, true +} diff --git a/pkg/apiserver/operation_test.go b/pkg/apiserver/operation_test.go new file mode 100644 index 00000000000..df3e64b4aca --- /dev/null +++ b/pkg/apiserver/operation_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2014 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package apiserver + +import ( + "testing" + "time" +) + +func TestOperation(t *testing.T) { + ops := NewOperations() + + c := make(chan interface{}) + op := ops.NewOperation(c) + go func() { + time.Sleep(5 * time.Second) + c <- "All done" + }() + + if op.expired(time.Now().Add(-time.Minute)) { + t.Errorf("Expired before finished: %#v", op) + } + ops.expire(time.Minute) + if tmp := ops.Get(op.ID); tmp == nil { + t.Errorf("expire incorrectly removed the operation %#v", ops) + } + + op.WaitFor(time.Second) + if _, completed := op.Describe(); completed { + t.Errorf("Unexpectedly fast completion") + } + + op.WaitFor(5 * time.Second) + if _, completed := op.Describe(); !completed { + t.Errorf("Unexpectedly slow completion") + } + + time.Sleep(900 * time.Millisecond) + + if op.expired(time.Now().Add(-time.Second)) { + t.Errorf("Should not be expired: %#v", op) + } + if !op.expired(time.Now().Add(-800 * time.Millisecond)) { + t.Errorf("Should be expired: %#v", op) + } + + ops.expire(800 * time.Millisecond) + if tmp := ops.Get(op.ID); tmp != nil { + t.Errorf("expire failed to remove the operation %#v", ops) + } + + if op.result.(string) != "All done" { + t.Errorf("Got unexpected result: %#v", op.result) + } +}