Make Rate limits thread safe (#347)
Signed-off-by: Julien Pivotto <roidelapluie@inuits.eu> Co-authored-by: Andrew Starr-Bochicchio <andrewsomething@users.noreply.github.com>
This commit is contained in:
parent
ce1a90fde7
commit
a51159bebb
16
godo.go
16
godo.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
|
@ -40,8 +41,9 @@ type Client struct {
|
||||||
UserAgent string
|
UserAgent string
|
||||||
|
|
||||||
// Rate contains the current rate limit for the client as determined by the most recent
|
// Rate contains the current rate limit for the client as determined by the most recent
|
||||||
// API call.
|
// API call. It is not thread-safe. Please consider using GetRate() instead.
|
||||||
Rate Rate
|
Rate Rate
|
||||||
|
ratemtx sync.Mutex
|
||||||
|
|
||||||
// Services used for communicating with the API
|
// Services used for communicating with the API
|
||||||
Account AccountService
|
Account AccountService
|
||||||
|
@ -288,6 +290,14 @@ func (c *Client) OnRequestCompleted(rc RequestCompletionCallback) {
|
||||||
c.onRequestCompleted = rc
|
c.onRequestCompleted = rc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRate returns the current rate limit for the client as determined by the most recent
|
||||||
|
// API call. It is thread-safe.
|
||||||
|
func (c *Client) GetRate() Rate {
|
||||||
|
c.ratemtx.Lock()
|
||||||
|
defer c.ratemtx.Unlock()
|
||||||
|
return c.Rate
|
||||||
|
}
|
||||||
|
|
||||||
// newResponse creates a new Response for the provided http.Response
|
// newResponse creates a new Response for the provided http.Response
|
||||||
func newResponse(r *http.Response) *Response {
|
func newResponse(r *http.Response) *Response {
|
||||||
response := Response{Response: r}
|
response := Response{Response: r}
|
||||||
|
@ -330,7 +340,9 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
|
||||||
}()
|
}()
|
||||||
|
|
||||||
response := newResponse(resp)
|
response := newResponse(resp)
|
||||||
|
c.ratemtx.Lock()
|
||||||
c.Rate = response.Rate
|
c.Rate = response.Rate
|
||||||
|
c.ratemtx.Unlock()
|
||||||
|
|
||||||
err = CheckResponse(resp)
|
err = CheckResponse(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
47
godo_test.go
47
godo_test.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -354,6 +355,9 @@ func TestDo_rateLimit(t *testing.T) {
|
||||||
if !client.Rate.Reset.IsZero() {
|
if !client.Rate.Reset.IsZero() {
|
||||||
t.Errorf("Client rate reset not initialized to zero value")
|
t.Errorf("Client rate reset not initialized to zero value")
|
||||||
}
|
}
|
||||||
|
if client.Rate != client.GetRate() {
|
||||||
|
t.Errorf("Client rate is not the same as client.GetRate()")
|
||||||
|
}
|
||||||
|
|
||||||
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
|
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
|
||||||
_, err := client.Do(context.Background(), req, nil)
|
_, err := client.Do(context.Background(), req, nil)
|
||||||
|
@ -371,6 +375,49 @@ func TestDo_rateLimit(t *testing.T) {
|
||||||
if client.Rate.Reset.UTC() != reset {
|
if client.Rate.Reset.UTC() != reset {
|
||||||
t.Errorf("Client rate reset = %v, expected %v", client.Rate.Reset, reset)
|
t.Errorf("Client rate reset = %v, expected %v", client.Rate.Reset, reset)
|
||||||
}
|
}
|
||||||
|
if client.Rate != client.GetRate() {
|
||||||
|
t.Errorf("Client rate is not the same as client.GetRate()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDo_rateLimitRace(t *testing.T) {
|
||||||
|
setup()
|
||||||
|
defer teardown()
|
||||||
|
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Add(headerRateLimit, "60")
|
||||||
|
w.Header().Add(headerRateRemaining, "59")
|
||||||
|
w.Header().Add(headerRateReset, "1372700873")
|
||||||
|
})
|
||||||
|
|
||||||
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
wait = make(chan struct{})
|
||||||
|
count = 100
|
||||||
|
)
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
go func() {
|
||||||
|
<-wait
|
||||||
|
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
|
||||||
|
_, err := client.Do(context.Background(), req, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do(): %v", err)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
go func() {
|
||||||
|
<-wait
|
||||||
|
_ = client.GetRate()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
close(wait)
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDo_rateLimit_errorResponse(t *testing.T) {
|
func TestDo_rateLimit_errorResponse(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue