diff --git a/godo.go b/godo.go index 7648f07..b4a4c9a 100644 --- a/godo.go +++ b/godo.go @@ -11,6 +11,7 @@ import ( "net/url" "reflect" "strconv" + "sync" "time" "github.com/google/go-querystring/query" @@ -40,8 +41,9 @@ type Client struct { UserAgent string // Rate contains the current rate limit for the client as determined by the most recent - // API call. - Rate Rate + // API call. It is not thread-safe. Please consider using GetRate() instead. + Rate Rate + ratemtx sync.Mutex // Services used for communicating with the API Account AccountService @@ -288,6 +290,14 @@ func (c *Client) OnRequestCompleted(rc RequestCompletionCallback) { c.onRequestCompleted = rc } +// GetRate returns the current rate limit for the client as determined by the most recent +// API call. It is thread-safe. +func (c *Client) GetRate() Rate { + c.ratemtx.Lock() + defer c.ratemtx.Unlock() + return c.Rate +} + // newResponse creates a new Response for the provided http.Response func newResponse(r *http.Response) *Response { response := Response{Response: r} @@ -330,7 +340,9 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res }() response := newResponse(resp) + c.ratemtx.Lock() c.Rate = response.Rate + c.ratemtx.Unlock() err = CheckResponse(resp) if err != nil { diff --git a/godo_test.go b/godo_test.go index b82fe42..289b3da 100644 --- a/godo_test.go +++ b/godo_test.go @@ -10,6 +10,7 @@ import ( "net/url" "reflect" "strings" + "sync" "testing" "time" ) @@ -354,6 +355,9 @@ func TestDo_rateLimit(t *testing.T) { if !client.Rate.Reset.IsZero() { t.Errorf("Client rate reset not initialized to zero value") } + if client.Rate != client.GetRate() { + t.Errorf("Client rate is not the same as client.GetRate()") + } req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil) _, err := client.Do(context.Background(), req, nil) @@ -371,6 +375,49 @@ func TestDo_rateLimit(t *testing.T) { if client.Rate.Reset.UTC() != reset { t.Errorf("Client rate reset = %v, expected %v", client.Rate.Reset, reset) } + if client.Rate != client.GetRate() { + t.Errorf("Client rate is not the same as client.GetRate()") + } +} + +func TestDo_rateLimitRace(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(headerRateLimit, "60") + w.Header().Add(headerRateRemaining, "59") + w.Header().Add(headerRateReset, "1372700873") + }) + + var ( + wg sync.WaitGroup + wait = make(chan struct{}) + count = 100 + ) + wg.Add(count) + for i := 0; i < count; i++ { + go func() { + <-wait + req, _ := client.NewRequest(ctx, http.MethodGet, "/", nil) + _, err := client.Do(context.Background(), req, nil) + if err != nil { + t.Fatalf("Do(): %v", err) + } + wg.Done() + }() + } + wg.Add(count) + for i := 0; i < count; i++ { + go func() { + <-wait + _ = client.GetRate() + wg.Done() + }() + } + + close(wait) + wg.Wait() } func TestDo_rateLimit_errorResponse(t *testing.T) {