test: add comprehensive testing of codegen (#14)

This commit is contained in:
Mateusz Filipowicz
2025-02-12 00:10:46 +01:00
committed by GitHub
parent 6016a3d34a
commit bb21419c0e
11 changed files with 639 additions and 107 deletions

View File

@@ -16,24 +16,16 @@ import (
"strings"
"github.com/iancoleman/strcase"
log "github.com/sirupsen/logrus"
"github.com/ulikunitz/xz"
"github.com/xor-gate/ar"
)
func DownloadAndExtract(downloadUrl url.URL, outputDir string) error {
targetInfo, err := os.Stat(outputDir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return err
}
// Check if output directory exists, if not create and perform extraction
err = os.MkdirAll(outputDir, 0o755)
if err != nil {
return err
}
// download fields, create
if created, err := ensurePath(outputDir); err != nil {
return fmt.Errorf("unable to create output directory %s: %w", outputDir, err)
} else if created {
log.Debugf("downloading UniFi Controller package from: %s", downloadUrl.String())
jarFile, err := downloadJar(downloadUrl, outputDir)
if err != nil {
@@ -41,18 +33,19 @@ func DownloadAndExtract(downloadUrl url.URL, outputDir string) error {
}
log.Debugf("extracting JSON files with API structures from: %s to: %s", jarFile, outputDir)
err = extractJSON(jarFile, outputDir)
if err != nil {
if err = extractJSON(jarFile, outputDir); err != nil {
return err
}
log.Debugf("JSON files extracted to: %s", outputDir)
targetInfo, err = os.Stat(outputDir)
_, err = os.Stat(outputDir)
if err != nil {
return err
}
}
if !targetInfo.IsDir() {
if targetInfo, err := os.Stat(outputDir); err != nil {
return err
} else if !targetInfo.IsDir() {
return errors.New("fields info isn't a directory")
}
return nil
@@ -68,10 +61,12 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
if err != nil {
return "", fmt.Errorf("unable to download UniFi Controller deb: %w", err)
}
if debResp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unable to download UniFi Controller deb: HTTP%d. Probably it does not exist under %s", debResp.StatusCode, downloadUrl.String())
}
defer debResp.Body.Close()
var uncompressedReader io.Reader
arReader := ar.NewReader(debResp.Body)
for {
header, err := arReader.Next()
@@ -81,8 +76,6 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
if err != nil {
return "", fmt.Errorf("in ar next: %w", err)
}
// read the data file
if header.Name == "data.tar.xz" {
uncompressedReader, err = xz.NewReader(arReader)
if err != nil {
@@ -96,9 +89,7 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
}
tarReader := tar.NewReader(uncompressedReader)
var aceJar *os.File
log.Debugln("extracting ace.jar from downloaded controller package")
for {
header, err := tarReader.Next()
@@ -108,12 +99,9 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
if err != nil {
return "", fmt.Errorf("in next: %w", err)
}
if header.Typeflag != tar.TypeReg || header.Name != "./usr/lib/unifi/lib/ace.jar" {
// skipping
continue
}
aceJar, err = os.Create(filepath.Join(outputDir, "ace.jar"))
if err != nil {
return "", fmt.Errorf("unable to create temp file: %w", err)
@@ -123,34 +111,14 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
return "", fmt.Errorf("unable to write ace.jar temp file: %w", err)
}
}
if aceJar == nil {
return "", errors.New("unable to find ace.jar")
}
defer aceJar.Close()
log.Debugf("ace.jar extracted to: %s", aceJar.Name())
return aceJar.Name(), nil
}
func sanitizeExtractedPath(filePath, destinationDir string) (string, error) {
absDestinationDir, err := filepath.Abs(destinationDir)
if err != nil {
return "", err
}
absFilePath, err := filepath.Abs(filepath.Join(destinationDir, filepath.Base(filePath)))
if err != nil {
return "", err
}
if !strings.HasPrefix(absFilePath, absDestinationDir) {
return "", fmt.Errorf("invalid file path: %s", filePath)
}
return absFilePath, nil
}
func extractJSON(jarFile, fieldsDir string) error {
jarZip, err := zip.OpenReader(jarFile)
if err != nil {
@@ -161,7 +129,6 @@ func extractJSON(jarFile, fieldsDir string) error {
log.Tracef("opened jar %s with %d files", jarFile, len(jarZip.File))
for _, f := range jarZip.File {
if !strings.HasPrefix(f.Name, "api/fields/") || path.Ext(f.Name) != ".json" {
// skip file
continue
}
@@ -175,19 +142,16 @@ func extractJSON(jarFile, fieldsDir string) error {
if err != nil {
return err
}
dst, err := os.Create(dstPath)
if err != nil {
return err
}
defer dst.Close()
_, err = io.Copy(dst, src)
log.Debugf("extracted %s", f.Name)
if err != nil {
return err
}
return nil
}()
if err != nil {
@@ -227,6 +191,23 @@ func extractJSON(jarFile, fieldsDir string) error {
log.Tracef("splitted %s into %s", settingKey, fileName)
}
// TODO: cleanup JSON
return nil
}
func sanitizeExtractedPath(filePath, destinationDir string) (string, error) {
absDestinationDir, err := filepath.Abs(destinationDir)
if err != nil {
return "", err
}
absFilePath, err := filepath.Abs(filepath.Join(destinationDir, filepath.Base(filePath)))
if err != nil {
return "", err
}
if !strings.HasPrefix(absFilePath, absDestinationDir) {
return "", fmt.Errorf("invalid file path: %s", filePath)
}
return absFilePath, nil
}

141
codegen/download_test.go Normal file
View File

@@ -0,0 +1,141 @@
package main
import (
"archive/zip"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper function to create a temporary zip file with given entries. 'entries' maps file names to their content.
func createTempZipFile(t *testing.T, entries map[string]string) string {
t.Helper()
tempDir := t.TempDir()
tempFileName := filepath.Join(tempDir, "test.zip")
tempFile, err := os.Create(tempFileName)
require.NoError(t, err, "Failed to create temp zip file")
// We need to truncate and write zip contents
w := zip.NewWriter(tempFile)
for name, content := range entries {
f, err := w.Create(name)
require.NoError(t, err, "Failed to add entry %s", name)
_, err = f.Write([]byte(content))
require.NoError(t, err, "Failed to write content for %s", name)
}
err = w.Close()
require.NoError(t, err, "Failed to close zip writer")
err = tempFile.Close()
require.NoError(t, err, "Failed to close temp file")
return tempFile.Name()
}
// Test when the output directory already exists. In this case, DownloadAndExtract should not call downloadJarFn or extractJSONFn.
func TestDownloadAndExtract_WithExistingDirectory(t *testing.T) {
t.Parallel()
r := require.New(t)
tempDir := t.TempDir()
testURL, _ := url.Parse("http://example.com/test.deb")
err := DownloadAndExtract(*testURL, tempDir)
r.NoError(err, "Expected no error when directory exists")
}
// // Test when output path is not a directory.
func TestDownloadAndExtract_PathNotDirectory(t *testing.T) {
t.Parallel()
r := require.New(t)
tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "dummy")
_, err := os.Create(tempFilePath)
r.NoError(err, "Failed to create temp file")
testURL, _ := url.Parse("http://example.com/test.deb")
err = DownloadAndExtract(*testURL, tempFilePath)
r.Error(err, "Expected error because tempFilePath is not a directory")
r.ErrorContains(err, tempFilePath+" isn't a directory")
}
// // Test extractJSON when the jar file cannot be opened.
func TestExtractJSON_OpenJarError(t *testing.T) {
t.Parallel()
r := require.New(t)
err := extractJSON("nonexisting.jar", t.TempDir())
r.Error(err)
r.ErrorContains(err, "unable to open jar")
}
// Test extractJSON with a valid zip file that contains a JSON file under api/fields/ and no Setting.json (so splitting is skipped).
func TestExtractJSON_NoSettings(t *testing.T) {
t.Parallel()
r := require.New(t)
tempDir := t.TempDir()
jarFile := createTempZipFile(t, map[string]string{"api/fields/dummy.json": "{\"key\": \"value\"}"})
err := extractJSON(jarFile, tempDir)
r.NoError(err)
// Check that dummy.json has been extracted
expectedPath := filepath.Join(tempDir, "dummy.json")
data, err := os.ReadFile(expectedPath)
r.NoError(err, "Expected file %s to exist", expectedPath)
r.JSONEq("{\"key\": \"value\"}", string(data), "Extracted file content mismatch")
}
// Test extractJSON with Setting.json present, so that it splits settings into individual files.
func TestExtractJSON_WithSettings(t *testing.T) {
t.Parallel()
r := require.New(t)
tempDir := t.TempDir()
entries := map[string]string{"api/fields/Setting.json": "{\"foo\": {\"bar\": 1}}"}
jarFile := createTempZipFile(t, entries)
err := extractJSON(jarFile, tempDir)
r.NoError(err)
// Check that the split settings file exists
settingFile := filepath.Join(tempDir, "SettingFoo.json")
data, err := os.ReadFile(settingFile)
r.NoError(err)
r.Contains(string(data), "bar")
}
// Test sanitizeExtractedPath with valid input.
func TestSanitizeExtractedPath_Valid(t *testing.T) {
t.Parallel()
a := assert.New(t)
r := require.New(t)
tempDir := t.TempDir()
filePath := "api/fields/dummy.json"
result, err := sanitizeExtractedPath(filePath, tempDir)
r.NoError(err, "Expected nil error from sanitizeExtractedPath")
expExpected := filepath.Join(tempDir, "dummy.json")
absExpected, err := filepath.Abs(expExpected)
r.NoError(err, "Failed to get abs path")
a.Equal(absExpected, result, "Sanitized path mismatch")
}
// Test extractJSON with invalid Setting.json content, expecting an unmarshal error.
func TestExtractJSON_InvalidSettings(t *testing.T) {
t.Parallel()
r := require.New(t)
tempDir := t.TempDir()
jarFile := createTempZipFile(t, map[string]string{"api/fields/Setting.json": "invalid json"})
err := extractJSON(jarFile, tempDir)
r.Error(err)
r.ErrorContains(err, "unable to unmarshal settings")
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/hashicorp/go-version"
)
var firmwareUpdateApi = "https://fw-update.ubnt.com/api/firmware-latest"
const defaultFirmwareUpdateApi = "https://fw-update.ubnt.com/api/firmware-latest"
const (
debianPlatform = "debian"
@@ -55,21 +55,20 @@ func (l *firmwareUpdateApiResponseEmbeddedFirmwareDataLink) MarshalJSON() ([]byt
func (l *firmwareUpdateApiResponseEmbeddedFirmwareDataLink) UnmarshalJSON(j []byte) error {
var m map[string]interface{}
err := json.Unmarshal(j, &m)
if err := json.Unmarshal(j, &m); err != nil {
return err
}
if href, exists := m["href"]; exists && href != nil {
strHref, ok := href.(string)
if !ok {
return fmt.Errorf("expected string for href, got %T", href)
}
u, err := url.Parse(strHref)
if err != nil {
return err
}
if href := m["href"]; href != nil {
url, err := url.Parse(href.(string))
if err != nil {
return err
l.Href = u
}
l.Href = url
}
return nil
}

201
codegen/fwupdate_test.go Normal file
View File

@@ -0,0 +1,201 @@
package main
import (
"encoding/json"
"net/url"
"testing"
"github.com/hashicorp/go-version"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFirmwareUpdateApiFilter(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
value string
expected string
}{
{"channel", "channel", "release", "eq~~channel~~release"},
{"product", "product", "unifi-controller", "eq~~product~~unifi-controller"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
a := assert.New(t)
filter := firmwareUpdateApiFilter(tc.key, tc.value)
a.Equal(tc.expected, filter)
})
}
}
func TestMarshalJSONDataLink(t *testing.T) {
t.Parallel()
tests := []struct {
name string
link firmwareUpdateApiResponseEmbeddedFirmwareDataLink
expectedJSON string
}{
{
"nil",
firmwareUpdateApiResponseEmbeddedFirmwareDataLink{Href: nil},
"{\"href\":\"\"}",
},
{
"with value",
func() firmwareUpdateApiResponseEmbeddedFirmwareDataLink {
u, err := url.Parse("https://example.com/firmware")
require.NoError(t, err) // error checking in test setup
return firmwareUpdateApiResponseEmbeddedFirmwareDataLink{Href: u}
}(),
"{\"href\":\"https://example.com/firmware\"}",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
a := assert.New(t)
b, err := json.Marshal(&tc.link)
require.NoError(t, err)
a.JSONEq(tc.expectedJSON, string(b))
})
}
}
func TestUnmarshalJSONDataLink(t *testing.T) {
t.Parallel()
tests := []struct {
name string
jsonStr string
shouldError bool
errorContains string
expectedHref *string
}{
{
"valid",
`{"href": "https://example.com/firmware"}`,
false,
"",
func(s string) *string { return &s }("https://example.com/firmware"),
},
{
"null",
`{"href": null}`,
false,
"",
nil,
},
{
"missing",
`{}`,
false,
"",
nil,
},
{
"non-string",
`{"href": 123}`,
true,
"expected string for href",
nil,
},
{
"invalid json",
`{"href": }`,
true,
"",
nil,
},
{
"invalid URL",
`{"href": "://missing"}`,
true,
"missing protocol scheme",
nil,
},
{
"empty",
`{"href": ""}`,
false,
"",
func(s string) *string { return &s }(""),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
a := assert.New(t)
var link firmwareUpdateApiResponseEmbeddedFirmwareDataLink
err := json.Unmarshal([]byte(tc.jsonStr), &link)
if tc.shouldError {
require.Error(t, err)
if tc.errorContains != "" {
require.ErrorContains(t, err, tc.errorContains)
}
} else {
require.NoError(t, err)
if tc.expectedHref == nil {
a.Nil(link.Href)
} else {
require.NotNil(t, link.Href)
a.Equal(*tc.expectedHref, link.Href.String())
}
}
})
}
}
func TestFirmwareUpdateApiResponse_Complete(t *testing.T) {
t.Parallel()
a := assert.New(t)
ver, err := version.NewVersion("1.2.3")
require.NoError(t, err)
u, err := url.Parse("https://example.com/download")
require.NoError(t, err)
req := firmwareUpdateApiResponse{
Embedded: firmwareUpdateApiResponseEmbedded{
Firmware: []firmwareUpdateApiResponseEmbeddedFirmware{
{
Channel: "release",
Created: "2020-01-01T00:00:00Z",
Id: "unique-id",
Platform: "debian",
Product: "unifi-controller",
Version: ver,
Links: firmwareUpdateApiResponseEmbeddedFirmwareLinks{
Data: firmwareUpdateApiResponseEmbeddedFirmwareDataLink{
Href: u,
},
},
},
},
},
}
jsonBytes, err := json.Marshal(req)
require.NoError(t, err)
var newReq firmwareUpdateApiResponse
err = json.Unmarshal(jsonBytes, &newReq)
require.NoError(t, err)
require.Len(t, newReq.Embedded.Firmware, 1)
fw := newReq.Embedded.Firmware[0]
a.Equal("release", fw.Channel)
a.Equal("debian", fw.Platform)
a.NotNil(fw.Version)
a.Equal("https://example.com/download", fw.Links.Data.Href.String())
}

View File

@@ -11,7 +11,6 @@ import (
"text/template"
"github.com/iancoleman/strcase"
log "github.com/sirupsen/logrus"
)
// Generatable is the interface for generation sources.
@@ -43,6 +42,10 @@ func generateCodeFromTemplate(templateName, templateContent string, toWrite any)
// generateCode generates code for each generation source and writes it to file.
func generateCode(fieldsDir string, outDir string) error {
if _, err := ensurePath(outDir); err != nil {
return fmt.Errorf("unable to create output directory %s: %w", outDir, err)
}
generators := make([]Generatable, 0)
resources, err := buildResourcesFromDownloadedFields(fieldsDir)
if err != nil {

View File

@@ -9,9 +9,11 @@ import (
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)
var log = logrus.New()
func usage() {
fmt.Printf("Usage: %s [OPTIONS] version\n", path.Base(os.Args[0]))
fmt.Printf("version can be a specific version or '%s' (default) for the latest UniFi Controller version\n", LatestVersionMarker)
@@ -19,21 +21,29 @@ func usage() {
}
func setupLogging(debugEnabled, traceEnabled bool) {
log.SetFormatter(&log.TextFormatter{
log.SetFormatter(&logrus.TextFormatter{
DisableTimestamp: true,
DisableLevelTruncation: true,
ForceColors: true,
FullTimestamp: false,
})
if traceEnabled {
log.SetLevel(log.TraceLevel)
log.SetLevel(logrus.TraceLevel)
} else if debugEnabled {
log.SetLevel(log.DebugLevel)
log.SetLevel(logrus.DebugLevel)
} else {
log.SetLevel(log.InfoLevel)
log.SetLevel(logrus.InfoLevel)
}
}
type options struct {
versionBaseDir string
outputDir string
downloadOnly bool
version string
firmwareUpdateApi string
}
func main() {
flag.Usage = usage
@@ -43,16 +53,31 @@ func main() {
debugFlag := flag.Bool("debug", false, "Enable debug logging")
traceFlag := flag.Bool("trace", false, "Enable trace logging")
flag.CommandLine.Init(os.Args[0], flag.PanicOnError) // set error handling to panic if parse ends with error
flag.Parse()
setupLogging(*debugFlag, *traceFlag)
specifiedVersion := strings.TrimSpace(flag.Arg(0))
if specifiedVersion == "" {
specifiedVersion = LatestVersionMarker // default to latest version
}
unifiVersion, err := determineUnifiVersion(specifiedVersion)
err := generate(options{
versionBaseDir: *versionBaseDirFlag,
outputDir: *outputDirFlag,
downloadOnly: *downloadOnly,
version: specifiedVersion,
firmwareUpdateApi: defaultFirmwareUpdateApi,
})
if err != nil {
log.Fatalf("unable to determine version and download URL for Unifi version %s: %s", specifiedVersion, err)
panic(err)
log.Error(err)
os.Exit(1)
}
}
func generate(opts options) error {
p := NewUnifiVersionProvider(opts.firmwareUpdateApi)
unifiVersion, err := p.ByVersionMarker(opts.version)
if err != nil {
return fmt.Errorf("unable to determine version and download URL for Unifi version %s: %w", opts.version, err)
}
log.Infof("UniFi Controller version: %s", unifiVersion.Version)
@@ -60,42 +85,49 @@ func main() {
wd, err := os.Getwd()
if err != nil {
log.Fatalf("unable to determine working directory: %s", err)
panic(err)
return fmt.Errorf("unable to determine working directory: %w", err)
}
structuresDir := filepath.Join(wd, *versionBaseDirFlag, fmt.Sprintf("v%s", unifiVersion.Version))
var structuresDir string
if path.IsAbs(opts.versionBaseDir) {
structuresDir = opts.versionBaseDir
} else {
structuresDir = filepath.Join(wd, opts.versionBaseDir)
}
structuresDir = filepath.Join(structuresDir, fmt.Sprintf("v%s", unifiVersion.Version))
log.Infoln("Downloading UniFi Controller API structures definitions...")
err = DownloadAndExtract(*unifiVersion.DownloadUrl, structuresDir)
if err != nil {
log.Fatalf("unable to download and extract UniFi Controller API structures definitions: %s", err)
panic(err)
return fmt.Errorf("unable to download and extract UniFi Controller API structures definitions: %w", err)
}
log.Infof("Downloaded UniFi Controller API structures definitions in %s", structuresDir)
if *downloadOnly {
if opts.downloadOnly {
log.Infoln("Structure JSONs ready!")
os.Exit(0)
return nil
}
log.Infoln("Generating resources code...")
outDir := filepath.Join(wd, *outputDirFlag)
var outDir string
if path.IsAbs(opts.outputDir) {
outDir = opts.outputDir
} else {
outDir = filepath.Join(wd, opts.outputDir)
}
if err = generateCode(structuresDir, outDir); err != nil {
log.Fatalf("unable to generate resources code: %s", err)
panic(err)
return fmt.Errorf("unable to generate resources code: %w", err)
}
log.Infof("Writing version file...")
if err = writeVersionFile(unifiVersion.Version, outDir); err != nil {
log.Fatalf("failed to write version file to %s: %s", outDir, err)
panic(err)
return fmt.Errorf("failed to write version file to %s: %w", outDir, err)
}
basepath := filepath.Dir(wd)
if err = writeVersionRepoMarkerFile(unifiVersion.Version, basepath); err != nil {
log.Fatalf("failed to write version file to %s: %s", basepath, err)
panic(err)
return fmt.Errorf("failed to write version file to %s: %w", basepath, err)
}
log.Infof("Generated resources in %s", outDir)
return nil
}

137
codegen/main_test.go Normal file
View File

@@ -0,0 +1,137 @@
package main
import (
"os"
"os/exec"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetupLogging(t *testing.T) {
t.Parallel()
a := assert.New(t)
setupLogging(false, false)
a.Equal(logrus.InfoLevel, log.Level)
setupLogging(true, false)
a.Equal(logrus.DebugLevel, log.Level)
setupLogging(false, true)
a.Equal(logrus.TraceLevel, log.Level)
setupLogging(true, true)
a.Equal(logrus.TraceLevel, log.Level)
}
// integration tests for the CLI
// these test require Internet access
func execCli(args ...string) (string, error) {
in := []string{"run", "."}
in = append(in, args...)
cmd := exec.Command("go", in...)
output, err := cmd.CombinedOutput()
return string(output), err
}
func TestHelpFlag(t *testing.T) {
t.Parallel()
out, err := execCli("-h")
require.Error(t, err)
assert.Contains(t, out, "Usage: codegen [OPTIONS] version")
}
func TestInvalidFlag(t *testing.T) {
t.Parallel()
out, err := execCli("-invalid")
require.Error(t, err)
assert.Contains(t, out, "flag provided but not defined: -invalid")
}
func TestDefaultVersion(t *testing.T) {
t.Parallel()
out, err := execCli("-version-base-dir", t.TempDir(), "-output-dir", t.TempDir())
require.NoError(t, err)
assert.Contains(t, out, "UniFi Controller version")
}
func testGenerate(t *testing.T, opts *options) error {
t.Helper()
setupLogging(false, false)
if opts.versionBaseDir == "" {
opts.versionBaseDir = t.TempDir()
}
if opts.outputDir == "" {
opts.outputDir = t.TempDir()
}
if opts.firmwareUpdateApi == "" {
opts.firmwareUpdateApi = defaultFirmwareUpdateApi
}
return generate(*opts)
}
func TestNonExistentVersion(t *testing.T) {
t.Parallel()
err := testGenerate(t, &options{version: "1.2.3"})
require.Error(t, err)
}
func TestInvalidVersion(t *testing.T) {
t.Parallel()
r := require.New(t)
err := testGenerate(t, &options{version: "invalid-version"})
r.Error(err)
r.ErrorContains(err, "Malformed")
r.ErrorContains(err, "invalid-version")
}
func TestGenerateLatest(t *testing.T) {
t.Parallel()
r := require.New(t)
opts := &options{version: LatestVersionMarker}
err := testGenerate(t, opts)
r.NoError(err)
files, err := os.ReadDir(opts.versionBaseDir)
r.NoError(err)
assert.NotEmptyf(t, files, "version base dir '%s' should not be empty", opts.versionBaseDir)
files, err = os.ReadDir(opts.outputDir)
r.NoError(err)
assert.NotEmptyf(t, files, "output dir '%s' should not be empty", opts.outputDir)
}
func TestGenerateDownloadOnly(t *testing.T) {
t.Parallel()
r := require.New(t)
opts := &options{version: LatestVersionMarker, downloadOnly: true}
err := testGenerate(t, opts)
r.NoError(err)
files, err := os.ReadDir(opts.versionBaseDir)
r.NoError(err)
assert.NotEmptyf(t, files, "version base dir '%s' should not be empty", opts.versionBaseDir)
files, err = os.ReadDir(opts.outputDir)
r.NoError(err) // test generated dir
assert.Emptyf(t, files, "output dir '%s' should be empty", opts.outputDir)
}

View File

@@ -12,7 +12,6 @@ import (
"strings"
"github.com/iancoleman/strcase"
log "github.com/sirupsen/logrus"
)
type replacement struct {

26
codegen/utils.go Normal file
View File

@@ -0,0 +1,26 @@
package main
import (
"errors"
"fmt"
"os"
)
// ensurePath checks if a path exists and is a directory, if not it creates the directory. Returns true if the directories were created.
func ensurePath(path string) (bool, error) {
// Check if output directory exists, if not create and perform extraction
targetInfo, err := os.Stat(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return false, err
}
if err = os.MkdirAll(path, 0o755); err != nil {
return false, err
}
return true, nil
}
if !targetInfo.IsDir() {
return false, fmt.Errorf("%s isn't a directory", path)
}
return false, nil
}

View File

@@ -31,8 +31,23 @@ func NewUnifiVersion(unifiVersion *version.Version, downloadUrl *url.URL) *Unifi
}
}
func latestUnifiVersion() (*UnifiVersion, error) {
url, err := url.Parse(firmwareUpdateApi)
type UnifiVersionProvider interface {
Latest() (*UnifiVersion, error)
ByVersionMarker(versionMarker string) (*UnifiVersion, error)
}
type defaultUnifiVersionProvider struct {
firmwareUpdateApi string
}
func NewUnifiVersionProvider(firmwareUpdateApi string) UnifiVersionProvider { //nolint:ireturn
return &defaultUnifiVersionProvider{
firmwareUpdateApi: firmwareUpdateApi,
}
}
func (p *defaultUnifiVersionProvider) Latest() (*UnifiVersion, error) {
url, err := url.Parse(p.firmwareUpdateApi)
if err != nil {
return nil, err
}
@@ -70,9 +85,9 @@ func latestUnifiVersion() (*UnifiVersion, error) {
return nil, errors.New("no Unifi Controller firmware found")
}
func determineUnifiVersion(versionMarker string) (*UnifiVersion, error) {
func (p *defaultUnifiVersionProvider) ByVersionMarker(versionMarker string) (*UnifiVersion, error) {
if versionMarker == LatestVersionMarker {
return latestUnifiVersion()
return p.Latest()
} else {
unifiVersion, err := version.NewVersion(versionMarker)
if err != nil {

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/require"
)
func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersion, error)) {
func assertLatestVersionUsingProvider(t *testing.T, provider func(p UnifiVersionProvider) (*UnifiVersion, error)) {
t.Helper()
assert := assert.New(t)
require := require.New(t)
@@ -85,8 +85,9 @@ func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersi
}))
defer server.Close()
firmwareUpdateApi = server.URL
gotVersion, err := provider()
p := NewUnifiVersionProvider(server.URL)
gotVersion, err := provider(p)
require.NoError(err)
assert.Equal(fwVersion.Core(), gotVersion.Version)
@@ -95,15 +96,15 @@ func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersi
func TestLatestUnifiVersion(t *testing.T) {
t.Parallel()
assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) {
return latestUnifiVersion()
assertLatestVersionUsingProvider(t, func(p UnifiVersionProvider) (*UnifiVersion, error) {
return p.Latest()
})
}
func TestDetermineUnifiVersion_latest(t *testing.T) {
t.Parallel()
assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) {
return determineUnifiVersion(LatestVersionMarker)
assertLatestVersionUsingProvider(t, func(p UnifiVersionProvider) (*UnifiVersion, error) {
return p.ByVersionMarker(LatestVersionMarker)
})
}
@@ -121,7 +122,7 @@ func TestDetermineUnifiVersion_provided(t *testing.T) {
t.Parallel()
a := assert.New(t)
unifiVersion, err := determineUnifiVersion(providedVersion)
unifiVersion, err := NewUnifiVersionProvider(defaultFirmwareUpdateApi).ByVersionMarker(providedVersion)
require.NoError(t, err)
a.Equal(expectedVersion, unifiVersion.Version.String())
@@ -143,7 +144,7 @@ func TestDetermineUnifiVersion_invalid(t *testing.T) {
for _, providedVersion := range testCases {
t.Run(providedVersion, func(t *testing.T) {
t.Parallel()
_, err := determineUnifiVersion(providedVersion)
_, err := NewUnifiVersionProvider(defaultFirmwareUpdateApi).ByVersionMarker(providedVersion)
require.ErrorContains(t, err, providedVersion)
})
}
@@ -171,8 +172,7 @@ func TestLatestUnifiVersion_HttpError(t *testing.T) {
}))
defer server.Close()
firmwareUpdateApi = server.URL
_, err := latestUnifiVersion()
_, err := NewUnifiVersionProvider(server.URL).Latest()
require.Error(t, err)
}
@@ -187,8 +187,8 @@ func TestLatestUnifiVersion_InvalidJson(t *testing.T) {
}))
defer server.Close()
firmwareUpdateApi = server.URL
_, err := latestUnifiVersion()
_, err := NewUnifiVersionProvider(server.URL).Latest()
require.Error(t, err)
require.ErrorContains(t, err, "invalid")
}
@@ -224,8 +224,8 @@ func TestLatestUnifiVersion_NoDebianFirmware(t *testing.T) {
}))
defer server.Close()
firmwareUpdateApi = server.URL
_, err = latestUnifiVersion()
_, err = NewUnifiVersionProvider(server.URL).Latest()
require.Error(t, err)
require.ErrorContains(t, err, "no Unifi Controller firmware found")
}
@@ -265,8 +265,7 @@ func TestWriteVersionRepoMarkerFile(t *testing.T) {
func TestLatestUnifiVersion_InvalidUrl(t *testing.T) {
t.Parallel()
firmwareUpdateApi = ":\\invalid"
_, err := latestUnifiVersion()
_, err := NewUnifiVersionProvider(":\\invalid").Latest()
require.Error(t, err)
require.ErrorContains(t, err, "invalid")
}
@@ -321,8 +320,7 @@ func TestLatestUnifiVersion_NilVersion(t *testing.T) {
}))
defer server.Close()
firmwareUpdateApi = server.URL
_, err := latestUnifiVersion()
_, err := NewUnifiVersionProvider(server.URL).Latest()
require.Error(t, err)
}