feat: add validation of ClientConfig fields for improved data integrity (#5)
* feat: add validation of ClientConfig fields for improved data integrity * chore: add tests for client config validation
This commit is contained in:
committed by
GitHub
parent
e99645cf93
commit
c7e81e2b18
@@ -85,10 +85,10 @@ func (m *Meta) error() error {
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
User string
|
||||
Pass string
|
||||
APIKey string
|
||||
URL string
|
||||
URL string `validate:"required,http_url"`
|
||||
APIKey string `validate:"required_without_all=User Pass"`
|
||||
User string `validate:"excluded_with=APIKey,required_with=Pass"`
|
||||
Pass string `validate:"excluded_with=APIKey,required_with=User"`
|
||||
Timeout time.Duration // how long to wait for replies, default: forever.
|
||||
VerifySSL bool
|
||||
Interceptors []ClientInterceptor
|
||||
@@ -107,6 +107,7 @@ type Client struct {
|
||||
interceptors []ClientInterceptor
|
||||
errorHandler ResponseErrorHandler
|
||||
lock sync.Mutex
|
||||
validator *validator
|
||||
}
|
||||
|
||||
type ApiPaths struct {
|
||||
@@ -240,7 +241,14 @@ func (d *DefaultResponseErrorHandler) HandleError(resp *http.Response) error {
|
||||
// Used to make additional, authenticated requests to the APIs.
|
||||
// Start here.
|
||||
func NewClient(config *ClientConfig) (*Client, error) {
|
||||
u, err := newUnifi(config)
|
||||
v, err := newValidator()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed creating validator: %w", err)
|
||||
}
|
||||
if err := v.Validate(config); err != nil {
|
||||
return nil, fmt.Errorf("failed validating config: %w", err)
|
||||
}
|
||||
u, err := newUnifi(config, v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed creating unifi client: %w", err)
|
||||
}
|
||||
@@ -275,7 +283,7 @@ func parseBaseUrl(base string) (*url.URL, error) {
|
||||
return baseURL, nil
|
||||
}
|
||||
|
||||
func newUnifi(config *ClientConfig) (*Client, error) {
|
||||
func newUnifi(config *ClientConfig, v *validator) (*Client, error) {
|
||||
var err error
|
||||
|
||||
config.URL = strings.TrimRight(config.URL, "/")
|
||||
@@ -337,6 +345,7 @@ func newUnifi(config *ClientConfig) (*Client, error) {
|
||||
interceptors: interceptors,
|
||||
errorHandler: errorHandler,
|
||||
lock: sync.Mutex{},
|
||||
validator: v,
|
||||
}
|
||||
for _, interceptor := range config.Interceptors {
|
||||
// add any custom interceptors and ensure no duplicates
|
||||
|
||||
@@ -102,7 +102,8 @@ func TestCustomizeHttpClient(t *testing.T) {
|
||||
|
||||
// when
|
||||
_, err := NewClient(&ClientConfig{
|
||||
URL: localUrl,
|
||||
URL: localUrl,
|
||||
APIKey: "test-key",
|
||||
HttpCustomizer: func(transport *http.Transport) error {
|
||||
called = true
|
||||
return nil
|
||||
@@ -439,7 +440,8 @@ func TestResponseDataHandling(t *testing.T) {
|
||||
}
|
||||
srv := RunTestServer(NewStyleAPI.ApiPath+"/test", TestData{})
|
||||
c, _ := NewClient(&ClientConfig{
|
||||
URL: srv.URL,
|
||||
URL: srv.URL,
|
||||
APIKey: "test-key",
|
||||
})
|
||||
c.apiPaths = &NewStyleAPI
|
||||
var data TestData
|
||||
@@ -460,6 +462,8 @@ func TestCsrfHandling(t *testing.T) {
|
||||
interceptor := NewTestInterceptor()
|
||||
c, _ := NewClient(&ClientConfig{
|
||||
URL: srv.URL,
|
||||
User: "test-user",
|
||||
Pass: "test-pass",
|
||||
Interceptors: interceptor.AsList(),
|
||||
})
|
||||
c.apiPaths = &NewStyleAPI
|
||||
@@ -487,6 +491,7 @@ func TestOverrideUserAgent(t *testing.T) {
|
||||
interceptor := NewTestInterceptor()
|
||||
c, _ := NewClient(&ClientConfig{
|
||||
URL: testUrl,
|
||||
APIKey: "test-key",
|
||||
Interceptors: interceptor.AsList(),
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
@@ -499,3 +504,78 @@ func TestOverrideUserAgent(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
a.EqualValues("test-agent", interceptor.RequestHeader(UserAgentHeader))
|
||||
}
|
||||
|
||||
func TestAuthConfigurationValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
User, Pass, APIKey string
|
||||
shouldFail bool
|
||||
}{
|
||||
{"", "", "", true},
|
||||
{"", "", "test", false},
|
||||
{"", "test", "", true},
|
||||
{"", "test", "test", true},
|
||||
{"test", "", "", true},
|
||||
{"test", "", "test", true},
|
||||
{"test", "test", "", false},
|
||||
{"test", "test", "test", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("user:%s-pass:%s-apikey:%s", tc.User, tc.Pass, tc.APIKey), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// given
|
||||
_, err := NewClient(&ClientConfig{
|
||||
URL: testUrl,
|
||||
User: tc.User,
|
||||
Pass: tc.Pass,
|
||||
APIKey: tc.APIKey,
|
||||
})
|
||||
|
||||
// then
|
||||
if tc.shouldFail {
|
||||
require.ErrorContains(t, err, "validation failed")
|
||||
return
|
||||
}
|
||||
require.ErrorContains(t, err, "dial tcp") // error will anyway exist, but it will be not related to config
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
URL string
|
||||
shouldFail bool
|
||||
errorString string
|
||||
}{
|
||||
{"", true, "required"},
|
||||
{"http://test.url", false, ""},
|
||||
{"http://test.url:3999", false, ""},
|
||||
{"https://test.url:3999", false, ""},
|
||||
{"ftp://test.url", true, "http"},
|
||||
{"test.url", true, "http"},
|
||||
{"http://127.0.0.1", false, ""},
|
||||
{"http://127.0.0.1:3999", false, ""},
|
||||
{"test", true, "http"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.URL, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// given
|
||||
_, err := NewClient(&ClientConfig{
|
||||
URL: tc.URL,
|
||||
APIKey: "test-key",
|
||||
})
|
||||
|
||||
// then
|
||||
if tc.shouldFail {
|
||||
require.ErrorContains(t, err, "validation failed")
|
||||
require.ErrorContains(t, err, tc.errorString)
|
||||
return
|
||||
}
|
||||
require.ErrorContains(t, err, "dial tcp") // error will anyway exist, but it will be not related to config
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
77
unifi/validation.go
Normal file
77
unifi/validation.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package unifi
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-playground/locales/en"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
vd "github.com/go-playground/validator/v10"
|
||||
en_translations "github.com/go-playground/validator/v10/translations/en"
|
||||
)
|
||||
|
||||
// ValidationError is a custom error type for validation errors.
|
||||
type ValidationError struct {
|
||||
Root error
|
||||
Messages map[string]string
|
||||
}
|
||||
|
||||
// Error returns the error message with combined all validation error messages.
|
||||
func (v *ValidationError) Error() string {
|
||||
err := "validation failed: \n"
|
||||
for field, message := range v.Messages {
|
||||
err += fmt.Sprintf("%s: %s\n", field, message)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Validator is the interface for the validator. Use it to validate structs. You can register structure-level validations
|
||||
// with RegisterStructValidation.
|
||||
type Validator interface {
|
||||
// Validate validates the given struct and returns an error if the struct is not valid.
|
||||
Validate(i interface{}) error
|
||||
// RegisterStructValidation registers a structure-level validation function for a given struct type.
|
||||
RegisterStructValidation(fn vd.StructLevelFunc, i interface{})
|
||||
// RegisterTranslation registers a custom translation for a given tag.
|
||||
RegisterTranslation(tag string, registerFn vd.RegisterTranslationsFunc, translationFn vd.TranslationFunc) (err error)
|
||||
}
|
||||
|
||||
type validator struct {
|
||||
validate *vd.Validate
|
||||
trans ut.Translator
|
||||
}
|
||||
|
||||
func (v *validator) Validate(i interface{}) error {
|
||||
if err := v.validate.Struct(i); err != nil {
|
||||
var errs vd.ValidationErrors
|
||||
errors.As(err, &errs)
|
||||
messages := errs.Translate(v.trans)
|
||||
|
||||
return &ValidationError{Root: err, Messages: messages}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *validator) RegisterStructValidation(f vd.StructLevelFunc, s interface{}) {
|
||||
v.validate.RegisterStructValidation(f, s)
|
||||
}
|
||||
|
||||
func (v *validator) RegisterTranslation(tag string, registerFn vd.RegisterTranslationsFunc, translationFn vd.TranslationFunc) error {
|
||||
return v.validate.RegisterTranslation(tag, v.trans, registerFn, translationFn)
|
||||
}
|
||||
|
||||
func newValidator() (*validator, error) {
|
||||
validate := vd.New(vd.WithRequiredStructEnabled())
|
||||
enLocale := en.New()
|
||||
uni := ut.New(enLocale, enLocale)
|
||||
trans, _ := uni.GetTranslator(enLocale.Locale())
|
||||
err := en_translations.RegisterDefaultTranslations(validate, trans)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &validator{
|
||||
validate: validate,
|
||||
trans: trans,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user