refactor: major refactor of unifi.go to split it into multiple files for cohesiveness and readability.

Even more documentation
This commit is contained in:
Mateusz Filipowicz
2025-02-10 22:59:39 +01:00
parent 473a5d0d5c
commit c6e20b675c
7 changed files with 762 additions and 574 deletions

100
unifi/api_paths.go Normal file
View File

@@ -0,0 +1,100 @@
package unifi
import (
"errors"
"fmt"
"io"
"net/http"
)
const (
apiPath = "/api"
apiV2Path = "/v2/api"
apiPathNew = "/proxy/network/api"
apiV2PathNew = "/proxy/network/v2/api"
loginPath = "/api/login"
loginPathNew = "/api/auth/login"
statusPath = "/status"
statusPathNew = "/proxy/network/status"
logoutPath = "/api/logout"
defaultUserAgent = "go-unifi/0.0.1"
ApiKeyHeader = "X-Api-Key"
CsrfHeader = "X-Csrf-Token"
UserAgentHeader = "User-Agent"
AcceptHeader = "Accept"
ContentTypeHeader = "Content-Type"
)
// APIPaths defines the URL paths used by the client.
type APIPaths struct {
ApiPath string
ApiV2Path string
LoginPath string
StatusPath string
LogoutPath string
}
var (
OldStyleAPI = APIPaths{
ApiPath: apiPath,
ApiV2Path: apiV2Path,
LoginPath: loginPath,
StatusPath: statusPath,
LogoutPath: logoutPath,
}
NewStyleAPI = APIPaths{
ApiPath: apiPathNew,
ApiV2Path: apiV2PathNew,
LoginPath: loginPathNew,
StatusPath: statusPathNew,
LogoutPath: logoutPath,
}
)
// determineApiStyle checks the base URL to decide which API style to use and sets the apiPaths accordingly.
func (c *Client) determineApiStyle() error {
ctx, cancel := c.newRequestContext()
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseURL.String(), nil)
if err != nil {
return err
}
client := &http.Client{
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: c.http.Transport,
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// Discard response body to avoid leaks
_, _ = io.Copy(io.Discard, resp.Body)
switch resp.StatusCode {
case http.StatusOK:
c.apiPaths = &NewStyleAPI
case http.StatusFound:
c.apiPaths = &OldStyleAPI
default:
return fmt.Errorf("expected 200 or 302 status code, but got: %d", resp.StatusCode)
}
if c.apiPaths == &OldStyleAPI && c.credentials.IsAPIKey() {
return errors.New("unable to use API key authentication with old style API. Switch to user/pass authentication or update controller to latest version")
}
return nil
}

298
unifi/client.go Normal file
View File

@@ -0,0 +1,298 @@
package unifi
import (
"crypto/tls"
"fmt"
"net/http"
"net/http/cookiejar"
"net/url"
"slices"
"strings"
"sync"
"time"
"golang.org/x/net/publicsuffix"
)
// validationMode represents the mode for request validation.
// It may be set to "soft", "hard", or "disable". The default is "soft".
type validationMode string
const (
// SoftValidation indicates that validation errors are logged as warnings but do not prevent the request from proceeding.
SoftValidation validationMode = "soft"
// HardValidation indicates that validation errors are treated as fatal and will cause the request to be rejected.
HardValidation validationMode = "hard"
// DisableValidation indicates that no validation is performed on the request body.
DisableValidation validationMode = "disable"
// DefaultValidation is the default validation mode used if none is specified.
// Currently set to SoftValidation, but may change to HardValidation in a future major version.
DefaultValidation validationMode = SoftValidation // TODO: change to hard in next major version
)
// HttpCustomizer is a function type for customizing the HTTP transport.
// It receives a pointer to an http.Transport and returns an error if customization fails.
type HttpCustomizer func(transport *http.Transport) error
// ResponseErrorHandler defines a method for handling HTTP response errors.
// HandleError processes the HTTP response and returns an error if the response indicates failure.
type ResponseErrorHandler interface {
// HandleError processes the HTTP response and returns an error if the response signals a failure.
HandleError(resp *http.Response) error
}
/*
ClientConfig holds configuration parameters for creating a UniFi Client.
Fields:
URL: The base URL of the UniFi controller. Must be a valid URL and should not include the `/api` suffix.
APIKey: An API key used for authentication. Provide this if user/password credentials are not used.
User: The username for user/password authentication. Must be provided with Pass if APIKey is not used.
Pass: The password for user/password authentication. Must be provided with User if APIKey is not used.
Timeout: The maximum duration to wait for responses; default is no timeout.
VerifySSL: When false, disables SSL certificate verification.
Interceptors: A slice of ClientInterceptor implementations that can modify requests and responses.
HttpCustomizer:An optional function to customize the HTTP transport (e.g., for custom TLS settings).
UserAgent: The User-Agent header string for outgoing HTTP requests.
ErrorHandler: A custom handler for processing HTTP response errors.
UseLocking: If true, enables internal locking for concurrent request processing.
ValidationMode:The mode for validating request bodies. Can be "soft", "hard", or "disable".
*/
type ClientConfig struct {
URL string `validate:"required,http_url"`
APIKey string `validate:"required_without_all=User Pass"`
User string `validate:"excluded_with=APIKey,required_with=Pass"`
Pass string `validate:"excluded_with=APIKey,required_with=User"`
Timeout time.Duration // How long to wait for replies, default: forever.
VerifySSL bool
Interceptors []ClientInterceptor
HttpCustomizer HttpCustomizer
UserAgent string
ErrorHandler ResponseErrorHandler
UseLocking bool
ValidationMode validationMode `validate:"omitempty,oneof=soft hard disable"`
}
// Credentials abstracts authentication credentials.
// It defines methods to determine the type of credentials and retrieve the associated values.
type Credentials interface {
// IsAPIKey returns true if the credentials represent an API key.
IsAPIKey() bool
// GetAPIKey returns the API key; returns an empty string if not applicable.
GetAPIKey() string
// GetUser returns the username for authentication; returns an empty string if not applicable.
GetUser() string
// GetPass returns the password for authentication; returns an empty string if not applicable.
GetPass() string
}
// APIKeyCredentials holds API key authentication details.
type APIKeyCredentials struct {
APIKey string
}
func (a APIKeyCredentials) IsAPIKey() bool { return true }
func (a APIKeyCredentials) GetAPIKey() string { return a.APIKey }
func (a APIKeyCredentials) GetUser() string { return "" }
func (a APIKeyCredentials) GetPass() string { return "" }
// UserPassCredentials holds user/password authentication.
type UserPassCredentials struct {
User string
Pass string
}
func (u UserPassCredentials) IsAPIKey() bool { return false }
func (u UserPassCredentials) GetAPIKey() string { return "" }
func (u UserPassCredentials) GetUser() string { return u.User }
func (u UserPassCredentials) GetPass() string { return u.Pass }
// Client represents a UniFi client.
type Client struct {
BaseURL *url.URL
SysInfo *SysInfo
apiPaths *APIPaths
timeout time.Duration
credentials Credentials
validationMode validationMode
useLocking bool
http *http.Client
interceptors []ClientInterceptor
errorHandler ResponseErrorHandler
lock sync.Mutex
validator *validator
}
// AddInterceptor adds a ClientInterceptor to the client's interceptor list if it is not already present.
// It appends the interceptor only if it is not already included in the list.
func (c *Client) AddInterceptor(interceptor *ClientInterceptor) {
if !slices.Contains(c.interceptors, *interceptor) {
c.interceptors = append(c.interceptors, *interceptor)
}
}
func parseBaseURL(base string) (*url.URL, error) {
baseURL, err := url.Parse(base)
if err != nil {
return nil, err
}
// Check if base URL's path is "/api" (deprecated usage now in api_paths.go)
if strings.TrimSuffix(baseURL.Path, "/") == "/api" {
return nil, fmt.Errorf("expected a base URL without the `/api`, got: %q", baseURL)
}
return baseURL, nil
}
func newClientFromConfig(config *ClientConfig, v *validator) (*Client, error) {
var err error
config.URL = strings.TrimRight(config.URL, "/")
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: !config.VerifySSL},
}
if config.HttpCustomizer != nil {
if err = config.HttpCustomizer(transport); err != nil {
return nil, fmt.Errorf("failed customizing HTTP transport: %w", err)
}
}
client := &http.Client{
Timeout: config.Timeout,
Transport: transport,
}
if config.APIKey == "" {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
return nil, fmt.Errorf("failed creating cookiejar: %w", err)
}
client.Jar = jar
}
baseURL, err := parseBaseURL(config.URL)
if err != nil {
return nil, fmt.Errorf("failed parsing base URL: %w", err)
}
var interceptors []ClientInterceptor
var credentials Credentials
if config.APIKey != "" {
credentials = APIKeyCredentials{APIKey: config.APIKey}
interceptors = append(interceptors, &APIKeyAuthInterceptor{apiKey: config.APIKey})
} else {
credentials = UserPassCredentials{User: config.User, Pass: config.Pass}
interceptors = append(interceptors, &CSRFInterceptor{})
}
if len(config.UserAgent) == 0 {
config.UserAgent = defaultUserAgent
}
interceptors = append(interceptors, &DefaultHeadersInterceptor{headers: map[string]string{
UserAgentHeader: config.UserAgent,
AcceptHeader: "application/json",
ContentTypeHeader: "application/json; charset=utf-8",
}})
var errorHandler ResponseErrorHandler
if config.ErrorHandler != nil {
errorHandler = config.ErrorHandler
} else {
errorHandler = &DefaultResponseErrorHandler{}
}
if config.ValidationMode == "" {
config.ValidationMode = DefaultValidation
}
u := &Client{
BaseURL: baseURL,
timeout: config.Timeout,
credentials: credentials,
validationMode: config.ValidationMode,
useLocking: config.UseLocking,
http: client,
interceptors: interceptors,
errorHandler: errorHandler,
lock: sync.Mutex{},
validator: v,
}
for _, interceptor := range config.Interceptors {
u.AddInterceptor(&interceptor)
}
return u, nil
}
// NewClient creates and initializes a new UniFi client based on the provided ClientConfig.
// It validates the configuration, determines the API style, performs login if necessary,
// and retrieves system information from the UniFi controller.
// On success, it returns a pointer to a Client; otherwise, it returns an error.
func NewClient(config *ClientConfig) (*Client, error) {
c, err := NewBareClient(config)
if err != nil {
return c, err
}
if err = c.Login(); err != nil {
return c, fmt.Errorf("failed logging in: %w", err)
}
if sysInfo, err := c.GetSystemInformation(); err != nil {
return c, fmt.Errorf("failed getting server info: %w", err)
} else {
c.SysInfo = sysInfo
}
return c, nil
}
// NewBareClient creates a new UniFi client without performing login or system information retrieval.
// When user/pass authentication is used, you must call Login before making requests.
// It validates the configuration, determines the API style, and returns a pointer to the client on success.
func NewBareClient(config *ClientConfig) (*Client, error) {
v, err := newValidator()
if err != nil {
return nil, fmt.Errorf("failed creating validator: %w", err)
}
if err = v.Validate(config); err != nil {
return nil, fmt.Errorf("failed validating client configuration: %w", err)
}
c, err := newClientFromConfig(config, v)
if err != nil {
return nil, fmt.Errorf("failed creating unifi client: %w", err)
}
if err = c.determineApiStyle(); err != nil {
return c, fmt.Errorf("failed determining API style: %w", err)
}
return c, nil
}
// Login authenticates the client using user/pass credentials.
// For API key authentication, Login does nothing.
// It returns an error if the authentication process fails.
func (c *Client) Login() error {
if c.credentials.IsAPIKey() {
return nil
}
ctx, cancel := c.newRequestContext()
defer cancel()
err := c.Post(ctx, c.apiPaths.LoginPath, &struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: c.credentials.GetUser(),
Password: c.credentials.GetPass(),
}, nil)
if err != nil {
return err
}
return nil
}
// Logout terminates the client's session for user/pass authentication.
// For API key authentication, Logout does nothing.
// It returns an error if the logout process fails.
func (c *Client) Logout() error {
if c.credentials.IsAPIKey() {
return nil
}
ctx, cancel := c.newRequestContext()
defer cancel()
err := c.Post(ctx, c.apiPaths.LogoutPath, nil, nil)
return err
}

71
unifi/interceptors.go Normal file
View File

@@ -0,0 +1,71 @@
package unifi
import "net/http"
// ClientInterceptor defines the interface for interceptors.
// An interceptor can modify HTTP requests and responses.
type ClientInterceptor interface {
InterceptRequest(req *http.Request) error
InterceptResponse(resp *http.Response) error
}
// APIKeyAuthInterceptor adds an API key to outgoing requests.
// It implements the ClientInterceptor interface.
type APIKeyAuthInterceptor struct {
apiKey string
}
// InterceptRequest sets the API key header on the given HTTP request.
// It adds the header defined by ApiKeyHeader with the stored API key and returns nil.
func (a *APIKeyAuthInterceptor) InterceptRequest(req *http.Request) error {
req.Header.Set(ApiKeyHeader, a.apiKey)
return nil
}
// InterceptResponse does not modify the HTTP response and always returns nil.
func (a *APIKeyAuthInterceptor) InterceptResponse(_ *http.Response) error {
return nil
}
// CSRFInterceptor manages CSRF tokens when using user/pass authentication.
// It implements the ClientInterceptor interface.
type CSRFInterceptor struct {
CSRFToken string
}
// InterceptRequest adds the CSRF token to the HTTP request header if it is set.
// It returns nil on success.
func (c *CSRFInterceptor) InterceptRequest(req *http.Request) error {
if c.CSRFToken != "" {
req.Header.Set(CsrfHeader, c.CSRFToken)
}
return nil
}
// InterceptResponse extracts the CSRF token from the HTTP response header, if present, and stores it for future requests.
func (c *CSRFInterceptor) InterceptResponse(resp *http.Response) error {
if token := resp.Header.Get(CsrfHeader); token != "" {
c.CSRFToken = token
}
return nil
}
// DefaultHeadersInterceptor sets default HTTP headers for requests.
// It implements the ClientInterceptor interface.
type DefaultHeadersInterceptor struct {
headers map[string]string
}
// InterceptRequest sets default HTTP headers on the request as specified in the interceptor's headers map.
// It returns nil on success.
func (d *DefaultHeadersInterceptor) InterceptRequest(req *http.Request) error {
for key, value := range d.headers {
req.Header.Set(key, value)
}
return nil
}
// InterceptResponse does not modify the HTTP response and always returns nil.
func (d *DefaultHeadersInterceptor) InterceptResponse(_ *http.Response) error {
return nil
}

148
unifi/requests.go Normal file
View File

@@ -0,0 +1,148 @@
package unifi
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"
)
// marshalRequest marshals the request body to an io.Reader. Returns nil if reqBody is nil.
func marshalRequest(reqBody interface{}) (io.Reader, error) {
if reqBody == nil {
return nil, nil //nolint: nilnil
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
return bytes.NewReader(reqBytes), nil
}
// buildRequestURL constructs the full URL for a given apiPath using the client's BaseURL and apiPaths.
func (c *Client) buildRequestURL(apiPath string) (*url.URL, error) {
reqURL, err := url.Parse(apiPath)
if err != nil {
return nil, err
}
if !strings.HasPrefix(apiPath, "/") && !reqURL.IsAbs() {
reqURL.Path = path.Join(c.apiPaths.ApiPath, reqURL.Path)
}
return c.BaseURL.ResolveReference(reqURL), nil
}
// validateRequestBody validates the request body if validation is enabled.
func (c *Client) validateRequestBody(reqBody interface{}) error {
if reqBody != nil && c.validationMode != DisableValidation {
if err := c.validator.Validate(reqBody); err != nil {
err = fmt.Errorf("failed validating request body: %w", err)
if c.validationMode == HardValidation {
return err
} else {
fmt.Println(err)
}
}
}
return nil
}
// newRequestContext creates a new context for the request with a timeout if specified.
func (c *Client) newRequestContext() (context.Context, context.CancelFunc) {
ctx := context.Background()
if c.timeout != 0 {
return context.WithTimeout(ctx, c.timeout)
}
return ctx, func() {}
}
// Do performs an HTTP request using the given method, apiPath, request body, and decodes the response into respBody.
// It validates the request body, applies interceptors, and decodes the HTTP response into respBody if provided.
// It returns an error if the request or response handling fails.
func (c *Client) Do(ctx context.Context, method, apiPath string, reqBody interface{}, respBody interface{}) error {
if err := c.validateRequestBody(reqBody); err != nil {
return err
}
reqReader, err := marshalRequest(reqBody)
if err != nil {
return fmt.Errorf("unable to marshal request: %w", err)
}
url, err := c.buildRequestURL(apiPath)
if err != nil {
return fmt.Errorf("unable to create request URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, url.String(), reqReader)
if err != nil {
return fmt.Errorf("unable to create request: %s %s %w", method, apiPath, err)
}
if c.useLocking {
c.lock.Lock()
defer c.lock.Unlock()
}
for _, interceptor := range c.interceptors {
if err := interceptor.InterceptRequest(req); err != nil {
return err
}
}
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("unable to perform request: %s %s %w", method, apiPath, err)
}
defer resp.Body.Close()
for _, interceptor := range c.interceptors {
if err := interceptor.InterceptResponse(resp); err != nil {
return err
}
}
if err := c.errorHandler.HandleError(resp); err != nil {
return err
}
if respBody == nil || resp.ContentLength == 0 {
return nil
}
err = json.NewDecoder(resp.Body).Decode(respBody)
if err != nil {
return fmt.Errorf("unable to decode body: %s %s %w", method, apiPath, err)
}
return nil
}
// Get sends an HTTP GET request to the specified API path with the provided request body,
// and decodes the HTTP response into respBody.
// It is a convenience wrapper around Do.
func (c *Client) Get(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(ctx, http.MethodGet, apiPath, reqBody, respBody)
}
// Post sends an HTTP POST request to the specified API path with the provided request body,
// and decodes the HTTP response into respBody.
// It is a convenience wrapper around Do.
func (c *Client) Post(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(ctx, http.MethodPost, apiPath, reqBody, respBody)
}
// Put sends an HTTP PUT request to the specified API path with the provided request body,
// and decodes the HTTP response into respBody.
// It is a convenience wrapper around Do.
func (c *Client) Put(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(ctx, http.MethodPut, apiPath, reqBody, respBody)
}
// Delete sends an HTTP DELETE request to the specified API path with the provided request body,
// and decodes the HTTP response into respBody.
// It is a convenience wrapper around Do.
func (c *Client) Delete(ctx context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(ctx, http.MethodDelete, apiPath, reqBody, respBody)
}

View File

@@ -2,9 +2,11 @@ package unifi
import (
"context"
"errors"
"fmt"
)
// SysInfo represents detailed system information from the UniFi controller.
type SysInfo struct {
Timezone string `json:"timezone"`
Version string `json:"version"`
@@ -71,6 +73,7 @@ type SysInfo struct {
*/
}
// GetSystemInfo retrieves system info using the new API.
func (c *Client) GetSystemInfo(ctx context.Context, id string) (*SysInfo, error) {
var respBody struct {
Meta Meta `json:"Meta"`
@@ -88,3 +91,56 @@ func (c *Client) GetSystemInfo(ctx context.Context, id string) (*SysInfo, error)
return &respBody.Data[0], nil
}
// serverInfo represents basic server info from old API .
type serverInfo struct {
Up bool `json:"up"`
ServerVersion string `json:"server_version"`
UUID string `json:"uuid"`
}
// getOldSysInfo retrieves system information using the old API style.
func (c *Client) getOldSysInfo(ctx context.Context) (*SysInfo, error) {
var response struct {
Data serverInfo `json:"Meta"`
}
err := c.Get(ctx, c.apiPaths.StatusPath, nil, &response)
if err != nil {
return nil, err
}
d := response.Data
return &SysInfo{
Version: d.ServerVersion,
}, nil
}
// GetSystemInformation retrieves system information, trying the new API first and falling back to the old API if necessary.
func (c *Client) GetSystemInformation() (*SysInfo, error) {
ctx, cancel := c.newRequestContext()
defer cancel()
var resultingError error
info, err := c.GetSystemInfo(ctx, "default")
if err != nil {
resultingError = err
} else if info == nil || info.Version == "" {
resultingError = errors.New("new API returned empty server info")
}
if resultingError != nil {
info, err = c.getOldSysInfo(ctx)
if err != nil {
resultingError = errors.Join(resultingError, err)
} else if info == nil || info.Version == "" {
resultingError = errors.Join(resultingError, errors.New("old API returned empty server info"))
} else {
resultingError = nil
}
}
if resultingError != nil {
return nil, resultingError
}
return info, nil
}

View File

@@ -1,550 +1,10 @@
package unifi
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"path"
"slices"
"strings"
"sync"
"time"
"golang.org/x/net/publicsuffix"
)
const (
apiPath = "/api"
apiV2Path = "/v2/api"
apiPathNew = "/proxy/network/api"
apiV2PathNew = "/proxy/network/v2/api"
loginPath = "/api/login"
loginPathNew = "/api/auth/login"
statusPath = "/status"
statusPathNew = "/proxy/network/status"
logoutPath = "/api/logout"
defaultUserAgent = "go-unifi/0.0.1"
ApiKeyHeader = "X-Api-Key"
CsrfHeader = "X-Csrf-Token"
UserAgentHeader = "User-Agent"
AcceptHeader = "Accept"
ContentTypeHeader = "Content-Type"
SoftValidation validationMode = "soft"
HardValidation validationMode = "hard"
DisableValidation validationMode = "disable"
DefaultValidation validationMode = SoftValidation // TODO: change to hard in next major version
)
type validationMode string
type ClientConfig struct {
URL string `validate:"required,http_url"`
APIKey string `validate:"required_without_all=User Pass"`
User string `validate:"excluded_with=APIKey,required_with=Pass"`
Pass string `validate:"excluded_with=APIKey,required_with=User"`
Timeout time.Duration // how long to wait for replies, default: forever.
VerifySSL bool
Interceptors []ClientInterceptor
HttpCustomizer HttpCustomizer
UserAgent string
ErrorHandler ResponseErrorHandler
UseLocking bool
ValidationMode validationMode `validate:"omitempty,oneof=soft hard disable"`
}
type Client struct {
BaseURL *url.URL
SysInfo *SysInfo
apiPaths *ApiPaths
config *ClientConfig
http *http.Client
interceptors []ClientInterceptor
errorHandler ResponseErrorHandler
lock sync.Mutex
validator *validator
}
type ApiPaths struct {
ApiPath string
ApiV2Path string
LoginPath string
StatusPath string
LogoutPath string
}
var (
OldStyleAPI = ApiPaths{
ApiPath: apiPath,
ApiV2Path: apiV2Path,
LoginPath: loginPath,
StatusPath: statusPath,
LogoutPath: logoutPath,
}
NewStyleAPI = ApiPaths{
ApiPath: apiPathNew,
ApiV2Path: apiV2PathNew,
LoginPath: loginPathNew,
StatusPath: statusPathNew,
LogoutPath: logoutPath,
}
)
type ServerInfo struct {
Up bool `json:"up"`
ServerVersion string `json:"server_version"`
UUID string `json:"uuid"`
}
type HttpCustomizer func(transport *http.Transport) error
type ClientInterceptor interface {
InterceptRequest(req *http.Request) error
InterceptResponse(resp *http.Response) error
}
type ApiKeyAuthInterceptor struct {
apiKey string
}
func (a *ApiKeyAuthInterceptor) InterceptRequest(req *http.Request) error {
req.Header.Set(ApiKeyHeader, a.apiKey)
return nil
}
func (a *ApiKeyAuthInterceptor) InterceptResponse(_ *http.Response) error {
return nil
}
type CsrfInterceptor struct {
csrfToken string
}
func (c *CsrfInterceptor) InterceptRequest(req *http.Request) error {
if c.csrfToken != "" {
req.Header.Set(CsrfHeader, c.csrfToken)
}
return nil
}
func (c *CsrfInterceptor) InterceptResponse(resp *http.Response) error {
if csrf := resp.Header.Get(CsrfHeader); csrf != "" {
c.csrfToken = csrf
}
return nil
}
type DefaultHeadersInterceptor struct {
headers map[string]string
}
func (d *DefaultHeadersInterceptor) InterceptRequest(req *http.Request) error {
for key, value := range d.headers {
req.Header.Set(key, value)
}
return nil
}
func (d *DefaultHeadersInterceptor) InterceptResponse(_ *http.Response) error {
return nil
}
func (c *Client) RegisterInterceptor(interceptor *ClientInterceptor) {
// ensure no duplicate interceptors
if !slices.Contains(c.interceptors, *interceptor) {
c.interceptors = append(c.interceptors, *interceptor)
}
}
type ResponseErrorHandler interface {
HandleError(resp *http.Response) error
}
// NewClient creates a http.Client with authenticated cookies.
// Used to make additional, authenticated requests to the APIs.
// Start here.
func NewClient(config *ClientConfig) (*Client, error) {
v, err := newValidator()
if err != nil {
return nil, fmt.Errorf("failed creating validator: %w", err)
}
if err := v.Validate(config); err != nil {
return nil, fmt.Errorf("failed validating config: %w", err)
}
u, err := newUnifi(config, v)
if err != nil {
return nil, fmt.Errorf("failed creating unifi client: %w", err)
}
if err = u.determineApiStyle(); err != nil {
return u, fmt.Errorf("failed determining API style: %w", err)
}
if err = u.Login(); err != nil {
return u, fmt.Errorf("failed logging in: %w", err)
}
if sysInfo, err := u.getSystemInformation(); err != nil {
return u, fmt.Errorf("failed getting server info: %w", err)
} else {
u.SysInfo = sysInfo
}
return u, nil
}
func parseBaseUrl(base string) (*url.URL, error) {
var err error
baseURL, err := url.Parse(base)
if err != nil {
return nil, err
}
// error for people who are still passing hard coded old paths
if path := strings.TrimSuffix(baseURL.Path, "/"); path == apiPath {
return nil, fmt.Errorf("expected a base URL without the `/api`, got: %q", baseURL)
}
return baseURL, nil
}
func newUnifi(config *ClientConfig, v *validator) (*Client, error) {
var err error
config.URL = strings.TrimRight(config.URL, "/")
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: !config.VerifySSL}, //nolint: gosec
}
if config.HttpCustomizer != nil {
if err = config.HttpCustomizer(transport); err != nil {
return nil, fmt.Errorf("failed customizing HTTP transport: %w", err)
}
}
client := &http.Client{
Timeout: config.Timeout,
Transport: transport,
}
if config.APIKey == "" {
// old user/pass style use the cookie jar
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
return nil, fmt.Errorf("failed creating cookiejar: %w", err)
}
client.Jar = jar
}
baseURL, err := parseBaseUrl(config.URL)
if err != nil {
return nil, fmt.Errorf("failed parsing base URL: %w", err)
}
var interceptors []ClientInterceptor
if config.APIKey != "" {
interceptors = append(interceptors, &ApiKeyAuthInterceptor{apiKey: config.APIKey})
} else {
// CSRF is only needed for user/pass auth
interceptors = append(interceptors, &CsrfInterceptor{})
}
if len(config.UserAgent) == 0 {
config.UserAgent = defaultUserAgent
}
interceptors = append(interceptors, &DefaultHeadersInterceptor{headers: map[string]string{
UserAgentHeader: config.UserAgent,
AcceptHeader: "application/json",
ContentTypeHeader: "application/json; charset=utf-8",
}})
var errorHandler ResponseErrorHandler
if config.ErrorHandler != nil {
errorHandler = config.ErrorHandler
} else {
errorHandler = &DefaultResponseErrorHandler{}
}
if config.ValidationMode == "" {
config.ValidationMode = DefaultValidation
}
u := &Client{
BaseURL: baseURL,
config: config,
http: client,
interceptors: interceptors,
errorHandler: errorHandler,
lock: sync.Mutex{},
validator: v,
}
for _, interceptor := range config.Interceptors {
// add any custom interceptors and ensure no duplicates
u.RegisterInterceptor(&interceptor)
}
return u, nil
}
// Login is a helper method. It can be called to grab a new authentication cookie.
// Only useful if you are using user/pass auth.
func (c *Client) Login() error {
if c.config.APIKey != "" {
// no need to login on api-key auth
return nil
}
ctx, cancel := c.createRequestContext()
defer cancel()
err := c.Post(ctx, c.apiPaths.LoginPath, &struct {
Username string `json:"username"`
Password string `json:"password"`
}{
Username: c.config.User,
Password: c.config.Pass,
}, nil)
if err != nil {
return err
}
return nil
}
// Logout closes the current session. Only useful if you are using user/pass auth.
func (c *Client) Logout() error {
if c.config.APIKey != "" {
// no need to logout on api-key auth
return nil
}
ctx, cancel := c.createRequestContext()
defer cancel()
// a post is needed for logout
err := c.Post(ctx, c.apiPaths.LogoutPath, nil, nil)
return err
}
func (c *Client) createRequestContext() (context.Context, context.CancelFunc) {
var (
ctx = context.Background()
cancel = func() {}
)
if c.config.Timeout != 0 {
ctx, cancel = context.WithTimeout(ctx, c.config.Timeout)
}
return ctx, cancel
}
// with the release of controller version 5.12.55 on UDM in Jan 2020 the api paths
// changed and broke this library. This function runs when `NewClient()` is called to
// check if this is a newer controller or not. If it is, we set new to true.
// Setting new to true makes the path() method return different (new) paths.
func (c *Client) determineApiStyle() error {
ctx, cancel := c.createRequestContext()
defer cancel()
// c.DebugLog("Requesting %s/ to determine API paths", c.URL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseURL.String(), nil)
if err != nil {
return err
}
// We can't share these cookies with other requests, so make a new client.
// Checking the return code on the first request so don't follow a redirect.
client := &http.Client{
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: c.http.Transport,
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close() // we need no data here.
_, _ = io.Copy(io.Discard, resp.Body) // avoid leaking.
switch resp.StatusCode {
case http.StatusOK:
c.apiPaths = &NewStyleAPI // The new version returns a "200" for a / request.
case http.StatusFound:
c.apiPaths = &OldStyleAPI // The old version returns a "302" (to /manage) for a / request.
default:
return fmt.Errorf("expected 200 or 302 status code, but got: %d", resp.StatusCode)
}
if c.apiPaths == &OldStyleAPI && c.config.APIKey != "" {
return errors.New("unable to use API key authentication with old style API. Switch to user/pass authentication or update controller to latest version")
}
return nil
}
func (c *Client) getOldSysInfo(ctx context.Context) (*SysInfo, error) {
var response struct {
Data ServerInfo `json:"Meta"`
}
err := c.Get(ctx, c.apiPaths.StatusPath, nil, &response)
if err != nil {
return nil, err
}
data := response.Data
return &SysInfo{
Version: data.ServerVersion,
}, nil
}
// getSystemInformation reads the controller's version and UUID. Only call this if you
// previously called Login and suspect the controller version has changed.
func (c *Client) getSystemInformation() (*SysInfo, error) {
ctx, cancel := c.createRequestContext()
defer cancel()
var resultingError error
info, err := c.GetSystemInfo(ctx, "default") // get for default site which must exist
if err != nil {
resultingError = err
} else if info == nil || info.Version == "" {
resultingError = errors.New("new API returned empty server info")
}
if resultingError != nil {
info, err = c.getOldSysInfo(ctx)
if err != nil {
resultingError = errors.Join(resultingError, err)
} else if info == nil || info.Version == "" {
resultingError = errors.Join(resultingError, errors.New("old API returned empty server info"))
} else {
resultingError = nil
}
}
if resultingError != nil {
return nil, resultingError
}
return info, nil
}
func marshalRequest(reqBody interface{}) (io.Reader, error) {
if reqBody == nil {
return nil, nil //nolint: nilnil
}
reqBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
return bytes.NewReader(reqBytes), nil
}
func (c *Client) createRequestURL(apiPath string) (*url.URL, error) {
reqURL, err := url.Parse(apiPath)
if err != nil {
return nil, err
}
if !strings.HasPrefix(apiPath, "/") && !reqURL.IsAbs() {
reqURL.Path = path.Join(c.apiPaths.ApiPath, reqURL.Path)
}
return c.BaseURL.ResolveReference(reqURL), nil
}
func (c *Client) validateRequestBody(reqBody interface{}) error {
if reqBody != nil && c.config.ValidationMode != DisableValidation {
if err := c.validator.Validate(reqBody); err != nil {
err = fmt.Errorf("failed validating request body: %w", err)
if c.config.ValidationMode == HardValidation {
return err
} else {
fmt.Println(err)
}
}
}
return nil
}
// Do performs a request to the given API path with the given method.
func (c *Client) Do(ctx context.Context, method, apiPath string, reqBody interface{}, respBody interface{}) error {
if err := c.validateRequestBody(reqBody); err != nil {
return err
}
reqReader, err := marshalRequest(reqBody)
if err != nil {
return fmt.Errorf("unable to marshal request: %w", err)
}
url, err := c.createRequestURL(apiPath)
if err != nil {
return fmt.Errorf("unable to create request URL: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, url.String(), reqReader)
if err != nil {
return fmt.Errorf("unable to create request: %s %s %w", method, apiPath, err)
}
if c.config.UseLocking {
c.lock.Lock()
defer c.lock.Unlock()
}
for _, interceptor := range c.interceptors {
if err := interceptor.InterceptRequest(req); err != nil {
return err
}
}
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("unable to perform request: %s %s %w", method, apiPath, err)
}
defer resp.Body.Close()
for _, interceptor := range c.interceptors {
if err := interceptor.InterceptResponse(resp); err != nil {
return err
}
}
if err := c.errorHandler.HandleError(resp); err != nil {
return err
}
if respBody == nil || resp.ContentLength == 0 {
return nil
}
err = json.NewDecoder(resp.Body).Decode(respBody)
if err != nil {
return fmt.Errorf("unable to decode body: %s %s %w", method, apiPath, err)
}
return nil
}
// Get performs a GET request to the given API path.
func (c *Client) Get(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(context, http.MethodGet, apiPath, reqBody, respBody)
}
// Post performs a POST request to the given API path.
func (c *Client) Post(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(context, http.MethodPost, apiPath, reqBody, respBody)
}
// Put performs a PUT request to the given API path.
func (c *Client) Put(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(context, http.MethodPut, apiPath, reqBody, respBody)
}
// Delete performs a DELETE request to the given API path.
func (c *Client) Delete(context context.Context, apiPath string, reqBody interface{}, respBody interface{}) error {
return c.Do(context, http.MethodDelete, apiPath, reqBody, respBody)
}
// DEPRECATED: The content of this file has been refactored and split into separate files to improve maintainability and readability.
// Please refer to the following files for the actual implementation:
// - client.go: contains Client, ClientConfig, Credentials, newClientInternal, NewClient, NewBareClient etc.
// - api_paths.go: contains API path constants and ApiPaths struct definitions.
// - interceptors.go: contains interceptor interface and implementations (ApiKeyAuthInterceptor, CsrfInterceptor, DefaultHeadersInterceptor).
// - requests.go: contains request handling functions (marshalRequest, createRequestURL, Do, Get, Post, Put, Delete, etc.).
// This file is left intentionally empty to avoid duplicate declarations. Will be removed soon.

View File

@@ -46,10 +46,10 @@ func verifyInterceptorPresence(a *assert.Assertions, c *Client, interceptors []i
}
}
func TestNewClient(t *testing.T) {
func TestNewBareClient(t *testing.T) {
t.Parallel()
a := assert.New(t)
c, err := NewClient(&ClientConfig{
c, err := NewBareClient(&ClientConfig{
URL: localUrl,
User: "admin",
Pass: "password",
@@ -58,8 +58,8 @@ func TestNewClient(t *testing.T) {
require.Error(t, err)
a.EqualValues(localUrl, c.BaseURL.String())
a.Contains(err.Error(), "connection refused", "an invalid destination should produce a connection error.")
verifyInterceptorPresence(a, c, []interface{}{&CsrfInterceptor{}, &DefaultHeadersInterceptor{}}, true)
verifyInterceptorPresence(a, c, []interface{}{&ApiKeyAuthInterceptor{}}, false)
verifyInterceptorPresence(a, c, []interface{}{&CSRFInterceptor{}, &DefaultHeadersInterceptor{}}, true)
verifyInterceptorPresence(a, c, []interface{}{&APIKeyAuthInterceptor{}}, false)
}
func TestNewClientWithApiKey(t *testing.T) {
@@ -76,8 +76,8 @@ func TestNewClientWithApiKey(t *testing.T) {
require.Error(t, err)
a.EqualValues(localUrl, c.BaseURL.String())
a.Contains(err.Error(), "connection refused", "an invalid destination should produce a connection error.")
verifyInterceptorPresence(a, c, []interface{}{&ApiKeyAuthInterceptor{}, &DefaultHeadersInterceptor{}}, true)
verifyInterceptorPresence(a, c, []interface{}{&CsrfInterceptor{}}, false)
verifyInterceptorPresence(a, c, []interface{}{&APIKeyAuthInterceptor{}, &DefaultHeadersInterceptor{}}, true)
verifyInterceptorPresence(a, c, []interface{}{&CSRFInterceptor{}}, false)
}
func TestCustomizeHttpClient(t *testing.T) {
@@ -418,23 +418,27 @@ func TestAuthConfigurationValidation(t *testing.T) {
{"test", "test", "test", true},
}
v, err := newValidator()
require.NoError(t, err)
for _, tc := range testCases {
t.Run(fmt.Sprintf("user:%s-pass:%s-apikey:%s", tc.User, tc.Pass, tc.APIKey), func(t *testing.T) {
t.Parallel()
// given
_, err := NewClient(&ClientConfig{
cc := &ClientConfig{
URL: testUrl,
User: tc.User,
Pass: tc.Pass,
APIKey: tc.APIKey,
})
}
// when
err = v.Validate(cc)
// then
if tc.shouldFail {
require.ErrorContains(t, err, "validation failed")
return
}
require.ErrorContains(t, err, "dial tcp") // error will anyway exist, but it will be not related to config
require.NoError(t, err)
})
}
}
@@ -461,22 +465,75 @@ func TestUrlValidation(t *testing.T) {
t.Run(tc.URL, func(t *testing.T) {
t.Parallel()
// given
_, err := NewClient(&ClientConfig{
cc := &ClientConfig{
URL: tc.URL,
APIKey: "test-key",
})
}
v, err := newValidator()
require.NoError(t, err)
// when
err = v.Validate(cc)
// then
if tc.shouldFail {
require.ErrorContains(t, err, "validation failed")
require.ErrorContains(t, err, tc.errorString)
return
}
require.ErrorContains(t, err, "dial tcp") // error will anyway exist, but it will be not related to config
require.NoError(t, err)
})
}
}
func TestValidationModeValidation(t *testing.T) {
t.Parallel()
testCases := []struct {
validationMode validationMode
expectedError string
}{
{SoftValidation, ""},
{HardValidation, ""},
{DisableValidation, ""},
{"invalid", "must be one of"},
}
for _, tc := range testCases {
t.Run(string(tc.validationMode), func(t *testing.T) {
t.Parallel()
// given
cc := &ClientConfig{
URL: testUrl,
APIKey: "test-key",
ValidationMode: tc.validationMode,
}
v, err := newValidator()
require.NoError(t, err)
// when
err = v.Validate(cc)
// then
if tc.expectedError != "" {
require.ErrorContains(t, err, tc.expectedError)
return
}
require.NoError(t, err)
})
}
}
func TestClientConfigValidationExecutedOnNewClient(t *testing.T) {
t.Parallel()
a := assert.New(t)
// given
cc := &ClientConfig{URL: "invalid URL"}
// when
c, err := NewClient(cc)
// then
require.ErrorContains(t, err, "validation failed")
a.Nil(c)
}
type validateableBody struct {
Data string `json:"data" validate:"required"`
}
@@ -584,7 +641,7 @@ func TestGetSystemInformation(t *testing.T) {
VerifySSL: false,
})
sysInfo, err := c.getSystemInformation()
sysInfo, err := c.GetSystemInformation()
if tc.expectedError != "" {
require.ErrorContains(t, err, tc.expectedError)
@@ -602,17 +659,17 @@ func TestParseBaseUrl(t *testing.T) {
a := assert.New(t)
// Valid URL without /api in the path.
base, err := parseBaseUrl("http://localhost")
base, err := parseBaseURL("http://localhost")
require.NoError(t, err)
a.Equal("http", base.Scheme)
a.Equal("", base.Path)
// URL with trailing slash /api/
_, err = parseBaseUrl("http://localhost/api/")
_, err = parseBaseURL("http://localhost/api/")
require.ErrorContains(t, err, "expected a base URL without the `/api`")
// URL with /api in path (no trailing slash).
_, err = parseBaseUrl("http://localhost/api")
_, err = parseBaseURL("http://localhost/api")
require.ErrorContains(t, err, "expected a base URL without the `/api`")
}
@@ -642,10 +699,10 @@ func TestRegisterInterceptor(t *testing.T) {
// Create a dummy interceptor (using TestInterceptor already defined in the file).
var dummy ClientInterceptor = &TestInterceptor{}
initialCount := len(client.interceptors)
client.RegisterInterceptor(&dummy)
client.AddInterceptor(&dummy)
assert.Len(t, client.interceptors, initialCount+1)
// Attempt to add the same interceptor again.
client.RegisterInterceptor(&dummy)
client.AddInterceptor(&dummy)
assert.Len(t, client.interceptors, initialCount+1)
}
@@ -730,7 +787,7 @@ func TestCreateRequestURLInvalid(t *testing.T) {
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
apiPaths: &NewStyleAPI,
}
_, err := c.createRequestURL("://bad-url")
_, err := c.buildRequestURL("://bad-url")
require.Error(t, err)
assert.Contains(t, err.Error(), "parse")
}
@@ -741,7 +798,7 @@ func TestCreateRequestURLAbsolute(t *testing.T) {
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
apiPaths: &NewStyleAPI,
}
reqURL, err := c.createRequestURL("http://example.com/test")
reqURL, err := c.buildRequestURL("http://example.com/test")
require.NoError(t, err)
assert.Equal(t, "http://example.com/test", reqURL.String())
}
@@ -749,9 +806,9 @@ func TestCreateRequestURLAbsolute(t *testing.T) {
func TestCreateRequestContextTimeout(t *testing.T) {
t.Parallel()
c := &Client{
config: &ClientConfig{Timeout: 100 * time.Millisecond},
timeout: 100 * time.Millisecond,
}
ctx, cancel := c.createRequestContext()
ctx, cancel := c.newRequestContext()
defer cancel()
_, ok := ctx.Deadline()
require.True(t, ok)
@@ -787,9 +844,7 @@ func TestLoginWithAPIKeyDirect(t *testing.T) {
t.Parallel()
// Create a client manually with the APIKey set.
c := &Client{
config: &ClientConfig{
APIKey: "abc",
},
credentials: APIKeyCredentials{APIKey: "abc"},
}
err := c.Login()
assert.NoError(t, err)