Skip to content

Commit 426a192

Browse files
authored
feat: added semaphore http client (#3)
1 parent 8b00782 commit 426a192

File tree

4 files changed

+181
-5
lines changed

4 files changed

+181
-5
lines changed

gosrm.go

+23-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@ import (
99
)
1010

1111
type (
12+
// HTTPClient is the interface that can be used to do HTTP calls.
13+
HTTPClient interface {
14+
Do(req *http.Request) (*http.Response, error)
15+
}
16+
1217
// OSRMClient is the base type with helper methods to call OSRM APIs.
1318
// It only holds the base OSRM URL.
1419
OSRMClient struct {
1520
baseURL *url.URL
21+
22+
client HTTPClient
1623
}
1724

1825
// Request is the OSRM's request structure.
@@ -29,12 +36,25 @@ type (
2936

3037
// New returns a new OSRM client.
3138
func New(baseURL string) (OSRMClient, error) {
39+
var client OSRMClient
40+
3241
u, err := url.Parse(baseURL)
3342
if err != nil {
34-
return OSRMClient{}, err
43+
return client, err
3544
}
3645

37-
return OSRMClient{baseURL: u}, nil
46+
client.baseURL = u
47+
client.SetHTTPClient(NewHTTPClient(HTTPClientConfig{}))
48+
49+
return client, nil
50+
}
51+
52+
// SetHTTPClient sets the HTTP client that will be used to call OSRM.
53+
func (osrm *OSRMClient) SetHTTPClient(client HTTPClient) {
54+
if client == nil {
55+
panic("http client can't be nil")
56+
}
57+
osrm.client = client
3858
}
3959

4060
// get calls the given URL and parses the response.
@@ -45,7 +65,7 @@ func (osrm OSRMClient) get(ctx context.Context, url string, out any) error {
4565
}
4666
req = req.WithContext(ctx)
4767

48-
res, err := http.DefaultClient.Do(req)
68+
res, err := osrm.client.Do(req)
4969
if err != nil {
5070
return err
5171
}

gosrm_test.go

+20-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ import (
1313

1414
const invalidURL string = "postgres://user:abc{[email protected]:5432/db?sslmode=require"
1515

16+
func newOSRMClient() OSRMClient {
17+
return OSRMClient{client: NewHTTPClient(HTTPClientConfig{})}
18+
}
19+
1620
func TestNew(t *testing.T) {
1721
testCases := []struct {
1822
name, baseURL string
@@ -38,8 +42,22 @@ func TestNew(t *testing.T) {
3842
}
3943
}
4044

45+
func TestOSRMClient_SetHTTPClient(t *testing.T) {
46+
osrm := newOSRMClient()
47+
48+
assert.PanicsWithValue(t, "http client can't be nil", func() {
49+
osrm.SetHTTPClient(nil)
50+
})
51+
52+
client := NewHTTPClient(HTTPClientConfig{MaxConcurrency: 10})
53+
54+
osrm.SetHTTPClient(client)
55+
56+
assert.Equal(t, client, osrm.client)
57+
}
58+
4159
func TestOSRMClient_get(t *testing.T) {
42-
osrm := OSRMClient{}
60+
osrm := newOSRMClient()
4361
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4462
w.Write([]byte("{\"message\": \"Ok\"}"))
4563
}))
@@ -59,7 +77,7 @@ func TestOSRMClient_get(t *testing.T) {
5977
}
6078

6179
func TestOSRMClient_applyOpts(t *testing.T) {
62-
osrm := OSRMClient{}
80+
osrm := newOSRMClient()
6381
u := url.URL{}
6482

6583
osrm.applyOpts(&u, []Option{

semaphore_client.go

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package gosrm
2+
3+
import "net/http"
4+
5+
type (
6+
// httpClient is the default implementation of HTTPClient interface.
7+
httpClient struct {
8+
client *http.Client
9+
pool chan struct{}
10+
}
11+
12+
// HTTPClientConfig is the config used to customize http client.
13+
HTTPClientConfig struct {
14+
// MaxConcurrency is the max number of concurrent requests.
15+
// If it's 0 then there is no limit.
16+
//
17+
// Defaults to 0.
18+
MaxConcurrency uint
19+
20+
// HTTPClient is the client which will be used to do HTTP calls.
21+
//
22+
// Defaults to http.DefaultClient
23+
HTTPClient *http.Client
24+
}
25+
)
26+
27+
// acquire acquires a spot in the pool.
28+
func (c httpClient) acquire() {
29+
if cap(c.pool) == 0 {
30+
return
31+
}
32+
c.pool <- struct{}{}
33+
}
34+
35+
// release releases a spot from the pool.
36+
func (c httpClient) release() {
37+
if cap(c.pool) == 0 {
38+
return
39+
}
40+
<-c.pool
41+
}
42+
43+
// Do does the HTTP call.
44+
func (c httpClient) Do(req *http.Request) (*http.Response, error) {
45+
c.acquire()
46+
defer c.release()
47+
48+
return c.client.Do(req)
49+
}
50+
51+
// NewHTTPClient returns a new HTTP client.
52+
func NewHTTPClient(cfg HTTPClientConfig) HTTPClient {
53+
var c httpClient
54+
55+
if cfg.HTTPClient != nil {
56+
c.client = cfg.HTTPClient
57+
} else {
58+
c.client = http.DefaultClient
59+
}
60+
61+
c.pool = make(chan struct{}, cfg.MaxConcurrency)
62+
63+
return c
64+
}

semaphore_client_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package gosrm
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestNewHTTPClient(t *testing.T) {
12+
cfg := HTTPClientConfig{}
13+
14+
c := NewHTTPClient(cfg).(httpClient)
15+
assert.Equal(t, http.DefaultClient, c.client)
16+
assert.Equal(t, 0, cap(c.pool))
17+
18+
cfg.MaxConcurrency = 100
19+
20+
c = NewHTTPClient(cfg).(httpClient)
21+
assert.Equal(t, 100, cap(c.pool))
22+
23+
cfg.HTTPClient = &http.Client{}
24+
c = NewHTTPClient(cfg).(httpClient)
25+
assert.Equal(t, cfg.HTTPClient, c.client)
26+
}
27+
28+
func TestHTTPClient_acquire_and_release(t *testing.T) {
29+
cfg := HTTPClientConfig{}
30+
c := NewHTTPClient(cfg).(httpClient)
31+
32+
for i := 0; i < 5; i++ {
33+
c.acquire()
34+
}
35+
36+
// chan is not used since it's disabled.
37+
assert.Len(t, c.pool, 0)
38+
39+
c.release()
40+
assert.Len(t, c.pool, 0)
41+
42+
cfg.MaxConcurrency = 2
43+
c = NewHTTPClient(cfg).(httpClient)
44+
45+
for i := 0; i < int(cfg.MaxConcurrency); i++ {
46+
c.acquire()
47+
}
48+
49+
// chan is full.
50+
assert.Len(t, c.pool, 2)
51+
52+
for i := cfg.MaxConcurrency; i > 0; i-- {
53+
c.release()
54+
assert.Len(t, c.pool, int(i-1))
55+
}
56+
57+
// chan is empty.
58+
assert.Len(t, c.pool, 0)
59+
}
60+
61+
func TestHTTPClient_Do(t *testing.T) {
62+
client := NewHTTPClient(HTTPClientConfig{MaxConcurrency: 1})
63+
64+
testsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
w.WriteHeader(201)
66+
}))
67+
68+
req, err := http.NewRequest("GET", testsrv.URL, nil)
69+
assert.NoError(t, err)
70+
71+
res, err := client.Do(req)
72+
assert.NoError(t, err)
73+
assert.Equal(t, 201, res.StatusCode)
74+
}

0 commit comments

Comments
 (0)