From c6e20b675ce8f3bd0fc2acd3faf084451e678336 Mon Sep 17 00:00:00 2001 From: Mateusz Filipowicz Date: Mon, 10 Feb 2025 22:59:39 +0100 Subject: [PATCH] refactor: major refactor of unifi.go to split it into multiple files for cohesiveness and readability. Even more documentation --- unifi/api_paths.go | 100 ++++++++ unifi/client.go | 298 ++++++++++++++++++++++ unifi/interceptors.go | 71 ++++++ unifi/requests.go | 148 +++++++++++ unifi/sysinfo.go | 56 +++++ unifi/unifi.go | 556 +----------------------------------------- unifi/unifi_test.go | 107 ++++++-- 7 files changed, 762 insertions(+), 574 deletions(-) create mode 100644 unifi/api_paths.go create mode 100644 unifi/client.go create mode 100644 unifi/interceptors.go create mode 100644 unifi/requests.go diff --git a/unifi/api_paths.go b/unifi/api_paths.go new file mode 100644 index 0000000..2e94ee2 --- /dev/null +++ b/unifi/api_paths.go @@ -0,0 +1,100 @@ +package unifi + +import ( + "errors" + "fmt" + "io" + "net/http" +) + +const ( + apiPath = "/api" + apiV2Path = "/v2/api" + + apiPathNew = "/proxy/network/api" + apiV2PathNew = "/proxy/network/v2/api" + + loginPath = "/api/login" + loginPathNew = "/api/auth/login" + + statusPath = "/status" + statusPathNew = "/proxy/network/status" + + logoutPath = "/api/logout" + + defaultUserAgent = "go-unifi/0.0.1" + + ApiKeyHeader = "X-Api-Key" + CsrfHeader = "X-Csrf-Token" + UserAgentHeader = "User-Agent" + AcceptHeader = "Accept" + ContentTypeHeader = "Content-Type" +) + +// APIPaths defines the URL paths used by the client. +type APIPaths struct { + ApiPath string + ApiV2Path string + LoginPath string + StatusPath string + LogoutPath string +} + +var ( + OldStyleAPI = APIPaths{ + ApiPath: apiPath, + ApiV2Path: apiV2Path, + LoginPath: loginPath, + StatusPath: statusPath, + LogoutPath: logoutPath, + } + NewStyleAPI = APIPaths{ + ApiPath: apiPathNew, + ApiV2Path: apiV2PathNew, + LoginPath: loginPathNew, + StatusPath: statusPathNew, + LogoutPath: logoutPath, + } +) + +// determineApiStyle checks the base URL to decide which API style to use and sets the apiPaths accordingly. +func (c *Client) determineApiStyle() error { + ctx, cancel := c.newRequestContext() + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseURL.String(), nil) + if err != nil { + return err + } + + client := &http.Client{ + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + Transport: c.http.Transport, + } + + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + // Discard response body to avoid leaks + _, _ = io.Copy(io.Discard, resp.Body) + + switch resp.StatusCode { + case http.StatusOK: + c.apiPaths = &NewStyleAPI + case http.StatusFound: + c.apiPaths = &OldStyleAPI + default: + return fmt.Errorf("expected 200 or 302 status code, but got: %d", resp.StatusCode) + } + + if c.apiPaths == &OldStyleAPI && c.credentials.IsAPIKey() { + return errors.New("unable to use API key authentication with old style API. Switch to user/pass authentication or update controller to latest version") + } + + return nil +} diff --git a/unifi/client.go b/unifi/client.go new file mode 100644 index 0000000..2f2aea2 --- /dev/null +++ b/unifi/client.go @@ -0,0 +1,298 @@ +package unifi + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/http/cookiejar" + "net/url" + "slices" + "strings" + "sync" + "time" + + "golang.org/x/net/publicsuffix" +) + +// validationMode represents the mode for request validation. +// It may be set to "soft", "hard", or "disable". The default is "soft". +type validationMode string + +const ( + // SoftValidation indicates that validation errors are logged as warnings but do not prevent the request from proceeding. + SoftValidation validationMode = "soft" + // HardValidation indicates that validation errors are treated as fatal and will cause the request to be rejected. + HardValidation validationMode = "hard" + // DisableValidation indicates that no validation is performed on the request body. + DisableValidation validationMode = "disable" + // DefaultValidation is the default validation mode used if none is specified. + // Currently set to SoftValidation, but may change to HardValidation in a future major version. + DefaultValidation validationMode = SoftValidation // TODO: change to hard in next major version +) + +// HttpCustomizer is a function type for customizing the HTTP transport. +// It receives a pointer to an http.Transport and returns an error if customization fails. +type HttpCustomizer func(transport *http.Transport) error + +// ResponseErrorHandler defines a method for handling HTTP response errors. +// HandleError processes the HTTP response and returns an error if the response indicates failure. +type ResponseErrorHandler interface { + // HandleError processes the HTTP response and returns an error if the response signals a failure. + HandleError(resp *http.Response) error +} + +/* +ClientConfig holds configuration parameters for creating a UniFi Client. + +Fields: + + URL: The base URL of the UniFi controller. Must be a valid URL and should not include the `/api` suffix. + APIKey: An API key used for authentication. Provide this if user/password credentials are not used. + User: The username for user/password authentication. Must be provided with Pass if APIKey is not used. + Pass: The password for user/password authentication. Must be provided with User if APIKey is not used. + Timeout: The maximum duration to wait for responses; default is no timeout. + VerifySSL: When false, disables SSL certificate verification. + Interceptors: A slice of ClientInterceptor implementations that can modify requests and responses. + HttpCustomizer:An optional function to customize the HTTP transport (e.g., for custom TLS settings). + UserAgent: The User-Agent header string for outgoing HTTP requests. + ErrorHandler: A custom handler for processing HTTP response errors. + UseLocking: If true, enables internal locking for concurrent request processing. + ValidationMode:The mode for validating request bodies. Can be "soft", "hard", or "disable". +*/ +type ClientConfig struct { + URL string `validate:"required,http_url"` + APIKey string `validate:"required_without_all=User Pass"` + User string `validate:"excluded_with=APIKey,required_with=Pass"` + Pass string `validate:"excluded_with=APIKey,required_with=User"` + Timeout time.Duration // How long to wait for replies, default: forever. + VerifySSL bool + Interceptors []ClientInterceptor + HttpCustomizer HttpCustomizer + UserAgent string + ErrorHandler ResponseErrorHandler + UseLocking bool + ValidationMode validationMode `validate:"omitempty,oneof=soft hard disable"` +} + +// Credentials abstracts authentication credentials. +// It defines methods to determine the type of credentials and retrieve the associated values. +type Credentials interface { + // IsAPIKey returns true if the credentials represent an API key. + IsAPIKey() bool + // GetAPIKey returns the API key; returns an empty string if not applicable. + GetAPIKey() string + // GetUser returns the username for authentication; returns an empty string if not applicable. + GetUser() string + // GetPass returns the password for authentication; returns an empty string if not applicable. + GetPass() string +} + +// APIKeyCredentials holds API key authentication details. +type APIKeyCredentials struct { + APIKey string +} + +func (a APIKeyCredentials) IsAPIKey() bool { return true } +func (a APIKeyCredentials) GetAPIKey() string { return a.APIKey } +func (a APIKeyCredentials) GetUser() string { return "" } +func (a APIKeyCredentials) GetPass() string { return "" } + +// UserPassCredentials holds user/password authentication. +type UserPassCredentials struct { + User string + Pass string +} + +func (u UserPassCredentials) IsAPIKey() bool { return false } +func (u UserPassCredentials) GetAPIKey() string { return "" } +func (u UserPassCredentials) GetUser() string { return u.User } +func (u UserPassCredentials) GetPass() string { return u.Pass } + +// Client represents a UniFi client. +type Client struct { + BaseURL *url.URL + SysInfo *SysInfo + apiPaths *APIPaths + timeout time.Duration + credentials Credentials + validationMode validationMode + useLocking bool + + http *http.Client + interceptors []ClientInterceptor + errorHandler ResponseErrorHandler + lock sync.Mutex + validator *validator +} + +// AddInterceptor adds a ClientInterceptor to the client's interceptor list if it is not already present. +// It appends the interceptor only if it is not already included in the list. +func (c *Client) AddInterceptor(interceptor *ClientInterceptor) { + if !slices.Contains(c.interceptors, *interceptor) { + c.interceptors = append(c.interceptors, *interceptor) + } +} + +func parseBaseURL(base string) (*url.URL, error) { + baseURL, err := url.Parse(base) + if err != nil { + return nil, err + } + // Check if base URL's path is "/api" (deprecated usage now in api_paths.go) + if strings.TrimSuffix(baseURL.Path, "/") == "/api" { + return nil, fmt.Errorf("expected a base URL without the `/api`, got: %q", baseURL) + } + return baseURL, nil +} + +func newClientFromConfig(config *ClientConfig, v *validator) (*Client, error) { + var err error + config.URL = strings.TrimRight(config.URL, "/") + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{InsecureSkipVerify: !config.VerifySSL}, + } + if config.HttpCustomizer != nil { + if err = config.HttpCustomizer(transport); err != nil { + return nil, fmt.Errorf("failed customizing HTTP transport: %w", err) + } + } + client := &http.Client{ + Timeout: config.Timeout, + Transport: transport, + } + if config.APIKey == "" { + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + return nil, fmt.Errorf("failed creating cookiejar: %w", err) + } + client.Jar = jar + } + baseURL, err := parseBaseURL(config.URL) + if err != nil { + return nil, fmt.Errorf("failed parsing base URL: %w", err) + } + var interceptors []ClientInterceptor + var credentials Credentials + + if config.APIKey != "" { + credentials = APIKeyCredentials{APIKey: config.APIKey} + interceptors = append(interceptors, &APIKeyAuthInterceptor{apiKey: config.APIKey}) + } else { + credentials = UserPassCredentials{User: config.User, Pass: config.Pass} + interceptors = append(interceptors, &CSRFInterceptor{}) + } + if len(config.UserAgent) == 0 { + config.UserAgent = defaultUserAgent + } + interceptors = append(interceptors, &DefaultHeadersInterceptor{headers: map[string]string{ + UserAgentHeader: config.UserAgent, + AcceptHeader: "application/json", + ContentTypeHeader: "application/json; charset=utf-8", + }}) + var errorHandler ResponseErrorHandler + if config.ErrorHandler != nil { + errorHandler = config.ErrorHandler + } else { + errorHandler = &DefaultResponseErrorHandler{} + } + if config.ValidationMode == "" { + config.ValidationMode = DefaultValidation + } + u := &Client{ + BaseURL: baseURL, + timeout: config.Timeout, + credentials: credentials, + validationMode: config.ValidationMode, + useLocking: config.UseLocking, + http: client, + interceptors: interceptors, + errorHandler: errorHandler, + lock: sync.Mutex{}, + validator: v, + } + for _, interceptor := range config.Interceptors { + u.AddInterceptor(&interceptor) + } + return u, nil +} + +// NewClient creates and initializes a new UniFi client based on the provided ClientConfig. +// It validates the configuration, determines the API style, performs login if necessary, +// and retrieves system information from the UniFi controller. +// On success, it returns a pointer to a Client; otherwise, it returns an error. +func NewClient(config *ClientConfig) (*Client, error) { + c, err := NewBareClient(config) + if err != nil { + return c, err + } + if err = c.Login(); err != nil { + return c, fmt.Errorf("failed logging in: %w", err) + } + if sysInfo, err := c.GetSystemInformation(); err != nil { + return c, fmt.Errorf("failed getting server info: %w", err) + } else { + c.SysInfo = sysInfo + } + return c, nil +} + +// NewBareClient creates a new UniFi client without performing login or system information retrieval. +// When user/pass authentication is used, you must call Login before making requests. +// It validates the configuration, determines the API style, and returns a pointer to the client on success. +func NewBareClient(config *ClientConfig) (*Client, error) { + v, err := newValidator() + if err != nil { + return nil, fmt.Errorf("failed creating validator: %w", err) + } + if err = v.Validate(config); err != nil { + return nil, fmt.Errorf("failed validating client configuration: %w", err) + } + c, err := newClientFromConfig(config, v) + if err != nil { + return nil, fmt.Errorf("failed creating unifi client: %w", err) + } + if err = c.determineApiStyle(); err != nil { + return c, fmt.Errorf("failed determining API style: %w", err) + } + return c, nil +} + +// Login authenticates the client using user/pass credentials. +// For API key authentication, Login does nothing. +// It returns an error if the authentication process fails. +func (c *Client) Login() error { + if c.credentials.IsAPIKey() { + return nil + } + + ctx, cancel := c.newRequestContext() + defer cancel() + + err := c.Post(ctx, c.apiPaths.LoginPath, &struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: c.credentials.GetUser(), + Password: c.credentials.GetPass(), + }, nil) + if err != nil { + return err + } + return nil +} + +// Logout terminates the client's session for user/pass authentication. +// For API key authentication, Logout does nothing. +// It returns an error if the logout process fails. +func (c *Client) Logout() error { + if c.credentials.IsAPIKey() { + return nil + } + + ctx, cancel := c.newRequestContext() + defer cancel() + + err := c.Post(ctx, c.apiPaths.LogoutPath, nil, nil) + return err +} diff --git a/unifi/interceptors.go b/unifi/interceptors.go new file mode 100644 index 0000000..9565b48 --- /dev/null +++ b/unifi/interceptors.go @@ -0,0 +1,71 @@ +package unifi + +import "net/http" + +// ClientInterceptor defines the interface for interceptors. +// An interceptor can modify HTTP requests and responses. +type ClientInterceptor interface { + InterceptRequest(req *http.Request) error + InterceptResponse(resp *http.Response) error +} + +// APIKeyAuthInterceptor adds an API key to outgoing requests. +// It implements the ClientInterceptor interface. +type APIKeyAuthInterceptor struct { + apiKey string +} + +// InterceptRequest sets the API key header on the given HTTP request. +// It adds the header defined by ApiKeyHeader with the stored API key and returns nil. +func (a *APIKeyAuthInterceptor) InterceptRequest(req *http.Request) error { + req.Header.Set(ApiKeyHeader, a.apiKey) + return nil +} + +// InterceptResponse does not modify the HTTP response and always returns nil. +func (a *APIKeyAuthInterceptor) InterceptResponse(_ *http.Response) error { + return nil +} + +// CSRFInterceptor manages CSRF tokens when using user/pass authentication. +// It implements the ClientInterceptor interface. +type CSRFInterceptor struct { + CSRFToken string +} + +// InterceptRequest adds the CSRF token to the HTTP request header if it is set. +// It returns nil on success. +func (c *CSRFInterceptor) InterceptRequest(req *http.Request) error { + if c.CSRFToken != "" { + req.Header.Set(CsrfHeader, c.CSRFToken) + } + return nil +} + +// InterceptResponse extracts the CSRF token from the HTTP response header, if present, and stores it for future requests. +func (c *CSRFInterceptor) InterceptResponse(resp *http.Response) error { + if token := resp.Header.Get(CsrfHeader); token != "" { + c.CSRFToken = token + } + return nil +} + +// DefaultHeadersInterceptor sets default HTTP headers for requests. +// It implements the ClientInterceptor interface. +type DefaultHeadersInterceptor struct { + headers map[string]string +} + +// InterceptRequest sets default HTTP headers on the request as specified in the interceptor's headers map. +// It returns nil on success. +func (d *DefaultHeadersInterceptor) InterceptRequest(req *http.Request) error { + for key, value := range d.headers { + req.Header.Set(key, value) + } + return nil +} + +// InterceptResponse does not modify the HTTP response and always returns nil. +func (d *DefaultHeadersInterceptor) InterceptResponse(_ *http.Response) error { + return nil +} diff --git a/unifi/requests.go b/unifi/requests.go new file mode 100644 index 0000000..f019a4a --- /dev/null +++ b/unifi/requests.go @@ -0,0 +1,148 @@ +package unifi + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strings" +) + +// marshalRequest marshals the request body to an io.Reader. Returns nil if reqBody is nil. +func marshalRequest(reqBody interface{}) (io.Reader, error) { + if reqBody == nil { + return nil, nil //nolint: nilnil + } + reqBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + return bytes.NewReader(reqBytes), nil +} + +// buildRequestURL constructs the full URL for a given apiPath using the client's BaseURL and apiPaths. +func (c *Client) buildRequestURL(apiPath string) (*url.URL, error) { + reqURL, err := url.Parse(apiPath) + if err != nil { + return nil, err + } + if !strings.HasPrefix(apiPath, "/") && !reqURL.IsAbs() { + reqURL.Path = path.Join(c.apiPaths.ApiPath, reqURL.Path) + } + return c.BaseURL.ResolveReference(reqURL), nil +} + +// validateRequestBody validates the request body if validation is enabled. +func (c *Client) validateRequestBody(reqBody interface{}) error { + if reqBody != nil && c.validationMode != DisableValidation { + if err := c.validator.Validate(reqBody); err != nil { + err = fmt.Errorf("failed validating request body: %w", err) + if c.validationMode == HardValidation { + return err + } else { + fmt.Println(err) + } + } + } + return nil +} + +// newRequestContext creates a new context for the request with a timeout if specified. +func (c *Client) newRequestContext() (context.Context, context.CancelFunc) { + ctx := context.Background() + if c.timeout != 0 { + return context.WithTimeout(ctx, c.timeout) + } + return ctx, func() {} +} + +// Do performs an HTTP request using the given method, apiPath, request body, and decodes the response into respBody. +// It validates the request body, applies interceptors, and decodes the HTTP response into respBody if provided. +// It returns an error if the request or response handling fails. +func (c *Client) Do(ctx context.Context, method, apiPath string, reqBody interface{}, respBody interface{}) error { + if err := c.validateRequestBody(reqBody); err != nil { + return err + } + reqReader, err := marshalRequest(reqBody) + if err != nil { + return fmt.Errorf("unable to marshal request: %w", err) + } + + url, err := c.buildRequestURL(apiPath) + if err != nil { + return fmt.Errorf("unable to create request URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, method, url.String(), reqReader) + if err != nil { + return fmt.Errorf("unable to create request: %s %s %w", method, apiPath, err) + } + + if c.useLocking { + c.lock.Lock() + defer c.lock.Unlock() + } + for _, interceptor := range c.interceptors { + if err := interceptor.InterceptRequest(req); err != nil { + return err + } + } + + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("unable to perform request: %s %s %w", method, apiPath, err) + } + defer resp.Body.Close() + + for _, interceptor := range c.interceptors { + if err := interceptor.InterceptResponse(resp); err != nil { + return err + } + } + + if err := c.errorHandler.HandleError(resp); err != nil { + return err + } + + if respBody == nil || resp.ContentLength == 0 { + return nil + } + + err = json.NewDecoder(resp.Body).Decode(respBody) + if err != nil { + return fmt.Errorf("unable to decode body: %s %s %w", method, apiPath, err) + } + return nil +} + +// Get sends an HTTP GET request to the specified API path with the provided request body, +// and decodes the HTTP response into respBody. +// It is a convenience wrapper around Do. +func (c *Client) Get(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { + return c.Do(ctx, http.MethodGet, apiPath, reqBody, respBody) +} + +// Post sends an HTTP POST request to the specified API path with the provided request body, +// and decodes the HTTP response into respBody. +// It is a convenience wrapper around Do. +func (c *Client) Post(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { + return c.Do(ctx, http.MethodPost, apiPath, reqBody, respBody) +} + +// Put sends an HTTP PUT request to the specified API path with the provided request body, +// and decodes the HTTP response into respBody. +// It is a convenience wrapper around Do. +func (c *Client) Put(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { + return c.Do(ctx, http.MethodPut, apiPath, reqBody, respBody) +} + +// Delete sends an HTTP DELETE request to the specified API path with the provided request body, +// and decodes the HTTP response into respBody. +// It is a convenience wrapper around Do. +func (c *Client) Delete(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { + return c.Do(ctx, http.MethodDelete, apiPath, reqBody, respBody) +} diff --git a/unifi/sysinfo.go b/unifi/sysinfo.go index d83c2ae..8cb1efe 100644 --- a/unifi/sysinfo.go +++ b/unifi/sysinfo.go @@ -2,9 +2,11 @@ package unifi import ( "context" + "errors" "fmt" ) +// SysInfo represents detailed system information from the UniFi controller. type SysInfo struct { Timezone string `json:"timezone"` Version string `json:"version"` @@ -71,6 +73,7 @@ type SysInfo struct { */ } +// GetSystemInfo retrieves system info using the new API. func (c *Client) GetSystemInfo(ctx context.Context, id string) (*SysInfo, error) { var respBody struct { Meta Meta `json:"Meta"` @@ -88,3 +91,56 @@ func (c *Client) GetSystemInfo(ctx context.Context, id string) (*SysInfo, error) return &respBody.Data[0], nil } + +// serverInfo represents basic server info from old API . +type serverInfo struct { + Up bool `json:"up"` + ServerVersion string `json:"server_version"` + UUID string `json:"uuid"` +} + +// getOldSysInfo retrieves system information using the old API style. +func (c *Client) getOldSysInfo(ctx context.Context) (*SysInfo, error) { + var response struct { + Data serverInfo `json:"Meta"` + } + + err := c.Get(ctx, c.apiPaths.StatusPath, nil, &response) + if err != nil { + return nil, err + } + d := response.Data + return &SysInfo{ + Version: d.ServerVersion, + }, nil +} + +// GetSystemInformation retrieves system information, trying the new API first and falling back to the old API if necessary. +func (c *Client) GetSystemInformation() (*SysInfo, error) { + ctx, cancel := c.newRequestContext() + defer cancel() + + var resultingError error + info, err := c.GetSystemInfo(ctx, "default") + if err != nil { + resultingError = err + } else if info == nil || info.Version == "" { + resultingError = errors.New("new API returned empty server info") + } + + if resultingError != nil { + info, err = c.getOldSysInfo(ctx) + if err != nil { + resultingError = errors.Join(resultingError, err) + } else if info == nil || info.Version == "" { + resultingError = errors.Join(resultingError, errors.New("old API returned empty server info")) + } else { + resultingError = nil + } + } + + if resultingError != nil { + return nil, resultingError + } + return info, nil +} diff --git a/unifi/unifi.go b/unifi/unifi.go index 4be1da3..62ebe32 100644 --- a/unifi/unifi.go +++ b/unifi/unifi.go @@ -1,550 +1,10 @@ package unifi -import ( - "bytes" - "context" - "crypto/tls" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/http/cookiejar" - "net/url" - "path" - "slices" - "strings" - "sync" - "time" - - "golang.org/x/net/publicsuffix" -) - -const ( - apiPath = "/api" - apiV2Path = "/v2/api" - - apiPathNew = "/proxy/network/api" - apiV2PathNew = "/proxy/network/v2/api" - - loginPath = "/api/login" - loginPathNew = "/api/auth/login" - - statusPath = "/status" - statusPathNew = "/proxy/network/status" - - logoutPath = "/api/logout" - - defaultUserAgent = "go-unifi/0.0.1" - - ApiKeyHeader = "X-Api-Key" - CsrfHeader = "X-Csrf-Token" - UserAgentHeader = "User-Agent" - AcceptHeader = "Accept" - ContentTypeHeader = "Content-Type" - - SoftValidation validationMode = "soft" - HardValidation validationMode = "hard" - DisableValidation validationMode = "disable" - DefaultValidation validationMode = SoftValidation // TODO: change to hard in next major version -) - -type validationMode string - -type ClientConfig struct { - URL string `validate:"required,http_url"` - APIKey string `validate:"required_without_all=User Pass"` - User string `validate:"excluded_with=APIKey,required_with=Pass"` - Pass string `validate:"excluded_with=APIKey,required_with=User"` - Timeout time.Duration // how long to wait for replies, default: forever. - VerifySSL bool - Interceptors []ClientInterceptor - HttpCustomizer HttpCustomizer - UserAgent string - ErrorHandler ResponseErrorHandler - UseLocking bool - ValidationMode validationMode `validate:"omitempty,oneof=soft hard disable"` -} - -type Client struct { - BaseURL *url.URL - SysInfo *SysInfo - apiPaths *ApiPaths - config *ClientConfig - http *http.Client - interceptors []ClientInterceptor - errorHandler ResponseErrorHandler - lock sync.Mutex - validator *validator -} - -type ApiPaths struct { - ApiPath string - ApiV2Path string - LoginPath string - StatusPath string - LogoutPath string -} - -var ( - OldStyleAPI = ApiPaths{ - ApiPath: apiPath, - ApiV2Path: apiV2Path, - LoginPath: loginPath, - StatusPath: statusPath, - LogoutPath: logoutPath, - } - NewStyleAPI = ApiPaths{ - ApiPath: apiPathNew, - ApiV2Path: apiV2PathNew, - LoginPath: loginPathNew, - StatusPath: statusPathNew, - LogoutPath: logoutPath, - } -) - -type ServerInfo struct { - Up bool `json:"up"` - ServerVersion string `json:"server_version"` - UUID string `json:"uuid"` -} - -type HttpCustomizer func(transport *http.Transport) error - -type ClientInterceptor interface { - InterceptRequest(req *http.Request) error - InterceptResponse(resp *http.Response) error -} -type ApiKeyAuthInterceptor struct { - apiKey string -} - -func (a *ApiKeyAuthInterceptor) InterceptRequest(req *http.Request) error { - req.Header.Set(ApiKeyHeader, a.apiKey) - return nil -} - -func (a *ApiKeyAuthInterceptor) InterceptResponse(_ *http.Response) error { - return nil -} - -type CsrfInterceptor struct { - csrfToken string -} - -func (c *CsrfInterceptor) InterceptRequest(req *http.Request) error { - if c.csrfToken != "" { - req.Header.Set(CsrfHeader, c.csrfToken) - } - return nil -} - -func (c *CsrfInterceptor) InterceptResponse(resp *http.Response) error { - if csrf := resp.Header.Get(CsrfHeader); csrf != "" { - c.csrfToken = csrf - } - return nil -} - -type DefaultHeadersInterceptor struct { - headers map[string]string -} - -func (d *DefaultHeadersInterceptor) InterceptRequest(req *http.Request) error { - for key, value := range d.headers { - req.Header.Set(key, value) - } - return nil -} - -func (d *DefaultHeadersInterceptor) InterceptResponse(_ *http.Response) error { - return nil -} - -func (c *Client) RegisterInterceptor(interceptor *ClientInterceptor) { - // ensure no duplicate interceptors - if !slices.Contains(c.interceptors, *interceptor) { - c.interceptors = append(c.interceptors, *interceptor) - } -} - -type ResponseErrorHandler interface { - HandleError(resp *http.Response) error -} - -// NewClient creates a http.Client with authenticated cookies. -// Used to make additional, authenticated requests to the APIs. -// Start here. -func NewClient(config *ClientConfig) (*Client, error) { - v, err := newValidator() - if err != nil { - return nil, fmt.Errorf("failed creating validator: %w", err) - } - - if err := v.Validate(config); err != nil { - return nil, fmt.Errorf("failed validating config: %w", err) - } - - u, err := newUnifi(config, v) - if err != nil { - return nil, fmt.Errorf("failed creating unifi client: %w", err) - } - - if err = u.determineApiStyle(); err != nil { - return u, fmt.Errorf("failed determining API style: %w", err) - } - - if err = u.Login(); err != nil { - return u, fmt.Errorf("failed logging in: %w", err) - } - - if sysInfo, err := u.getSystemInformation(); err != nil { - return u, fmt.Errorf("failed getting server info: %w", err) - } else { - u.SysInfo = sysInfo - } - return u, nil -} - -func parseBaseUrl(base string) (*url.URL, error) { - var err error - baseURL, err := url.Parse(base) - if err != nil { - return nil, err - } - - // error for people who are still passing hard coded old paths - if path := strings.TrimSuffix(baseURL.Path, "/"); path == apiPath { - return nil, fmt.Errorf("expected a base URL without the `/api`, got: %q", baseURL) - } - - return baseURL, nil -} - -func newUnifi(config *ClientConfig, v *validator) (*Client, error) { - var err error - - config.URL = strings.TrimRight(config.URL, "/") - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{InsecureSkipVerify: !config.VerifySSL}, //nolint: gosec - } - - if config.HttpCustomizer != nil { - if err = config.HttpCustomizer(transport); err != nil { - return nil, fmt.Errorf("failed customizing HTTP transport: %w", err) - } - } - - client := &http.Client{ - Timeout: config.Timeout, - Transport: transport, - } - - if config.APIKey == "" { - // old user/pass style use the cookie jar - jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) - if err != nil { - return nil, fmt.Errorf("failed creating cookiejar: %w", err) - } - client.Jar = jar - } - baseURL, err := parseBaseUrl(config.URL) - if err != nil { - return nil, fmt.Errorf("failed parsing base URL: %w", err) - } - var interceptors []ClientInterceptor - - if config.APIKey != "" { - interceptors = append(interceptors, &ApiKeyAuthInterceptor{apiKey: config.APIKey}) - } else { - // CSRF is only needed for user/pass auth - interceptors = append(interceptors, &CsrfInterceptor{}) - } - if len(config.UserAgent) == 0 { - config.UserAgent = defaultUserAgent - } - interceptors = append(interceptors, &DefaultHeadersInterceptor{headers: map[string]string{ - UserAgentHeader: config.UserAgent, - AcceptHeader: "application/json", - ContentTypeHeader: "application/json; charset=utf-8", - }}) - - var errorHandler ResponseErrorHandler - if config.ErrorHandler != nil { - errorHandler = config.ErrorHandler - } else { - errorHandler = &DefaultResponseErrorHandler{} - } - if config.ValidationMode == "" { - config.ValidationMode = DefaultValidation - } - u := &Client{ - BaseURL: baseURL, - config: config, - http: client, - interceptors: interceptors, - errorHandler: errorHandler, - lock: sync.Mutex{}, - validator: v, - } - for _, interceptor := range config.Interceptors { - // add any custom interceptors and ensure no duplicates - u.RegisterInterceptor(&interceptor) - } - - return u, nil -} - -// Login is a helper method. It can be called to grab a new authentication cookie. -// Only useful if you are using user/pass auth. -func (c *Client) Login() error { - if c.config.APIKey != "" { - // no need to login on api-key auth - return nil - } - - ctx, cancel := c.createRequestContext() - defer cancel() - - err := c.Post(ctx, c.apiPaths.LoginPath, &struct { - Username string `json:"username"` - Password string `json:"password"` - }{ - Username: c.config.User, - Password: c.config.Pass, - }, nil) - if err != nil { - return err - } - return nil -} - -// Logout closes the current session. Only useful if you are using user/pass auth. -func (c *Client) Logout() error { - if c.config.APIKey != "" { - // no need to logout on api-key auth - return nil - } - ctx, cancel := c.createRequestContext() - defer cancel() - - // a post is needed for logout - err := c.Post(ctx, c.apiPaths.LogoutPath, nil, nil) - - return err -} - -func (c *Client) createRequestContext() (context.Context, context.CancelFunc) { - var ( - ctx = context.Background() - cancel = func() {} - ) - if c.config.Timeout != 0 { - ctx, cancel = context.WithTimeout(ctx, c.config.Timeout) - } - return ctx, cancel -} - -// with the release of controller version 5.12.55 on UDM in Jan 2020 the api paths -// changed and broke this library. This function runs when `NewClient()` is called to -// check if this is a newer controller or not. If it is, we set new to true. -// Setting new to true makes the path() method return different (new) paths. -func (c *Client) determineApiStyle() error { - ctx, cancel := c.createRequestContext() - defer cancel() - - // c.DebugLog("Requesting %s/ to determine API paths", c.URL) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseURL.String(), nil) - if err != nil { - return err - } - - // We can't share these cookies with other requests, so make a new client. - // Checking the return code on the first request so don't follow a redirect. - client := &http.Client{ - CheckRedirect: func(_ *http.Request, _ []*http.Request) error { - return http.ErrUseLastResponse - }, - Transport: c.http.Transport, - } - - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() // we need no data here. - _, _ = io.Copy(io.Discard, resp.Body) // avoid leaking. - - switch resp.StatusCode { - case http.StatusOK: - c.apiPaths = &NewStyleAPI // The new version returns a "200" for a / request. - case http.StatusFound: - c.apiPaths = &OldStyleAPI // The old version returns a "302" (to /manage) for a / request. - default: - return fmt.Errorf("expected 200 or 302 status code, but got: %d", resp.StatusCode) - } - if c.apiPaths == &OldStyleAPI && c.config.APIKey != "" { - return errors.New("unable to use API key authentication with old style API. Switch to user/pass authentication or update controller to latest version") - } - return nil -} - -func (c *Client) getOldSysInfo(ctx context.Context) (*SysInfo, error) { - var response struct { - Data ServerInfo `json:"Meta"` - } - - err := c.Get(ctx, c.apiPaths.StatusPath, nil, &response) - if err != nil { - return nil, err - } - data := response.Data - return &SysInfo{ - Version: data.ServerVersion, - }, nil -} - -// getSystemInformation reads the controller's version and UUID. Only call this if you -// previously called Login and suspect the controller version has changed. -func (c *Client) getSystemInformation() (*SysInfo, error) { - ctx, cancel := c.createRequestContext() - defer cancel() - - var resultingError error - info, err := c.GetSystemInfo(ctx, "default") // get for default site which must exist - if err != nil { - resultingError = err - } else if info == nil || info.Version == "" { - resultingError = errors.New("new API returned empty server info") - } - - if resultingError != nil { - info, err = c.getOldSysInfo(ctx) - if err != nil { - resultingError = errors.Join(resultingError, err) - } else if info == nil || info.Version == "" { - resultingError = errors.Join(resultingError, errors.New("old API returned empty server info")) - } else { - resultingError = nil - } - } - - if resultingError != nil { - return nil, resultingError - } - return info, nil -} - -func marshalRequest(reqBody interface{}) (io.Reader, error) { - if reqBody == nil { - return nil, nil //nolint: nilnil - } - reqBytes, err := json.Marshal(reqBody) - if err != nil { - return nil, err - } - return bytes.NewReader(reqBytes), nil -} - -func (c *Client) createRequestURL(apiPath string) (*url.URL, error) { - reqURL, err := url.Parse(apiPath) - if err != nil { - return nil, err - } - if !strings.HasPrefix(apiPath, "/") && !reqURL.IsAbs() { - reqURL.Path = path.Join(c.apiPaths.ApiPath, reqURL.Path) - } - - return c.BaseURL.ResolveReference(reqURL), nil -} - -func (c *Client) validateRequestBody(reqBody interface{}) error { - if reqBody != nil && c.config.ValidationMode != DisableValidation { - if err := c.validator.Validate(reqBody); err != nil { - err = fmt.Errorf("failed validating request body: %w", err) - if c.config.ValidationMode == HardValidation { - return err - } else { - fmt.Println(err) - } - } - } - return nil -} - -// Do performs a request to the given API path with the given method. -func (c *Client) Do(ctx context.Context, method, apiPath string, reqBody interface{}, respBody interface{}) error { - if err := c.validateRequestBody(reqBody); err != nil { - return err - } - reqReader, err := marshalRequest(reqBody) - if err != nil { - return fmt.Errorf("unable to marshal request: %w", err) - } - - url, err := c.createRequestURL(apiPath) - if err != nil { - return fmt.Errorf("unable to create request URL: %w", err) - } - req, err := http.NewRequestWithContext(ctx, method, url.String(), reqReader) - if err != nil { - return fmt.Errorf("unable to create request: %s %s %w", method, apiPath, err) - } - if c.config.UseLocking { - c.lock.Lock() - defer c.lock.Unlock() - } - - for _, interceptor := range c.interceptors { - if err := interceptor.InterceptRequest(req); err != nil { - return err - } - } - - resp, err := c.http.Do(req) - if err != nil { - return fmt.Errorf("unable to perform request: %s %s %w", method, apiPath, err) - } - defer resp.Body.Close() - - for _, interceptor := range c.interceptors { - if err := interceptor.InterceptResponse(resp); err != nil { - return err - } - } - if err := c.errorHandler.HandleError(resp); err != nil { - return err - } - if respBody == nil || resp.ContentLength == 0 { - return nil - } - - err = json.NewDecoder(resp.Body).Decode(respBody) - if err != nil { - return fmt.Errorf("unable to decode body: %s %s %w", method, apiPath, err) - } - - return nil -} - -// Get performs a GET request to the given API path. -func (c *Client) Get(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { - return c.Do(context, http.MethodGet, apiPath, reqBody, respBody) -} - -// Post performs a POST request to the given API path. -func (c *Client) Post(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { - return c.Do(context, http.MethodPost, apiPath, reqBody, respBody) -} - -// Put performs a PUT request to the given API path. -func (c *Client) Put(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { - return c.Do(context, http.MethodPut, apiPath, reqBody, respBody) -} - -// Delete performs a DELETE request to the given API path. -func (c *Client) Delete(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error { - return c.Do(context, http.MethodDelete, apiPath, reqBody, respBody) -} +// DEPRECATED: The content of this file has been refactored and split into separate files to improve maintainability and readability. +// Please refer to the following files for the actual implementation: +// - client.go: contains Client, ClientConfig, Credentials, newClientInternal, NewClient, NewBareClient etc. +// - api_paths.go: contains API path constants and ApiPaths struct definitions. +// - interceptors.go: contains interceptor interface and implementations (ApiKeyAuthInterceptor, CsrfInterceptor, DefaultHeadersInterceptor). +// - requests.go: contains request handling functions (marshalRequest, createRequestURL, Do, Get, Post, Put, Delete, etc.). + +// This file is left intentionally empty to avoid duplicate declarations. Will be removed soon. diff --git a/unifi/unifi_test.go b/unifi/unifi_test.go index 13049bc..ecd62df 100644 --- a/unifi/unifi_test.go +++ b/unifi/unifi_test.go @@ -46,10 +46,10 @@ func verifyInterceptorPresence(a *assert.Assertions, c *Client, interceptors []i } } -func TestNewClient(t *testing.T) { +func TestNewBareClient(t *testing.T) { t.Parallel() a := assert.New(t) - c, err := NewClient(&ClientConfig{ + c, err := NewBareClient(&ClientConfig{ URL: localUrl, User: "admin", Pass: "password", @@ -58,8 +58,8 @@ func TestNewClient(t *testing.T) { require.Error(t, err) a.EqualValues(localUrl, c.BaseURL.String()) a.Contains(err.Error(), "connection refused", "an invalid destination should produce a connection error.") - verifyInterceptorPresence(a, c, []interface{}{&CsrfInterceptor{}, &DefaultHeadersInterceptor{}}, true) - verifyInterceptorPresence(a, c, []interface{}{&ApiKeyAuthInterceptor{}}, false) + verifyInterceptorPresence(a, c, []interface{}{&CSRFInterceptor{}, &DefaultHeadersInterceptor{}}, true) + verifyInterceptorPresence(a, c, []interface{}{&APIKeyAuthInterceptor{}}, false) } func TestNewClientWithApiKey(t *testing.T) { @@ -76,8 +76,8 @@ func TestNewClientWithApiKey(t *testing.T) { require.Error(t, err) a.EqualValues(localUrl, c.BaseURL.String()) a.Contains(err.Error(), "connection refused", "an invalid destination should produce a connection error.") - verifyInterceptorPresence(a, c, []interface{}{&ApiKeyAuthInterceptor{}, &DefaultHeadersInterceptor{}}, true) - verifyInterceptorPresence(a, c, []interface{}{&CsrfInterceptor{}}, false) + verifyInterceptorPresence(a, c, []interface{}{&APIKeyAuthInterceptor{}, &DefaultHeadersInterceptor{}}, true) + verifyInterceptorPresence(a, c, []interface{}{&CSRFInterceptor{}}, false) } func TestCustomizeHttpClient(t *testing.T) { @@ -418,23 +418,27 @@ func TestAuthConfigurationValidation(t *testing.T) { {"test", "test", "test", true}, } + v, err := newValidator() + require.NoError(t, err) for _, tc := range testCases { t.Run(fmt.Sprintf("user:%s-pass:%s-apikey:%s", tc.User, tc.Pass, tc.APIKey), func(t *testing.T) { t.Parallel() // given - _, err := NewClient(&ClientConfig{ + cc := &ClientConfig{ URL: testUrl, User: tc.User, Pass: tc.Pass, APIKey: tc.APIKey, - }) + } + // when + err = v.Validate(cc) // then if tc.shouldFail { require.ErrorContains(t, err, "validation failed") return } - require.ErrorContains(t, err, "dial tcp") // error will anyway exist, but it will be not related to config + require.NoError(t, err) }) } } @@ -461,22 +465,75 @@ func TestUrlValidation(t *testing.T) { t.Run(tc.URL, func(t *testing.T) { t.Parallel() // given - _, err := NewClient(&ClientConfig{ + cc := &ClientConfig{ URL: tc.URL, APIKey: "test-key", - }) + } + v, err := newValidator() + require.NoError(t, err) + + // when + err = v.Validate(cc) // then if tc.shouldFail { require.ErrorContains(t, err, "validation failed") - require.ErrorContains(t, err, tc.errorString) return } - require.ErrorContains(t, err, "dial tcp") // error will anyway exist, but it will be not related to config + require.NoError(t, err) }) } } +func TestValidationModeValidation(t *testing.T) { + t.Parallel() + testCases := []struct { + validationMode validationMode + expectedError string + }{ + {SoftValidation, ""}, + {HardValidation, ""}, + {DisableValidation, ""}, + {"invalid", "must be one of"}, + } + + for _, tc := range testCases { + t.Run(string(tc.validationMode), func(t *testing.T) { + t.Parallel() + // given + cc := &ClientConfig{ + URL: testUrl, + APIKey: "test-key", + ValidationMode: tc.validationMode, + } + v, err := newValidator() + require.NoError(t, err) + + // when + err = v.Validate(cc) + + // then + if tc.expectedError != "" { + require.ErrorContains(t, err, tc.expectedError) + return + } + require.NoError(t, err) + }) + } +} + +func TestClientConfigValidationExecutedOnNewClient(t *testing.T) { + t.Parallel() + a := assert.New(t) + // given + cc := &ClientConfig{URL: "invalid URL"} + // when + c, err := NewClient(cc) + // then + require.ErrorContains(t, err, "validation failed") + a.Nil(c) +} + type validateableBody struct { Data string `json:"data" validate:"required"` } @@ -584,7 +641,7 @@ func TestGetSystemInformation(t *testing.T) { VerifySSL: false, }) - sysInfo, err := c.getSystemInformation() + sysInfo, err := c.GetSystemInformation() if tc.expectedError != "" { require.ErrorContains(t, err, tc.expectedError) @@ -602,17 +659,17 @@ func TestParseBaseUrl(t *testing.T) { a := assert.New(t) // Valid URL without /api in the path. - base, err := parseBaseUrl("http://localhost") + base, err := parseBaseURL("http://localhost") require.NoError(t, err) a.Equal("http", base.Scheme) a.Equal("", base.Path) // URL with trailing slash /api/ - _, err = parseBaseUrl("http://localhost/api/") + _, err = parseBaseURL("http://localhost/api/") require.ErrorContains(t, err, "expected a base URL without the `/api`") // URL with /api in path (no trailing slash). - _, err = parseBaseUrl("http://localhost/api") + _, err = parseBaseURL("http://localhost/api") require.ErrorContains(t, err, "expected a base URL without the `/api`") } @@ -642,10 +699,10 @@ func TestRegisterInterceptor(t *testing.T) { // Create a dummy interceptor (using TestInterceptor already defined in the file). var dummy ClientInterceptor = &TestInterceptor{} initialCount := len(client.interceptors) - client.RegisterInterceptor(&dummy) + client.AddInterceptor(&dummy) assert.Len(t, client.interceptors, initialCount+1) // Attempt to add the same interceptor again. - client.RegisterInterceptor(&dummy) + client.AddInterceptor(&dummy) assert.Len(t, client.interceptors, initialCount+1) } @@ -730,7 +787,7 @@ func TestCreateRequestURLInvalid(t *testing.T) { BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, apiPaths: &NewStyleAPI, } - _, err := c.createRequestURL("://bad-url") + _, err := c.buildRequestURL("://bad-url") require.Error(t, err) assert.Contains(t, err.Error(), "parse") } @@ -741,7 +798,7 @@ func TestCreateRequestURLAbsolute(t *testing.T) { BaseURL: &url.URL{Scheme: "http", Host: "localhost"}, apiPaths: &NewStyleAPI, } - reqURL, err := c.createRequestURL("http://example.com/test") + reqURL, err := c.buildRequestURL("http://example.com/test") require.NoError(t, err) assert.Equal(t, "http://example.com/test", reqURL.String()) } @@ -749,9 +806,9 @@ func TestCreateRequestURLAbsolute(t *testing.T) { func TestCreateRequestContextTimeout(t *testing.T) { t.Parallel() c := &Client{ - config: &ClientConfig{Timeout: 100 * time.Millisecond}, + timeout: 100 * time.Millisecond, } - ctx, cancel := c.createRequestContext() + ctx, cancel := c.newRequestContext() defer cancel() _, ok := ctx.Deadline() require.True(t, ok) @@ -787,9 +844,7 @@ func TestLoginWithAPIKeyDirect(t *testing.T) { t.Parallel() // Create a client manually with the APIKey set. c := &Client{ - config: &ClientConfig{ - APIKey: "abc", - }, + credentials: APIKeyCredentials{APIKey: "abc"}, } err := c.Login() assert.NoError(t, err)