diff --git a/databases.go b/databases.go index ccffff8..50c006e 100644 --- a/databases.go +++ b/databases.go @@ -56,7 +56,7 @@ type DatabasesService interface { DeleteReplica(context.Context, string, string) (*Response, error) GetEvictionPolicy(context.Context, string) (string, *Response, error) SetEvictionPolicy(context.Context, string, string) (*Response, error) - GetFirewallRules(context.Context, string) (*Response, error) + GetFirewallRules(context.Context, string) ([]DatabaseFirewallRule, *Response, error) UpdateFirewallRules(context.Context, string, *DatabaseUpdateFirewallRulesRequest) (*Response, error) } @@ -277,7 +277,7 @@ type evictionPolicyRoot struct { } type databaseFirewallRuleRoot struct { - Rules []*DatabaseFirewallRule `json:"rules"` + Rules []DatabaseFirewallRule `json:"rules"` } func (d Database) URN() string { @@ -692,14 +692,20 @@ func (svc *DatabasesServiceOp) SetEvictionPolicy(ctx context.Context, databaseID } // GetFirewallRules loads the inbound sources for a given cluster. -func (svc *DatabasesServiceOp) GetFirewallRules(ctx context.Context, databaseID string) (*Response, error) { +func (svc *DatabasesServiceOp) GetFirewallRules(ctx context.Context, databaseID string) ([]DatabaseFirewallRule, *Response, error) { path := fmt.Sprintf(databaseFirewallRulesPath, databaseID) root := new(databaseFirewallRuleRoot) req, err := svc.client.NewRequest(ctx, http.MethodGet, path, nil) if err != nil { - return nil, err + return nil, nil, err } - return svc.client.Do(ctx, req, root) + + resp, err := svc.client.Do(ctx, req, root) + if err != nil { + return nil, resp, err + } + + return root.Rules, resp, nil } // UpdateFirewallRules sets the inbound sources for a given cluster. diff --git a/databases_test.go b/databases_test.go index 42177aa..c6e9e5c 100644 --- a/databases_test.go +++ b/databases_test.go @@ -1192,6 +1192,15 @@ func TestDatabases_GetFirewallRules(t *testing.T) { path := fmt.Sprintf("/v2/databases/%s/firewall", dbID) + want := []DatabaseFirewallRule{ + { + Type: "ip_addr", + Value: "192.168.1.1", + UUID: "deadbeef-dead-4aa5-beef-deadbeef347d", + ClusterUUID: "deadbeef-dead-4aa5-beef-deadbeef347d", + }, + } + body := ` {"rules": [{ "type": "ip_addr", "value": "192.168.1.1", @@ -1204,8 +1213,9 @@ func TestDatabases_GetFirewallRules(t *testing.T) { fmt.Fprint(w, body) }) - _, err := client.Databases.GetFirewallRules(ctx, dbID) + got, _, err := client.Databases.GetFirewallRules(ctx, dbID) require.NoError(t, err) + require.Equal(t, want, got) } func TestDatabases_UpdateFirewallRules(t *testing.T) {