Switch token bucket rate limiter to github.com/juju/ratelimit

This commit is contained in:
Jordan Liggitt 2015-06-30 09:59:37 -04:00
parent 800303c843
commit f265d5c5ee
2 changed files with 13 additions and 81 deletions

View File

@ -16,10 +16,7 @@ limitations under the License.
package util package util
import ( import "github.com/juju/ratelimit"
"sync"
"time"
)
type RateLimiter interface { type RateLimiter interface {
// CanAccept returns true if the rate is below the limit, false otherwise // CanAccept returns true if the rate is below the limit, false otherwise
@ -31,24 +28,17 @@ type RateLimiter interface {
} }
type tickRateLimiter struct { type tickRateLimiter struct {
lock sync.Mutex limiter *ratelimit.Bucket
tokens chan bool
ticker <-chan time.Time
stop chan bool
} }
// NewTokenBucketRateLimiter creates a rate limiter which implements a token bucket approach. // NewTokenBucketRateLimiter creates a rate limiter which implements a token bucket approach.
// The rate limiter allows bursts of up to 'burst' to exceed the QPS, while still maintaining a // The rate limiter allows bursts of up to 'burst' to exceed the QPS, while still maintaining a
// smoothed qps rate of 'qps'. // smoothed qps rate of 'qps'.
// The bucket is initially filled with 'burst' tokens, the rate limiter spawns a go routine // The bucket is initially filled with 'burst' tokens, and refills at a rate of 'qps'.
// which refills the bucket with one token at a rate of 'qps'. The maximum number of tokens in // The maximum number of tokens in the bucket is capped at 'burst'.
// the bucket is capped at 'burst'.
// When done with the limiter, Stop() must be called to halt the associated goroutine.
func NewTokenBucketRateLimiter(qps float32, burst int) RateLimiter { func NewTokenBucketRateLimiter(qps float32, burst int) RateLimiter {
ticker := time.Tick(time.Duration(float32(time.Second) / qps)) limiter := ratelimit.NewBucketWithRate(float64(qps), int64(burst))
rate := newTokenBucketRateLimiterFromTicker(ticker, burst) return &tickRateLimiter{limiter}
go rate.run()
return rate
} }
type fakeRateLimiter struct{} type fakeRateLimiter struct{}
@ -57,63 +47,16 @@ func NewFakeRateLimiter() RateLimiter {
return &fakeRateLimiter{} return &fakeRateLimiter{}
} }
func newTokenBucketRateLimiterFromTicker(ticker <-chan time.Time, burst int) *tickRateLimiter {
if burst < 1 {
panic("burst must be a positive integer")
}
rate := &tickRateLimiter{
tokens: make(chan bool, burst),
ticker: ticker,
stop: make(chan bool),
}
for i := 0; i < burst; i++ {
rate.tokens <- true
}
return rate
}
func (t *tickRateLimiter) CanAccept() bool { func (t *tickRateLimiter) CanAccept() bool {
select { return t.limiter.TakeAvailable(1) == 1
case <-t.tokens:
return true
default:
return false
}
} }
// Accept will block until a token becomes available // Accept will block until a token becomes available
func (t *tickRateLimiter) Accept() { func (t *tickRateLimiter) Accept() {
<-t.tokens t.limiter.Wait(1)
} }
func (t *tickRateLimiter) Stop() { func (t *tickRateLimiter) Stop() {
close(t.stop)
}
func (r *tickRateLimiter) run() {
for {
if !r.step() {
break
}
}
}
func (r *tickRateLimiter) step() bool {
select {
case <-r.ticker:
r.increment()
return true
case <-r.stop:
return false
}
}
func (t *tickRateLimiter) increment() {
// non-blocking send
select {
case t.tokens <- true:
default:
}
} }
func (t *fakeRateLimiter) CanAccept() bool { func (t *fakeRateLimiter) CanAccept() bool {

View File

@ -22,8 +22,7 @@ import (
) )
func TestBasicThrottle(t *testing.T) { func TestBasicThrottle(t *testing.T) {
ticker := make(chan time.Time, 1) r := NewTokenBucketRateLimiter(1, 3)
r := newTokenBucketRateLimiterFromTicker(ticker, 3)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if !r.CanAccept() { if !r.CanAccept() {
t.Error("unexpected false accept") t.Error("unexpected false accept")
@ -35,32 +34,22 @@ func TestBasicThrottle(t *testing.T) {
} }
func TestIncrementThrottle(t *testing.T) { func TestIncrementThrottle(t *testing.T) {
ticker := make(chan time.Time, 1) r := NewTokenBucketRateLimiter(1, 1)
r := newTokenBucketRateLimiterFromTicker(ticker, 1)
if !r.CanAccept() { if !r.CanAccept() {
t.Error("unexpected false accept") t.Error("unexpected false accept")
} }
if r.CanAccept() { if r.CanAccept() {
t.Error("unexpected true accept") t.Error("unexpected true accept")
} }
ticker <- time.Now()
r.step() // Allow to refill
time.Sleep(2 * time.Second)
if !r.CanAccept() { if !r.CanAccept() {
t.Error("unexpected false accept") t.Error("unexpected false accept")
} }
} }
func TestOverBurst(t *testing.T) {
ticker := make(chan time.Time, 1)
r := newTokenBucketRateLimiterFromTicker(ticker, 3)
for i := 0; i < 4; i++ {
ticker <- time.Now()
r.step()
}
}
func TestThrottle(t *testing.T) { func TestThrottle(t *testing.T) {
r := NewTokenBucketRateLimiter(10, 5) r := NewTokenBucketRateLimiter(10, 5)