Make Rate limits thread safe (#347)
Signed-off-by: Julien Pivotto <roidelapluie@inuits.eu> Co-authored-by: Andrew Starr-Bochicchio <andrewsomething@users.noreply.github.com>
This commit is contained in:
parent
ce1a90fde7
commit
a51159bebb
16
godo.go
16
godo.go
|
@ -11,6 +11,7 @@ import (
|
|||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-querystring/query"
|
||||
|
@ -40,8 +41,9 @@ type Client struct {
|
|||
UserAgent string
|
||||
|
||||
// Rate contains the current rate limit for the client as determined by the most recent
|
||||
// API call.
|
||||
Rate Rate
|
||||
// API call. It is not thread-safe. Please consider using GetRate() instead.
|
||||
Rate Rate
|
||||
ratemtx sync.Mutex
|
||||
|
||||
// Services used for communicating with the API
|
||||
Account AccountService
|
||||
|
@ -288,6 +290,14 @@ func (c *Client) OnRequestCompleted(rc RequestCompletionCallback) {
|
|||
c.onRequestCompleted = rc
|
||||
}
|
||||
|
||||
// GetRate returns the current rate limit for the client as determined by the most recent
|
||||
// API call. It is thread-safe.
|
||||
func (c *Client) GetRate() Rate {
|
||||
c.ratemtx.Lock()
|
||||
defer c.ratemtx.Unlock()
|
||||
return c.Rate
|
||||
}
|
||||
|
||||
// newResponse creates a new Response for the provided http.Response
|
||||
func newResponse(r *http.Response) *Response {
|
||||
response := Response{Response: r}
|
||||
|
@ -330,7 +340,9 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
|
|||
}()
|
||||
|
||||
response := newResponse(resp)
|
||||
c.ratemtx.Lock()
|
||||
c.Rate = response.Rate
|
||||
c.ratemtx.Unlock()
|
||||
|
||||
err = CheckResponse(resp)
|
||||
if err != nil {
|
||||
|
|
47
godo_test.go
47
godo_test.go
|
@ -10,6 +10,7 @@ import (
|
|||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -354,6 +355,9 @@ func TestDo_rateLimit(t *testing.T) {
|
|||
if !client.Rate.Reset.IsZero() {
|
||||
t.Errorf("Client rate reset not initialized to zero value")
|
||||
}
|
||||
if client.Rate != client.GetRate() {
|
||||
t.Errorf("Client rate is not the same as client.GetRate()")
|
||||
}
|
||||
|
||||
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
|
||||
_, err := client.Do(context.Background(), req, nil)
|
||||
|
@ -371,6 +375,49 @@ func TestDo_rateLimit(t *testing.T) {
|
|||
if client.Rate.Reset.UTC() != reset {
|
||||
t.Errorf("Client rate reset = %v, expected %v", client.Rate.Reset, reset)
|
||||
}
|
||||
if client.Rate != client.GetRate() {
|
||||
t.Errorf("Client rate is not the same as client.GetRate()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDo_rateLimitRace(t *testing.T) {
|
||||
setup()
|
||||
defer teardown()
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add(headerRateLimit, "60")
|
||||
w.Header().Add(headerRateRemaining, "59")
|
||||
w.Header().Add(headerRateReset, "1372700873")
|
||||
})
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
wait = make(chan struct{})
|
||||
count = 100
|
||||
)
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
<-wait
|
||||
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
|
||||
_, err := client.Do(context.Background(), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Do(): %v", err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
<-wait
|
||||
_ = client.GetRate()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
close(wait)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestDo_rateLimit_errorResponse(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue