Skip to content

Commit fb3be44

Browse files
authored
Merge pull request #14 from vearne/develop
Develop
2 parents 0f3a0c8 + ce91e7d commit fb3be44

25 files changed

+630
-9006
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ go get github.com/vearne/ratelimit
4040
```
4141
### Usage
4242
#### 1. create redis.Client
43-
with "github.com/go-redis/redis"
43+
with "github.com/redis/go-redis"
4444
Supports both redis master-slave mode and cluster mode
4545
```
4646
client := redis.NewClient(&redis.Options{
@@ -132,15 +132,17 @@ package main
132132
import (
133133
"context"
134134
"fmt"
135-
"github.com/go-redis/redis/v8"
135+
"github.com/redis/go-redis/v9"
136136
"github.com/vearne/ratelimit"
137+
"github.com/vearne/ratelimit/counter"
138+
"github.com/vearne/ratelimit/tokenbucket"
137139
slog "github.com/vearne/simplelog"
138140
"sync"
139141
"time"
140142
)
141143
142144
func consume(r ratelimit.Limiter, group *sync.WaitGroup,
143-
c *ratelimit.Counter, targetCount int) {
145+
c *counter.Counter, targetCount int) {
144146
defer group.Done()
145147
var ok bool
146148
for {
@@ -168,7 +170,7 @@ func main() {
168170
DB: 0, // use default DB
169171
})
170172
171-
limiter, err := ratelimit.NewTokenBucketRateLimiter(
173+
limiter, err := tokenbucket.NewTokenBucketRateLimiter(
172174
context.Background(),
173175
client,
174176
"key:token",
@@ -184,7 +186,7 @@ func main() {
184186
185187
var wg sync.WaitGroup
186188
total := 50
187-
counter := ratelimit.NewCounter()
189+
counter := counter.NewCounter()
188190
start := time.Now()
189191
for i := 0; i < 10; i++ {
190192
wg.Add(1)
@@ -197,7 +199,7 @@ func main() {
197199
```
198200

199201
### Dependency
200-
[go-redis/redis](https://github.com/go-redis/redis)
202+
[redis/go-redis](https://github.com/redis/go-redis)
201203

202204
### Thanks
203205
The development of the module was inspired by the Reference 1.

README_zh.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ go get github.com/vearne/ratelimit
4242
```
4343
## 用法
4444
### 1. 创建 redis.Client
45-
依赖 "github.com/go-redis/redis"
45+
依赖 "github.com/redis/go-redis"
4646
同时支持redis 主从模式和cluster模式
4747
```
4848
client := redis.NewClient(&redis.Options{
@@ -136,15 +136,17 @@ package main
136136
import (
137137
"context"
138138
"fmt"
139-
"github.com/go-redis/redis/v8"
139+
"github.com/redis/go-redis/v9"
140140
"github.com/vearne/ratelimit"
141+
"github.com/vearne/ratelimit/counter"
142+
"github.com/vearne/ratelimit/tokenbucket"
141143
slog "github.com/vearne/simplelog"
142144
"sync"
143145
"time"
144146
)
145147
146148
func consume(r ratelimit.Limiter, group *sync.WaitGroup,
147-
c *ratelimit.Counter, targetCount int) {
149+
c *counter.Counter, targetCount int) {
148150
defer group.Done()
149151
var ok bool
150152
for {
@@ -172,7 +174,7 @@ func main() {
172174
DB: 0, // use default DB
173175
})
174176
175-
limiter, err := ratelimit.NewTokenBucketRateLimiter(
177+
limiter, err := tokenbucket.NewTokenBucketRateLimiter(
176178
context.Background(),
177179
client,
178180
"key:token",
@@ -188,7 +190,7 @@ func main() {
188190
189191
var wg sync.WaitGroup
190192
total := 50
191-
counter := ratelimit.NewCounter()
193+
counter := counter.NewCounter()
192194
start := time.Now()
193195
for i := 0; i < 10; i++ {
194196
wg.Add(1)
@@ -200,7 +202,7 @@ func main() {
200202
}
201203
```
202204
### 依赖
203-
[go-redis/redis](https://github.com/go-redis/redis)
205+
[redis/go-redis](https://github.com/redis/go-redis)
204206

205207
### 致谢
206208
模块的开发受到了资料1的启发,在此表示感谢

alg.go

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ return increment
3737
key ->
3838
token_count -> {token_count}
3939
updateTime -> {lastUpdateTime}* 1000000 + {microsecond}
40-
*/
40+
*/
4141

4242
const TokenBucketScript = `
4343
local bucket = KEYS[1]
@@ -85,14 +85,12 @@ end
8585
return count
8686
`
8787

88-
8988
/*
90-
key Type: string
91-
92-
// updateTime
93-
key -> {lastUpdateTime}* 1000000 + {microsecond}
89+
key Type: string
9490
95-
*/
91+
// updateTime
92+
key -> {lastUpdateTime}* 1000000 + {microsecond}
93+
*/
9694
const LeakyBucketScript = `
9795
local bucket = KEYS[1]
9896
local interval = tonumber(ARGV[1])
@@ -117,15 +115,13 @@ end
117115
return count
118116
`
119117

120-
121118
var (
122-
algMap map[int]string
119+
AlgMap map[int]string
123120
)
124121

125-
126-
func init(){
127-
algMap = make(map[int]string)
128-
algMap[CounterAlg] = counterScript
129-
algMap[TokenBucketAlg] = TokenBucketScript
130-
algMap[LeakyBucketAlg] = LeakyBucketScript
131-
}
122+
func init() {
123+
AlgMap = make(map[int]string)
124+
AlgMap[CounterAlg] = counterScript
125+
AlgMap[TokenBucketAlg] = TokenBucketScript
126+
AlgMap[LeakyBucketAlg] = LeakyBucketScript
127+
}

counter.go renamed to counter/counter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package ratelimit
1+
package counter
22

33
import "sync"
44

counter_limiter.go renamed to counter/counter_limiter.go

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
package ratelimit
1+
package counter
22

33
import (
44
"context"
55
"crypto/sha1"
66
"errors"
77
"fmt"
8-
"github.com/go-redis/redis/v8"
8+
"github.com/redis/go-redis/v9"
9+
"github.com/vearne/ratelimit"
910
slog "github.com/vearne/simplelog"
1011
"golang.org/x/sync/singleflight"
1112
"golang.org/x/time/rate"
@@ -14,7 +15,7 @@ import (
1415

1516
//nolint:govet
1617
type CounterLimiter struct {
17-
BaseRateLimiter
18+
ratelimit.BaseRateLimiter
1819
duration time.Duration
1920
throughput int
2021
batchSize int
@@ -29,9 +30,11 @@ type CounterLimiter struct {
2930
antiDDoSLimiter *rate.Limiter
3031
}
3132

33+
type Option func(*CounterLimiter)
34+
3235
func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string, duration time.Duration,
3336
throughput int,
34-
batchSize int) (Limiter, error) {
37+
batchSize int, opts ...Option) (ratelimit.Limiter, error) {
3538

3639
_, err := client.Ping(ctx).Result()
3740
if err != nil {
@@ -50,23 +53,35 @@ func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string
5053
return nil, errors.New("batchSize must greater than 0")
5154
}
5255

53-
script := algMap[CounterAlg]
56+
script := ratelimit.AlgMap[ratelimit.CounterAlg]
5457
scriptSHA1 := fmt.Sprintf("%x", sha1.Sum([]byte(script)))
5558

5659
r := CounterLimiter{
57-
BaseRateLimiter: BaseRateLimiter{redisClient: client, scriptSHA1: scriptSHA1, key: key},
60+
BaseRateLimiter: ratelimit.BaseRateLimiter{RedisClient: client, ScriptSHA1: scriptSHA1, Key: key},
5861
duration: duration,
5962
throughput: throughput,
6063
batchSize: batchSize,
6164
N: 0,
6265
AntiDDoS: true,
6366
}
64-
r.interval = duration / time.Duration(throughput)
67+
r.Interval = duration / time.Duration(throughput)
6568

66-
if !r.redisClient.ScriptExists(ctx, r.scriptSHA1).Val()[0] {
67-
r.redisClient.ScriptLoad(ctx, script).Val()
69+
// Loop through each option
70+
for _, opt := range opts {
71+
// Call the option giving the instantiated
72+
opt(&r)
6873
}
6974

75+
values, err := r.RedisClient.ScriptExists(ctx, r.ScriptSHA1).Result()
76+
if err != nil {
77+
return nil, err
78+
}
79+
if !values[0] {
80+
_, err = r.RedisClient.ScriptLoad(ctx, script).Result()
81+
if err != nil {
82+
return nil, err
83+
}
84+
}
7085
// 2x throughput
7186
throughputPerSec := int(float64(throughput) / float64(duration/time.Second))
7287
r.antiDDoSLimiter = rate.NewLimiter(rate.Limit(throughputPerSec*2), throughputPerSec*2)
@@ -75,8 +90,10 @@ func NewCounterRateLimiter(ctx context.Context, client redis.Cmdable, key string
7590
}
7691

7792
// just for test
78-
func (r *CounterLimiter) WithAntiDDos(antiDDoS bool) {
79-
r.AntiDDoS = antiDDoS
93+
func WithAntiDDos(antiDDoS bool) Option {
94+
return func(r *CounterLimiter) {
95+
r.AntiDDoS = antiDDoS
96+
}
8097
}
8198

8299
func (r *CounterLimiter) tryTakeFromLocal() bool {
@@ -101,7 +118,7 @@ func (r *CounterLimiter) Wait(ctx context.Context) (err error) {
101118
}
102119

103120
deadline, ok := ctx.Deadline()
104-
minWaitTime := r.interval
121+
minWaitTime := r.Interval
105122

106123
slog.Debug("minWaitTime:%v", minWaitTime)
107124
if ok {
@@ -143,11 +160,11 @@ func (r *CounterLimiter) Take(ctx context.Context) (bool, error) {
143160
}
144161

145162
// 2. try to get from redis
146-
_, err, _ := r.g.Do(r.key, func() (interface{}, error) {
147-
x, err := r.redisClient.EvalSha(
163+
_, err, _ := r.g.Do(r.Key, func() (interface{}, error) {
164+
x, err := r.RedisClient.EvalSha(
148165
ctx,
149-
r.scriptSHA1,
150-
[]string{r.key},
166+
r.ScriptSHA1,
167+
[]string{r.Key},
151168
int(r.duration/time.Microsecond),
152169
r.throughput,
153170
r.batchSize,

counter/ratelimit_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package counter
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/go-redis/redismock/v9"
7+
"github.com/stretchr/testify/assert"
8+
"log"
9+
"testing"
10+
"time"
11+
)
12+
13+
const (
14+
key = "key:count"
15+
hashVal = "bdbede5669d5e48d6e6c2967aeed2f72f03868ac"
16+
)
17+
18+
func MyMatch(expected, actual []interface{}) error {
19+
expectedStr := fmt.Sprintf("%v", expected)
20+
actualStr := fmt.Sprintf("%v", actual)
21+
if expectedStr == actualStr {
22+
return nil
23+
}
24+
log.Printf("expectedStr:%v, actualStr:%v", expectedStr, actualStr)
25+
return fmt.Errorf("not equal, expectedStr:%s, actualStr:%s", expectedStr, actualStr)
26+
}
27+
28+
func TestTakeFail(t *testing.T) {
29+
db, mock := redismock.NewClientMock()
30+
31+
mock = mock.CustomMatch(MyMatch)
32+
mock.ExpectPing().SetVal("PONG")
33+
34+
mock.ExpectScriptExists(hashVal).SetVal([]bool{true})
35+
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(0))
36+
37+
limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
38+
3,
39+
2,
40+
WithAntiDDos(false))
41+
if err != nil {
42+
t.Errorf("unexpected error, %v", err)
43+
return
44+
}
45+
46+
ok, err := limiter.Take(context.Background())
47+
if err != nil {
48+
t.Errorf("unexpected error, %v", err)
49+
return
50+
}
51+
if !ok {
52+
assert.Equal(t, ok, false)
53+
}
54+
}
55+
56+
func TestTakeSuccess(t *testing.T) {
57+
db, mock := redismock.NewClientMock()
58+
59+
mock = mock.CustomMatch(MyMatch)
60+
mock.ExpectPing().SetVal("PONG")
61+
62+
mock.ExpectScriptExists(hashVal).SetVal([]bool{true})
63+
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(1))
64+
65+
limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
66+
3,
67+
2,
68+
WithAntiDDos(false))
69+
if err != nil {
70+
t.Errorf("unexpected error, %v", err)
71+
return
72+
}
73+
74+
ok, err := limiter.Take(context.Background())
75+
if err != nil {
76+
t.Errorf("unexpected error, %v", err)
77+
return
78+
}
79+
if !ok {
80+
assert.Equal(t, ok, true)
81+
}
82+
}
83+
84+
func TestContextTimeOut(t *testing.T) {
85+
db, mock := redismock.NewClientMock()
86+
mock = mock.CustomMatch(MyMatch)
87+
mock.ExpectPing().SetVal("PONG")
88+
89+
mock.ExpectScriptExists(hashVal).SetVal([]bool{true, true})
90+
for i := 0; i < 1000; i++ {
91+
mock.ExpectEvalSha(hashVal, []string{key}, 1000000, 3, 2).SetVal(int64(0))
92+
}
93+
94+
limiter, err := NewCounterRateLimiter(context.Background(), db, key, time.Second,
95+
3,
96+
2,
97+
WithAntiDDos(false))
98+
if err != nil {
99+
t.Errorf("unexpected error, %v", err)
100+
return
101+
}
102+
103+
waitCtx, cancel := context.WithTimeout(context.Background(), time.Second)
104+
defer cancel()
105+
err = limiter.Wait(waitCtx)
106+
assert.Contains(t, err.Error(), "timeout")
107+
}

0 commit comments

Comments
 (0)