diff --git a/pkg/kubectl/cmd/rollingupdate.go b/pkg/kubectl/cmd/rollingupdate.go index 76ab39cf557..9f9b0973147 100644 --- a/pkg/kubectl/cmd/rollingupdate.go +++ b/pkg/kubectl/cmd/rollingupdate.go @@ -251,6 +251,15 @@ func addDeploymentKeyToReplicationController(oldRc *api.ReplicationController, c if err != nil { return err } + // First, update the template label. This ensures that any newly created pods will have the new label + if oldRc.Spec.Template.Labels == nil { + oldRc.Spec.Template.Labels = map[string]string{} + } + oldRc.Spec.Template.Labels[deploymentKey] = oldHash + if _, err := client.ReplicationControllers(namespace).Update(oldRc); err != nil { + return err + } + // Update all labels to include the new hash, so they are correctly adopted // TODO: extract the code from the label command and re-use it here. podList, err := client.Pods(namespace).List(labels.SelectorFromSet(oldRc.Spec.Selector), fields.Everything()) @@ -282,21 +291,30 @@ func addDeploymentKeyToReplicationController(oldRc *api.ReplicationController, c return err } } + if oldRc.Spec.Selector == nil { oldRc.Spec.Selector = map[string]string{} } - if oldRc.Spec.Template.Labels == nil { - oldRc.Spec.Template.Labels = map[string]string{} + // Copy the old selector, so that we can scrub out any orphaned pods + selectorCopy := map[string]string{} + for k, v := range oldRc.Spec.Selector { + selectorCopy[k] = v } oldRc.Spec.Selector[deploymentKey] = oldHash - oldRc.Spec.Template.Labels[deploymentKey] = oldHash if _, err := client.ReplicationControllers(namespace).Update(oldRc); err != nil { return err } - // Note there is still a race here, if a pod was created during the update phase. - // It's unlikely, but it could happen, and if it does, we'll create extra pods. - // TODO: Clean up orphaned pods here. + + podList, err = client.Pods(namespace).List(labels.SelectorFromSet(selectorCopy), fields.Everything()) + for ix := range podList.Items { + pod := &podList.Items[ix] + if value, found := pod.Labels[deploymentKey]; !found || value != oldHash { + if err := client.Pods(namespace).Delete(pod.Name); err != nil { + return err + } + } + } return nil } diff --git a/pkg/kubectl/cmd/rollingupdate_test.go b/pkg/kubectl/cmd/rollingupdate_test.go index 9598374549a..61af573dbe9 100644 --- a/pkg/kubectl/cmd/rollingupdate_test.go +++ b/pkg/kubectl/cmd/rollingupdate_test.go @@ -18,12 +18,14 @@ package cmd import ( "bytes" + "io/ioutil" "net/http" "testing" "github.com/GoogleCloudPlatform/kubernetes/pkg/api" "github.com/GoogleCloudPlatform/kubernetes/pkg/api/latest" "github.com/GoogleCloudPlatform/kubernetes/pkg/client" + "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" "github.com/GoogleCloudPlatform/kubernetes/pkg/util" ) @@ -137,12 +139,18 @@ func TestAddDeploymentHash(t *testing.T) { return &http.Response{StatusCode: 200, Body: objBody(codec, podList)}, nil case p == "/api/v1beta3/namespaces/default/pods/foo" && m == "PUT": seen.Insert("foo") + obj := readOrDie(t, req, codec) + podList.Items[0] = *(obj.(*api.Pod)) return &http.Response{StatusCode: 200, Body: objBody(codec, &podList.Items[0])}, nil case p == "/api/v1beta3/namespaces/default/pods/bar" && m == "PUT": seen.Insert("bar") + obj := readOrDie(t, req, codec) + podList.Items[1] = *(obj.(*api.Pod)) return &http.Response{StatusCode: 200, Body: objBody(codec, &podList.Items[1])}, nil case p == "/api/v1beta3/namespaces/default/pods/baz" && m == "PUT": seen.Insert("baz") + obj := readOrDie(t, req, codec) + podList.Items[2] = *(obj.(*api.Pod)) return &http.Response{StatusCode: 200, Body: objBody(codec, &podList.Items[2])}, nil case p == "/api/v1beta3/namespaces/default/replicationcontrollers/rc" && m == "PUT": updatedRc = true @@ -175,3 +183,17 @@ func TestAddDeploymentHash(t *testing.T) { t.Errorf("Failed to update replication controller with new labels") } } + +func readOrDie(t *testing.T, req *http.Request, codec runtime.Codec) runtime.Object { + data, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Error reading: %v", err) + t.FailNow() + } + obj, err := codec.Decode(data) + if err != nil { + t.Errorf("error decoding: %v", err) + t.FailNow() + } + return obj +}