diff --git a/README.md b/README.md index ca31bc8..be2c327 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ c, err := unifi.NewClient(&unifi.ClientConfig{ You can use interceptors to modify requests and responses. This gives you more control over the client behavior and flexibility to add custom logic. -To use interceptor logic, you need to create a struct implementing [ClientInterceptor](https://pkg.go.dev/github.com/filipowm/go-unifi/unifi#ClientInterceptor). interface. +To use interceptor logic, you need to create a struct implementing [ClientInterceptor](https://pkg.go.dev/github.com/filipowm/go-unifi/unifi#ClientInterceptor) interface. For example, you can use interceptors to log requests and responses: ```go diff --git a/go.mod b/go.mod index 8c33e3c..ba0c420 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/filipowm/go-unifi -go 1.22.1 +go 1.23 -toolchain go1.23.1 +toolchain go1.23.5 require ( github.com/golangci/golangci-lint v1.63.4 diff --git a/unifi/unifi.go b/unifi/unifi.go index 06f144c..77b938f 100644 --- a/unifi/unifi.go +++ b/unifi/unifi.go @@ -36,6 +36,12 @@ const ( logoutPath = "/api/logout" defaultUserAgent = "go-unifi/0.0.1" + + ApiKeyHeader = "X-API-Key" + CsrfHeader = "X-Csrf-Token" + UserAgentHeader = "User-Agent" + AcceptHeader = "Accept" + ContentTypeHeader = "Content-Type" ) var ( @@ -140,15 +146,15 @@ type ClientInterceptor interface { InterceptRequest(req *http.Request) error InterceptResponse(resp *http.Response) error } -type ApiTokenAuthInterceptor struct { +type ApiKeyAuthInterceptor struct { apiKey string } -func (a *ApiTokenAuthInterceptor) InterceptRequest(req *http.Request) error { - req.Header.Set("X-API-Key", a.apiKey) +func (a *ApiKeyAuthInterceptor) InterceptRequest(req *http.Request) error { + req.Header.Set(ApiKeyHeader, a.apiKey) return nil } -func (a *ApiTokenAuthInterceptor) InterceptResponse(_ *http.Response) error { +func (a *ApiKeyAuthInterceptor) InterceptResponse(_ *http.Response) error { return nil } @@ -158,13 +164,13 @@ type CsrfInterceptor struct { func (c *CsrfInterceptor) InterceptRequest(req *http.Request) error { if c.csrfToken != "" { - req.Header.Set("X-Csrf-Token", c.csrfToken) + req.Header.Set(CsrfHeader, c.csrfToken) } return nil } func (c *CsrfInterceptor) InterceptResponse(resp *http.Response) error { - if csrf := resp.Header.Get("X-Csrf-Token"); csrf != "" { + if csrf := resp.Header.Get(CsrfHeader); csrf != "" { c.csrfToken = csrf } return nil @@ -303,7 +309,7 @@ func newUnifi(config *ClientConfig) (*Client, error) { var interceptors []ClientInterceptor if config.APIKey != "" { - interceptors = append(interceptors, &ApiTokenAuthInterceptor{apiKey: config.APIKey}) + interceptors = append(interceptors, &ApiKeyAuthInterceptor{apiKey: config.APIKey}) } else { // CSRF is only needed for user/pass auth interceptors = append(interceptors, &CsrfInterceptor{}) @@ -312,9 +318,9 @@ func newUnifi(config *ClientConfig) (*Client, error) { config.UserAgent = defaultUserAgent } interceptors = append(interceptors, &DefaultHeadersInterceptor{headers: map[string]string{ - "User-Agent": config.UserAgent, - "Accept": "application/json", - "Content-Type": "application/json; charset=utf-8", + UserAgentHeader: config.UserAgent, + AcceptHeader: "application/json", + ContentTypeHeader: "application/json; charset=utf-8", }}) var errorHandler ResponseErrorHandler @@ -340,6 +346,7 @@ func newUnifi(config *ClientConfig) (*Client, error) { } // Login is a helper method. It can be called to grab a new authentication cookie. +// Only useful if you are using user/pass auth. func (c *Client) Login() error { if c.config.APIKey != "" { // no need to login on api-key auth @@ -362,7 +369,7 @@ func (c *Client) Login() error { return nil } -// Logout closes the current session. +// Logout closes the current session. Only useful if you are using user/pass auth. func (c *Client) Logout() error { if c.config.APIKey != "" { // no need to logout on api-key auth @@ -433,7 +440,7 @@ func (c *Client) determineApiStyle() error { return nil } -// GetServerInfo sets the controller's version and UUID. Only call this if you +// GetServerInfo reads the controller's version and UUID. Only call this if you // previously called Login and suspect the controller version has changed. func (c *Client) GetServerInfo() (*ServerInfo, error) { ctx, cancel := c.createRequestContext() @@ -474,6 +481,7 @@ func (c *Client) createRequestURL(apiPath string) (*url.URL, error) { return c.BaseURL.ResolveReference(reqURL), nil } +// Do performs a request to the given API path with the given method. func (c *Client) Do(ctx context.Context, method, apiPath string, reqBody interface{}, respBody interface{}) error { reqReader, err := marshalRequest(reqBody) if err != nil { @@ -531,18 +539,22 @@ func (c *Client) Do(ctx context.Context, method, apiPath string, reqBody interfa return nil } +// Get performs a GET request to the given API path. func (c *Client) Get(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { return c.Do(context, http.MethodGet, apiPath, reqBody, respBody) } +// Post performs a POST request to the given API path. func (c *Client) Post(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { return c.Do(context, http.MethodPost, apiPath, reqBody, respBody) } +// Put performs a PUT request to the given API path. func (c *Client) Put(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { return c.Do(context, http.MethodPut, apiPath, reqBody, respBody) } +// Delete performs a DELETE request to the given API path. func (c *Client) Delete(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { return c.Do(context, http.MethodDelete, apiPath, reqBody, respBody) } diff --git a/unifi/unifi_test.go b/unifi/unifi_test.go new file mode 100644 index 0000000..63227d8 --- /dev/null +++ b/unifi/unifi_test.go @@ -0,0 +1,484 @@ +package unifi + +import ( + "context" + "encoding/json" + "fmt" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "reflect" + "slices" + "strings" + "testing" +) + +const ( + localUrl = "http://127.0.0.1:64431" + testUrl = "http://test.url" +) + +func verifyContainsInterceptors(a *assert.Assertions, c *Client, interceptors ...interface{}) { + var ( + expectedTypes []reflect.Type + matchingTypes []reflect.Type + ) + for _, i := range interceptors { + expectedTypes = append(expectedTypes, reflect.TypeOf(i)) + } + for _, i := range c.interceptors { + actualType := reflect.TypeOf(i) + if slices.Contains(expectedTypes, actualType) { + matchingTypes = append(matchingTypes, actualType) + } + } + if len(matchingTypes) != len(expectedTypes) { + a.Fail(fmt.Sprintf("interceptors not found; expected: %v, found: %v", expectedTypes, matchingTypes)) + } +} + +func verifyDoesNotContainInterceptors(a *assert.Assertions, c *Client, interceptors ...interface{}) { + var ( + expectedTypes []reflect.Type + matchingTypes []reflect.Type + ) + for _, i := range interceptors { + expectedTypes = append(expectedTypes, reflect.TypeOf(i)) + } + for _, i := range c.interceptors { + actualType := reflect.TypeOf(i) + if slices.Contains(expectedTypes, actualType) { + matchingTypes = append(matchingTypes, actualType) + } + } + if len(matchingTypes) != 0 { + a.Fail(fmt.Sprintf("interceptors found; expected to be not present: %v, found: %v", expectedTypes, matchingTypes)) + } +} + +func TestNewClient(t *testing.T) { + t.Parallel() + a := assert.New(t) + c, err := NewClient(&ClientConfig{ + URL: localUrl, + User: "admin", + Pass: "password", + VerifySSL: false, + }) + a.NotNil(err) + a.EqualValues(localUrl, c.BaseURL.String()) + a.Contains(err.Error(), "connection refused", "an invalid destination should produce a connection error.") + verifyContainsInterceptors(a, c, &CsrfInterceptor{}, &DefaultHeadersInterceptor{}) + verifyDoesNotContainInterceptors(a, c, &ApiKeyAuthInterceptor{}) +} + +func TestNewClientWithApiKey(t *testing.T) { + t.Parallel() + a := assert.New(t) + // when + c, err := NewClient(&ClientConfig{ + URL: localUrl, + APIKey: "test", + VerifySSL: false, + }) + + // then + a.NotNil(err) + a.EqualValues(localUrl, c.BaseURL.String()) + a.Contains(err.Error(), "connection refused", "an invalid destination should produce a connection error.") + verifyContainsInterceptors(a, c, &ApiKeyAuthInterceptor{}, &DefaultHeadersInterceptor{}) + verifyDoesNotContainInterceptors(a, c, &CsrfInterceptor{}) +} + +func TestCustomizeHttpClient(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + called := false + + // when + NewClient(&ClientConfig{ + URL: localUrl, + HttpCustomizer: func(transport *http.Transport) error { + called = true + return nil + }, + }) + + // then + a.True(called, "http customizer not called") +} + +type TestInterceptor struct { + request *http.Request + response *http.Response + failOnRequest bool +} + +func (i *TestInterceptor) IsRequestIntercepted() bool { + return i.request != nil +} + +func (i *TestInterceptor) IsResponseIntercepted() bool { + return i.response != nil +} + +func (i *TestInterceptor) InterceptRequest(req *http.Request) error { + i.request = req + if i.failOnRequest { + return fmt.Errorf("request interceptor failed") + } + return nil +} +func (i *TestInterceptor) InterceptResponse(resp *http.Response) error { + i.response = resp + return nil +} + +func (i *TestInterceptor) RequestHeader(key string) string { + return i.request.Header.Get(key) +} + +func (i *TestInterceptor) ResponseHeader(key string) string { + return i.response.Header.Get(key) +} + +func (i *TestInterceptor) Method() string { + return i.request.Method +} + +func NewTestInterceptor() *TestInterceptor { + return &TestInterceptor{} +} + +func (i *TestInterceptor) AsList() []ClientInterceptor { + return []ClientInterceptor{i} +} + +func NewTestClientWithInterceptor() (*Client, *TestInterceptor) { + interceptor := NewTestInterceptor() + c, _ := NewClient(&ClientConfig{ + URL: testUrl, + APIKey: "test-key", + Interceptors: interceptor.AsList(), + }) + c.apiPaths = &NewStyleAPI + return c, interceptor +} + +func TestInterceptors(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Get(context.Background(), "/", nil, nil) + + // then + a.True(interceptor.IsRequestIntercepted(), "request interceptor not called") + a.False(interceptor.IsResponseIntercepted(), "response interceptor called, but should not because of failed request") +} + +func TestNoSendRequestWhenRequestInterceptorReturnsError(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + interceptor.failOnRequest = true + + // when + err := c.Get(context.Background(), "/", nil, nil) + + // then + a.NotNil(err) + a.Contains(err.Error(), "request interceptor failed") +} + +func TestProperRequestUrl(t *testing.T) { + t.Parallel() + a := assert.New(t) + testCases := []struct { + path string + expected string + }{ + {"", testUrl + NewStyleAPI.ApiPath}, + {"test", testUrl + NewStyleAPI.ApiPath + "/test"}, + {"test/", testUrl + NewStyleAPI.ApiPath + "/test"}, + {"test/test", testUrl + NewStyleAPI.ApiPath + "/test/test"}, + {"/test/", testUrl + "/test/"}, + {"/test", testUrl + "/test"}, + {"/test/test", testUrl + "/test/test"}, + } + // given + c, interceptor := NewTestClientWithInterceptor() + + for _, tc := range testCases { + t.Run(tc.path, func(t *testing.T) { + // when + c.Get(context.Background(), tc.path, nil, nil) + + // then + a.EqualValues(tc.expected, interceptor.request.URL.String()) + }) + } +} + +func TestApiKeyAddedToRequest(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Get(context.Background(), "/", nil, nil) + + // then + a.EqualValues("test-key", interceptor.RequestHeader(ApiKeyHeader)) +} + +func TestDefaultHeadersAddedToRequest(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Get(context.Background(), "/", nil, nil) + + // then + a.EqualValues("application/json", interceptor.RequestHeader(AcceptHeader)) + a.EqualValues("application/json; charset=utf-8", interceptor.RequestHeader(ContentTypeHeader)) + a.EqualValues(defaultUserAgent, interceptor.RequestHeader(UserAgentHeader)) +} + +type TestData struct { + Data string `json:"data"` +} + +func TestRequestSentWithJson(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + data := &TestData{ + Data: "test", + } + + // when + c.Get(context.Background(), "/", data, nil) + + // then + body := &TestData{} + err := json.NewDecoder(interceptor.request.Body).Decode(body) + + a.Nil(err) + a.Equal(data, body) +} + +func TestRequestMethod(t *testing.T) { + t.Parallel() + a := assert.New(t) + testCases := []string{ + http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodOptions, http.MethodHead, http.MethodTrace, http.MethodConnect, + } + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Post(context.Background(), "/", nil, nil) + + // then + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + // when + c.Do(context.Background(), tc, "", nil, nil) + + // then + a.EqualValues(tc, interceptor.Method()) + }) + } +} + +func TestGetRequest(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Get(context.Background(), "/", nil, nil) + + // then + a.EqualValues(http.MethodGet, interceptor.Method()) +} + +func TestPostRequest(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Post(context.Background(), "/", nil, nil) + + // then + a.EqualValues(http.MethodPost, interceptor.Method()) +} + +func TestPutRequest(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Put(context.Background(), "/", nil, nil) + + // then + a.EqualValues(http.MethodPut, interceptor.Method()) +} + +func TestDeleteRequest(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + c, interceptor := NewTestClientWithInterceptor() + + // when + c.Delete(context.Background(), "/", nil, nil) + + // then + a.EqualValues(http.MethodDelete, interceptor.Method()) +} + +func RunTestServer(path string, requestBody interface{}) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(CsrfHeader, "csrf-token") + if !strings.EqualFold(r.URL.Path, path) { + w.WriteHeader(http.StatusNotFound) + return + } + w.WriteHeader(http.StatusOK) + data, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Printf("error reading body:%v", err) + return + } + err = json.Unmarshal(data, &requestBody) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Printf("error decoding body: %s: %s", string(data), err) + return + } + resp := TestData{ + Data: "test", + } + respData, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Printf("error encoding response: %s", err) + return + } + _, err = w.Write(respData) + if err != nil { + fmt.Printf("error writing response: %s", err) + } + })) +} + +func TestUnifiIntegrationUserPassInjected(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + type userPass struct { + Username string `json:"username"` + Password string `json:"password"` + } + srv := RunTestServer(NewStyleAPI.LoginPath, userPass{}) + interceptor := NewTestInterceptor() + c, _ := NewClient(&ClientConfig{ + URL: srv.URL, + User: "test-user", + Pass: "test-pass", + Interceptors: interceptor.AsList(), + }) + c.apiPaths = &NewStyleAPI + + // when + err := c.Login() + + // then + a.Nil(err, "user/pass login must not produce an error") + a.EqualValues(http.MethodPost, interceptor.Method()) + a.EqualValues(http.StatusOK, interceptor.response.StatusCode) +} + +func TestResponseDataHandling(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + reqData := TestData{ + Data: "request", + } + srv := RunTestServer(NewStyleAPI.ApiPath+"/test", TestData{}) + c, _ := NewClient(&ClientConfig{ + URL: srv.URL, + }) + c.apiPaths = &NewStyleAPI + var data TestData + + // when + err := c.Get(context.Background(), "test", reqData, &data) + + // then + a.Nil(err) + a.EqualValues("test", data.Data) +} + +func TestCsrfHandling(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + srv := RunTestServer("", struct{}{}) + interceptor := NewTestInterceptor() + c, _ := NewClient(&ClientConfig{ + URL: srv.URL, + Interceptors: interceptor.AsList(), + }) + c.apiPaths = &NewStyleAPI + + // when + c.Get(context.Background(), "", nil, nil) + + // then + a.EqualValues("", interceptor.RequestHeader(CsrfHeader)) + a.EqualValues("csrf-token", interceptor.ResponseHeader(CsrfHeader)) + + // when + c.Get(context.Background(), "", nil, nil) + + // then + a.EqualValues("csrf-token", interceptor.RequestHeader(CsrfHeader)) +} + +func TestOverrideUserAgent(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + interceptor := NewTestInterceptor() + c, _ := NewClient(&ClientConfig{ + URL: testUrl, + Interceptors: interceptor.AsList(), + UserAgent: "test-agent", + }) + c.apiPaths = &NewStyleAPI + + // when + c.Get(context.Background(), "", nil, nil) + + // then + a.EqualValues("test-agent", interceptor.RequestHeader(UserAgentHeader)) +}