diff --git a/pkg/kubectl/cmd/drain.go b/pkg/kubectl/cmd/drain.go index a25b2476226..461b3efdfec 100644 --- a/pkg/kubectl/cmd/drain.go +++ b/pkg/kubectl/cmd/drain.go @@ -20,8 +20,8 @@ import ( "errors" "fmt" "io" + "k8s.io/apimachinery/pkg/util/json" "math" - "reflect" "strings" "time" @@ -33,8 +33,11 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/wait" restclient "k8s.io/client-go/rest" + "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/apis/policy" "k8s.io/kubernetes/pkg/client/clientset_generated/internalclientset" @@ -43,7 +46,6 @@ import ( "k8s.io/kubernetes/pkg/kubectl/cmd/templates" cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util" "k8s.io/kubernetes/pkg/kubectl/resource" - "k8s.io/kubernetes/pkg/kubelet/types" "k8s.io/kubernetes/pkg/util/i18n" ) @@ -84,7 +86,6 @@ const ( kLocalStorageWarning = "Deleting pods with local storage" kUnmanagedFatal = "pods not managed by ReplicationController, ReplicaSet, Job, DaemonSet or StatefulSet (use --force to override)" kUnmanagedWarning = "Deleting pods not managed by ReplicationController, ReplicaSet, Job, DaemonSet or StatefulSet" - kMaxNodeUpdateRetry = 10 ) var ( @@ -353,7 +354,7 @@ func (o *DrainOptions) daemonsetFilter(pod api.Pod) (bool, *warning, *fatal) { } func mirrorPodFilter(pod api.Pod) (bool, *warning, *fatal) { - if _, found := pod.ObjectMeta.Annotations[types.ConfigMirrorAnnotationKey]; found { + if _, found := pod.ObjectMeta.Annotations[corev1.MirrorPodAnnotationKey]; found { return false, nil, nil } return true, nil, nil @@ -621,27 +622,28 @@ func (o *DrainOptions) RunCordonOrUncordon(desired bool) error { } if o.nodeInfo.Mapping.GroupVersionKind.Kind == "Node" { - unsched := reflect.ValueOf(o.nodeInfo.Object).Elem().FieldByName("Spec").FieldByName("Unschedulable") - if unsched.Bool() == desired { + obj, err := o.nodeInfo.Mapping.ConvertToVersion(o.nodeInfo.Object, o.nodeInfo.Mapping.GroupVersionKind.GroupVersion()) + if err != nil { + return err + } + oldData, err := json.Marshal(obj) + node, ok := obj.(*corev1.Node) + if !ok { + return fmt.Errorf("unexpected Type%T, expected Node", obj) + } + unsched := node.Spec.Unschedulable + if unsched == desired { cmdutil.PrintSuccess(o.mapper, false, o.Out, o.nodeInfo.Mapping.Resource, o.nodeInfo.Name, false, already(desired)) } else { helper := resource.NewHelper(o.restClient, o.nodeInfo.Mapping) - unsched.SetBool(desired) + node.Spec.Unschedulable = desired var err error - for i := 0; i < kMaxNodeUpdateRetry; i++ { - // We don't care about what previous versions may exist, we always want - // to overwrite, and Replace always sets current ResourceVersion if version is "". - helper.Versioner.SetResourceVersion(o.nodeInfo.Object, "") - _, err = helper.Replace(cmdNamespace, o.nodeInfo.Name, true, o.nodeInfo.Object) - if err != nil { - if !apierrors.IsConflict(err) { - return err - } - } else { - break - } - // It's a race, no need to sleep + newData, err := json.Marshal(obj) + patchBytes, err := strategicpatch.CreateTwoWayMergePatch(oldData, newData, obj) + if err != nil { + return err } + _, err = helper.Patch(cmdNamespace, o.nodeInfo.Name, types.StrategicMergePatchType, patchBytes) if err != nil { return err } diff --git a/pkg/kubectl/cmd/drain_test.go b/pkg/kubectl/cmd/drain_test.go index 36ea363f48e..7a28cf62f40 100644 --- a/pkg/kubectl/cmd/drain_test.go +++ b/pkg/kubectl/cmd/drain_test.go @@ -33,11 +33,13 @@ import ( "github.com/spf13/cobra" + "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/rest/fake" "k8s.io/kubernetes/pkg/api" @@ -55,27 +57,27 @@ const ( DeleteMethod = "Delete" ) -var node *api.Node -var cordoned_node *api.Node +var node *v1.Node +var cordoned_node *v1.Node func boolptr(b bool) *bool { return &b } func TestMain(m *testing.M) { // Create a node. - node = &api.Node{ + node = &v1.Node{ ObjectMeta: metav1.ObjectMeta{ Name: "node", CreationTimestamp: metav1.Time{Time: time.Now()}, }, - Spec: api.NodeSpec{ + Spec: v1.NodeSpec{ ExternalID: "node", }, - Status: api.NodeStatus{}, + Status: v1.NodeStatus{}, } clone, _ := api.Scheme.DeepCopy(node) // A copy of the same node, but cordoned. - cordoned_node = clone.(*api.Node) + cordoned_node = clone.(*v1.Node) cordoned_node.Spec.Unschedulable = true os.Exit(m.Run()) } @@ -83,8 +85,8 @@ func TestMain(m *testing.M) { func TestCordon(t *testing.T) { tests := []struct { description string - node *api.Node - expected *api.Node + node *v1.Node + expected *v1.Node cmd func(cmdutil.Factory, io.Writer) *cobra.Command arg string expectFatal bool @@ -149,7 +151,7 @@ func TestCordon(t *testing.T) { for _, test := range tests { f, tf, codec, ns := cmdtesting.NewAPIFactory() - new_node := &api.Node{} + new_node := &v1.Node{} updated := false tf.Client = &fake.RESTClient{ APIRegistry: api.Registry, @@ -161,17 +163,25 @@ func TestCordon(t *testing.T) { return &http.Response{StatusCode: 200, Header: defaultHeader(), Body: objBody(codec, test.node)}, nil case m.isFor("GET", "/nodes/bar"): return &http.Response{StatusCode: 404, Header: defaultHeader(), Body: stringBody("nope")}, nil - case m.isFor("PUT", "/nodes/node"): + case m.isFor("PATCH", "/nodes/node"): data, err := ioutil.ReadAll(req.Body) if err != nil { t.Fatalf("%s: unexpected error: %v", test.description, err) } defer req.Body.Close() - if err := runtime.DecodeInto(codec, data, new_node); err != nil { + oldJSON, err := runtime.Encode(codec, node) + if err != nil { + t.Fatalf("%s: unexpected error: %v", test.description, err) + } + appliedPatch, err := strategicpatch.StrategicMergePatch(oldJSON, data, &v1.Node{}) + if err != nil { + t.Fatalf("%s: unexpected error: %v", test.description, err) + } + if err := runtime.DecodeInto(codec, appliedPatch, new_node); err != nil { t.Fatalf("%s: unexpected error: %v", test.description, err) } if !reflect.DeepEqual(test.expected.Spec, new_node.Spec) { - t.Fatalf("%s: expected:\n%v\nsaw:\n%v\n", test.description, test.expected.Spec, new_node.Spec) + t.Fatalf("%s: expected:\n%v\nsaw:\n%v\n", test.description, test.expected.Spec.Unschedulable, new_node.Spec.Unschedulable) } updated = true return &http.Response{StatusCode: 200, Header: defaultHeader(), Body: objBody(codec, new_node)}, nil @@ -443,8 +453,8 @@ func TestDrain(t *testing.T) { tests := []struct { description string - node *api.Node - expected *api.Node + node *v1.Node + expected *v1.Node pods []api.Pod rcs []api.ReplicationController replicaSets []extensions.ReplicaSet @@ -582,7 +592,7 @@ func TestDrain(t *testing.T) { currMethod = DeleteMethod } for _, test := range tests { - new_node := &api.Node{} + new_node := &v1.Node{} deleted := false evicted := false f, tf, codec, ns := cmdtesting.NewAPIFactory() @@ -649,13 +659,21 @@ func TestDrain(t *testing.T) { return &http.Response{StatusCode: 200, Header: defaultHeader(), Body: objBody(codec, &api.PodList{Items: test.pods})}, nil case m.isFor("GET", "/replicationcontrollers"): return &http.Response{StatusCode: 200, Header: defaultHeader(), Body: objBody(codec, &api.ReplicationControllerList{Items: test.rcs})}, nil - case m.isFor("PUT", "/nodes/node"): + case m.isFor("PATCH", "/nodes/node"): data, err := ioutil.ReadAll(req.Body) if err != nil { t.Fatalf("%s: unexpected error: %v", test.description, err) } defer req.Body.Close() - if err := runtime.DecodeInto(codec, data, new_node); err != nil { + oldJSON, err := runtime.Encode(codec, node) + if err != nil { + t.Fatalf("%s: unexpected error: %v", test.description, err) + } + appliedPatch, err := strategicpatch.StrategicMergePatch(oldJSON, data, &v1.Node{}) + if err != nil { + t.Fatalf("%s: unexpected error: %v", test.description, err) + } + if err := runtime.DecodeInto(codec, appliedPatch, new_node); err != nil { t.Fatalf("%s: unexpected error: %v", test.description, err) } if !reflect.DeepEqual(test.expected.Spec, new_node.Spec) { @@ -692,7 +710,6 @@ func TestDrain(t *testing.T) { cmd.SetArgs(test.args) cmd.Execute() }() - if test.expectFatal { if !saw_fatal { t.Fatalf("%s: unexpected non-error when using %s", test.description, currMethod)