Make Rate limits thread safe (#347)

Signed-off-by: Julien Pivotto <roidelapluie@inuits.eu>

Co-authored-by: Andrew Starr-Bochicchio <andrewsomething@users.noreply.github.com>
This commit is contained in:
Julien Pivotto 2020-07-15 00:41:16 +02:00 committed by GitHub
parent ce1a90fde7
commit a51159bebb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 2 deletions

16
godo.go
View File

@ -11,6 +11,7 @@ import (
"net/url"
"reflect"
"strconv"
"sync"
"time"
"github.com/google/go-querystring/query"
@ -40,8 +41,9 @@ type Client struct {
UserAgent string
// Rate contains the current rate limit for the client as determined by the most recent
// API call.
Rate Rate
// API call. It is not thread-safe. Please consider using GetRate() instead.
Rate Rate
ratemtx sync.Mutex
// Services used for communicating with the API
Account AccountService
@ -288,6 +290,14 @@ func (c *Client) OnRequestCompleted(rc RequestCompletionCallback) {
c.onRequestCompleted = rc
}
// GetRate returns the current rate limit for the client as determined by the most recent
// API call. It is thread-safe.
func (c *Client) GetRate() Rate {
c.ratemtx.Lock()
defer c.ratemtx.Unlock()
return c.Rate
}
// newResponse creates a new Response for the provided http.Response
func newResponse(r *http.Response) *Response {
response := Response{Response: r}
@ -330,7 +340,9 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
}()
response := newResponse(resp)
c.ratemtx.Lock()
c.Rate = response.Rate
c.ratemtx.Unlock()
err = CheckResponse(resp)
if err != nil {

View File

@ -10,6 +10,7 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
)
@ -354,6 +355,9 @@ func TestDo_rateLimit(t *testing.T) {
if !client.Rate.Reset.IsZero() {
t.Errorf("Client rate reset not initialized to zero value")
}
if client.Rate != client.GetRate() {
t.Errorf("Client rate is not the same as client.GetRate()")
}
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
_, err := client.Do(context.Background(), req, nil)
@ -371,6 +375,49 @@ func TestDo_rateLimit(t *testing.T) {
if client.Rate.Reset.UTC() != reset {
t.Errorf("Client rate reset = %v, expected %v", client.Rate.Reset, reset)
}
if client.Rate != client.GetRate() {
t.Errorf("Client rate is not the same as client.GetRate()")
}
}
func TestDo_rateLimitRace(t *testing.T) {
setup()
defer teardown()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(headerRateLimit, "60")
w.Header().Add(headerRateRemaining, "59")
w.Header().Add(headerRateReset, "1372700873")
})
var (
wg sync.WaitGroup
wait = make(chan struct{})
count = 100
)
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
<-wait
req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil)
_, err := client.Do(context.Background(), req, nil)
if err != nil {
t.Fatalf("Do(): %v", err)
}
wg.Done()
}()
}
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
<-wait
_ = client.GetRate()
wg.Done()
}()
}
close(wait)
wg.Wait()
}
func TestDo_rateLimit_errorResponse(t *testing.T) {