diff --git a/domains.go b/domains.go index 43c0424..5037013 100644 --- a/domains.go +++ b/domains.go @@ -18,6 +18,9 @@ type DomainsService interface { Delete(context.Context, string) (*Response, error) Records(context.Context, string, *ListOptions) ([]DomainRecord, *Response, error) + RecordsByType(context.Context, string, string, *ListOptions) ([]DomainRecord, *Response, error) + RecordsByName(context.Context, string, string, *ListOptions) ([]DomainRecord, *Response, error) + RecordsByTypeAndName(context.Context, string, string, string, *ListOptions) ([]DomainRecord, *Response, error) Record(context.Context, string, int) (*DomainRecord, *Response, error) DeleteRecord(context.Context, string, int) (*Response, error) EditRecord(context.Context, string, int, *DomainRecordEditRequest) (*DomainRecord, *Response, error) @@ -201,7 +204,7 @@ func (d DomainRecordEditRequest) String() string { return Stringify(d) } -// Records returns a slice of DomainRecords for a domain +// Records returns a slice of DomainRecord for a domain. func (s *DomainsServiceOp) Records(ctx context.Context, domain string, opt *ListOptions) ([]DomainRecord, *Response, error) { if len(domain) < 1 { return nil, nil, NewArgError("domain", "cannot be an empty string") @@ -213,21 +216,68 @@ func (s *DomainsServiceOp) Records(ctx context.Context, domain string, opt *List return nil, nil, err } - req, err := s.client.NewRequest(ctx, http.MethodGet, path, nil) + return s.records(ctx, path) +} + +// RecordsByType returns a slice of DomainRecord for a domain matched by record type. +func (s *DomainsServiceOp) RecordsByType(ctx context.Context, domain, ofType string, opt *ListOptions) ([]DomainRecord, *Response, error) { + if len(domain) < 1 { + return nil, nil, NewArgError("domain", "cannot be an empty string") + } + + if len(ofType) < 1 { + return nil, nil, NewArgError("type", "cannot be an empty string") + } + + path := fmt.Sprintf("%s/%s/records?type=%s", domainsBasePath, domain, ofType) + path, err := addOptions(path, opt) if err != nil { return nil, nil, err } - root := new(domainRecordsRoot) - resp, err := s.client.Do(ctx, req, root) - if err != nil { - return nil, resp, err - } - if l := root.Links; l != nil { - resp.Links = l + return s.records(ctx, path) +} + +// RecordsByName returns a slice of DomainRecord for a domain matched by record name. +func (s *DomainsServiceOp) RecordsByName(ctx context.Context, domain, name string, opt *ListOptions) ([]DomainRecord, *Response, error) { + if len(domain) < 1 { + return nil, nil, NewArgError("domain", "cannot be an empty string") } - return root.DomainRecords, resp, err + if len(name) < 1 { + return nil, nil, NewArgError("name", "cannot be an empty string") + } + + path := fmt.Sprintf("%s/%s/records?name=%s", domainsBasePath, domain, name) + path, err := addOptions(path, opt) + if err != nil { + return nil, nil, err + } + + return s.records(ctx, path) +} + +// RecordsByTypeAndName returns a slice of DomainRecord for a domain matched by record type and name. +func (s *DomainsServiceOp) RecordsByTypeAndName(ctx context.Context, domain, ofType, name string, opt *ListOptions) ([]DomainRecord, *Response, error) { + if len(domain) < 1 { + return nil, nil, NewArgError("domain", "cannot be an empty string") + } + + if len(ofType) < 1 { + return nil, nil, NewArgError("type", "cannot be an empty string") + } + + if len(name) < 1 { + return nil, nil, NewArgError("name", "cannot be an empty string") + } + + path := fmt.Sprintf("%s/%s/records?type=%s&name=%s", domainsBasePath, domain, ofType, name) + path, err := addOptions(path, opt) + if err != nil { + return nil, nil, err + } + + return s.records(ctx, path) } // Record returns the record id from a domain @@ -339,3 +389,22 @@ func (s *DomainsServiceOp) CreateRecord(ctx context.Context, return d.DomainRecord, resp, err } + +// Performs a domain records request given a path. +func (s *DomainsServiceOp) records(ctx context.Context, path string) ([]DomainRecord, *Response, error) { + req, err := s.client.NewRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return nil, nil, err + } + + root := new(domainRecordsRoot) + resp, err := s.client.Do(ctx, req, root) + if err != nil { + return nil, resp, err + } + if l := root.Links; l != nil { + resp.Links = l + } + + return root.DomainRecords, resp, err +} diff --git a/domains_test.go b/domains_test.go index f79e24e..f1f1367 100644 --- a/domains_test.go +++ b/domains_test.go @@ -5,7 +5,11 @@ import ( "fmt" "net/http" "reflect" + "strconv" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDomains_ListDomains(t *testing.T) { @@ -181,7 +185,7 @@ func TestDomains_AllRecordsForDomainName(t *testing.T) { expected := []DomainRecord{{ID: 1}, {ID: 2}} if !reflect.DeepEqual(records, expected) { - t.Errorf("Domains.List returned %+v, expected %+v", records, expected) + t.Errorf("Domains.Records returned %+v, expected %+v", records, expected) } } @@ -206,7 +210,167 @@ func TestDomains_AllRecordsForDomainName_PerPage(t *testing.T) { expected := []DomainRecord{{ID: 1}, {ID: 2}} if !reflect.DeepEqual(records, expected) { - t.Errorf("Domains.List returned %+v, expected %+v", records, expected) + t.Errorf("Domains.Records returned %+v, expected %+v", records, expected) + } +} + +func TestDomains_RecordsByType(t *testing.T) { + tests := []struct { + name string + recordType string + pagination *ListOptions + expectedErr *ArgError + }{ + { + name: "success", + recordType: "CNAME", + }, + { + name: "when record type is empty it returns argument error", + expectedErr: &ArgError{arg: "type", reason: "cannot be an empty string"}, + }, + { + name: "with pagination", + recordType: "CNAME", + pagination: &ListOptions{Page: 1, PerPage: 10}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, tt.recordType, r.URL.Query().Get("type")) + if tt.pagination != nil { + require.Equal(t, strconv.Itoa(tt.pagination.Page), r.URL.Query().Get("page")) + require.Equal(t, strconv.Itoa(tt.pagination.PerPage), r.URL.Query().Get("per_page")) + } + testMethod(t, r, http.MethodGet) + fmt.Fprint(w, `{"domain_records":[{"id":1},{"id":2}]}`) + }) + + records, _, err := client.Domains.RecordsByType(ctx, "example.com", tt.recordType, tt.pagination) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } else { + expected := []DomainRecord{{ID: 1}, {ID: 2}} + if !reflect.DeepEqual(records, expected) { + t.Errorf("Domains.RecordsByType returned %+v, expected %+v", records, expected) + } + } + }) + } +} + +func TestDomains_RecordsByName(t *testing.T) { + tests := []struct { + name string + recordName string + pagination *ListOptions + expectedErr *ArgError + }{ + { + name: "success", + recordName: "foo.com", + }, + { + name: "when record name is empty it returns argument error", + expectedErr: &ArgError{arg: "name", reason: "cannot be an empty string"}, + }, + { + name: "with pagination", + recordName: "foo.com", + pagination: &ListOptions{Page: 2, PerPage: 1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, tt.recordName, r.URL.Query().Get("name")) + if tt.pagination != nil { + require.Equal(t, strconv.Itoa(tt.pagination.Page), r.URL.Query().Get("page")) + require.Equal(t, strconv.Itoa(tt.pagination.PerPage), r.URL.Query().Get("per_page")) + } + testMethod(t, r, http.MethodGet) + fmt.Fprint(w, `{"domain_records":[{"id":1},{"id":2}]}`) + }) + + records, _, err := client.Domains.RecordsByName(ctx, "example.com", tt.recordName, tt.pagination) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } else { + expected := []DomainRecord{{ID: 1}, {ID: 2}} + if !reflect.DeepEqual(records, expected) { + t.Errorf("Domains.RecordsByName returned %+v, expected %+v", records, expected) + } + } + }) + } +} + +func TestDomains_RecordsByTypeAndName(t *testing.T) { + tests := []struct { + name string + recordType string + recordName string + pagination *ListOptions + expectedErr *ArgError + }{ + { + name: "success", + recordType: "NS", + recordName: "foo.com", + }, + { + name: "when record type is empty it returns argument error", + recordName: "foo.com", + expectedErr: &ArgError{arg: "type", reason: "cannot be an empty string"}, + }, + { + name: "when record name is empty it returns argument error", + recordType: "NS", + expectedErr: &ArgError{arg: "name", reason: "cannot be an empty string"}, + }, + { + name: "with pagination", + recordType: "CNAME", + recordName: "foo.com", + pagination: &ListOptions{Page: 1, PerPage: 1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/v2/domains/example.com/records", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, tt.recordType, r.URL.Query().Get("type")) + require.Equal(t, tt.recordName, r.URL.Query().Get("name")) + if tt.pagination != nil { + require.Equal(t, strconv.Itoa(tt.pagination.Page), r.URL.Query().Get("page")) + require.Equal(t, strconv.Itoa(tt.pagination.PerPage), r.URL.Query().Get("per_page")) + } + testMethod(t, r, http.MethodGet) + fmt.Fprint(w, `{"domain_records":[{"id":1},{"id":2}]}`) + }) + + records, _, err := client.Domains.RecordsByTypeAndName(ctx, "example.com", tt.recordType, tt.recordName, tt.pagination) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } else { + expected := []DomainRecord{{ID: 1}, {ID: 2}} + if !reflect.DeepEqual(records, expected) { + t.Errorf("Domains.RecordsByTypeAndName returned %+v, expected %+v", records, expected) + } + } + }) } }