Merge pull request #83967 from mgugino-upstream-stage/kubectl-drain-timeout

kubectl drain: avoid leaking goroutines
This commit is contained in:
Kubernetes Prow Robot 2019-10-22 23:19:22 -07:00 committed by GitHub
commit 9a7201c6b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 16 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
package drain package drain
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"math" "math"
@ -184,7 +185,6 @@ func (d *Helper) DeleteOrEvictPods(pods []corev1.Pod) error {
getPodFn := func(namespace, name string) (*corev1.Pod, error) { getPodFn := func(namespace, name string) (*corev1.Pod, error) {
return d.Client.CoreV1().Pods(namespace).Get(name, metav1.GetOptions{}) return d.Client.CoreV1().Pods(namespace).Get(name, metav1.GetOptions{})
} }
if len(policyGroupVersion) > 0 { if len(policyGroupVersion) > 0 {
return d.evictPods(pods, policyGroupVersion, getPodFn) return d.evictPods(pods, policyGroupVersion, getPodFn)
} }
@ -194,11 +194,26 @@ func (d *Helper) DeleteOrEvictPods(pods []corev1.Pod) error {
func (d *Helper) evictPods(pods []corev1.Pod, policyGroupVersion string, getPodFn func(namespace, name string) (*corev1.Pod, error)) error { func (d *Helper) evictPods(pods []corev1.Pod, policyGroupVersion string, getPodFn func(namespace, name string) (*corev1.Pod, error)) error {
returnCh := make(chan error, 1) returnCh := make(chan error, 1)
// 0 timeout means infinite, we use MaxInt64 to represent it.
var globalTimeout time.Duration
if d.Timeout == 0 {
globalTimeout = time.Duration(math.MaxInt64)
} else {
globalTimeout = d.Timeout
}
ctx, cancel := context.WithTimeout(context.TODO(), globalTimeout)
defer cancel()
for _, pod := range pods { for _, pod := range pods {
go func(pod corev1.Pod, returnCh chan error) { go func(pod corev1.Pod, returnCh chan error) {
for { for {
fmt.Fprintf(d.Out, "evicting pod %q\n", pod.Name) fmt.Fprintf(d.Out, "evicting pod %q\n", pod.Name)
select {
case <-ctx.Done():
// return here or we'll leak a goroutine.
returnCh <- fmt.Errorf("error when evicting pod %q: global timeout reached: %v", pod.Name, globalTimeout)
return
default:
}
err := d.EvictPod(pod, policyGroupVersion) err := d.EvictPod(pod, policyGroupVersion)
if err == nil { if err == nil {
break break
@ -213,7 +228,7 @@ func (d *Helper) evictPods(pods []corev1.Pod, policyGroupVersion string, getPodF
return return
} }
} }
_, err := waitForDelete([]corev1.Pod{pod}, 1*time.Second, time.Duration(math.MaxInt64), true, getPodFn, d.OnPodDeletedOrEvicted) _, err := waitForDelete(ctx, []corev1.Pod{pod}, 1*time.Second, time.Duration(math.MaxInt64), true, getPodFn, d.OnPodDeletedOrEvicted, globalTimeout)
if err == nil { if err == nil {
returnCh <- nil returnCh <- nil
} else { } else {
@ -225,14 +240,6 @@ func (d *Helper) evictPods(pods []corev1.Pod, policyGroupVersion string, getPodF
doneCount := 0 doneCount := 0
var errors []error var errors []error
// 0 timeout means infinite, we use MaxInt64 to represent it.
var globalTimeout time.Duration
if d.Timeout == 0 {
globalTimeout = time.Duration(math.MaxInt64)
} else {
globalTimeout = d.Timeout
}
globalTimeoutCh := time.After(globalTimeout)
numPods := len(pods) numPods := len(pods)
for doneCount < numPods { for doneCount < numPods {
select { select {
@ -241,10 +248,10 @@ func (d *Helper) evictPods(pods []corev1.Pod, policyGroupVersion string, getPodF
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
} }
case <-globalTimeoutCh: default:
return fmt.Errorf("drain did not complete within %v", globalTimeout)
} }
} }
return utilerrors.NewAggregate(errors) return utilerrors.NewAggregate(errors)
} }
@ -262,11 +269,12 @@ func (d *Helper) deletePods(pods []corev1.Pod, getPodFn func(namespace, name str
return err return err
} }
} }
_, err := waitForDelete(pods, 1*time.Second, globalTimeout, false, getPodFn, d.OnPodDeletedOrEvicted) ctx := context.TODO()
_, err := waitForDelete(ctx, pods, 1*time.Second, globalTimeout, false, getPodFn, d.OnPodDeletedOrEvicted, globalTimeout)
return err return err
} }
func waitForDelete(pods []corev1.Pod, interval, timeout time.Duration, usingEviction bool, getPodFn func(string, string) (*corev1.Pod, error), onDoneFn func(pod *corev1.Pod, usingEviction bool)) ([]corev1.Pod, error) { func waitForDelete(ctx context.Context, pods []corev1.Pod, interval, timeout time.Duration, usingEviction bool, getPodFn func(string, string) (*corev1.Pod, error), onDoneFn func(pod *corev1.Pod, usingEviction bool), globalTimeout time.Duration) ([]corev1.Pod, error) {
err := wait.PollImmediate(interval, timeout, func() (bool, error) { err := wait.PollImmediate(interval, timeout, func() (bool, error) {
pendingPods := []corev1.Pod{} pendingPods := []corev1.Pod{}
for i, pod := range pods { for i, pod := range pods {
@ -284,6 +292,12 @@ func waitForDelete(pods []corev1.Pod, interval, timeout time.Duration, usingEvic
} }
pods = pendingPods pods = pendingPods
if len(pendingPods) > 0 { if len(pendingPods) > 0 {
select {
case <-ctx.Done():
return false, fmt.Errorf("global timeout reached: %v", globalTimeout)
default:
return false, nil
}
return false, nil return false, nil
} }
return true, nil return true, nil

View File

@ -17,8 +17,10 @@ limitations under the License.
package drain package drain
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"math"
"os" "os"
"reflect" "reflect"
"sort" "sort"
@ -105,7 +107,8 @@ func TestDeletePods(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
_, pods := createPods(false) _, pods := createPods(false)
pendingPods, err := waitForDelete(pods, test.interval, test.timeout, false, test.getPodFn, nil) ctx := context.TODO()
pendingPods, err := waitForDelete(ctx, pods, test.interval, test.timeout, false, test.getPodFn, nil, time.Duration(math.MaxInt64))
if test.expectError { if test.expectError {
if err == nil { if err == nil {