From bb21419c0e8bab98f1aec596c001c64cc30d9610 Mon Sep 17 00:00:00 2001 From: Mateusz Filipowicz Date: Wed, 12 Feb 2025 00:10:46 +0100 Subject: [PATCH] test: add comprehensive testing of codegen (#14) --- codegen/download.go | 79 ++++++--------- codegen/download_test.go | 141 +++++++++++++++++++++++++++ codegen/fwupdate.go | 19 ++-- codegen/fwupdate_test.go | 201 +++++++++++++++++++++++++++++++++++++++ codegen/generator.go | 5 +- codegen/main.go | 78 ++++++++++----- codegen/main_test.go | 137 ++++++++++++++++++++++++++ codegen/resources.go | 1 - codegen/utils.go | 26 +++++ codegen/version.go | 23 ++++- codegen/version_test.go | 36 ++++--- 11 files changed, 639 insertions(+), 107 deletions(-) create mode 100644 codegen/download_test.go create mode 100644 codegen/fwupdate_test.go create mode 100644 codegen/main_test.go create mode 100644 codegen/utils.go diff --git a/codegen/download.go b/codegen/download.go index 9cb78dc..6076abc 100644 --- a/codegen/download.go +++ b/codegen/download.go @@ -16,24 +16,16 @@ import ( "strings" "github.com/iancoleman/strcase" - log "github.com/sirupsen/logrus" "github.com/ulikunitz/xz" "github.com/xor-gate/ar" ) func DownloadAndExtract(downloadUrl url.URL, outputDir string) error { - targetInfo, err := os.Stat(outputDir) - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - return err - } + // Check if output directory exists, if not create and perform extraction - err = os.MkdirAll(outputDir, 0o755) - if err != nil { - return err - } - - // download fields, create + if created, err := ensurePath(outputDir); err != nil { + return fmt.Errorf("unable to create output directory %s: %w", outputDir, err) + } else if created { log.Debugf("downloading UniFi Controller package from: %s", downloadUrl.String()) jarFile, err := downloadJar(downloadUrl, outputDir) if err != nil { @@ -41,18 +33,19 @@ func DownloadAndExtract(downloadUrl url.URL, outputDir string) error { } log.Debugf("extracting JSON files with API structures from: %s to: %s", jarFile, outputDir) - err = extractJSON(jarFile, outputDir) - if err != nil { + if err = extractJSON(jarFile, outputDir); err != nil { return err } log.Debugf("JSON files extracted to: %s", outputDir) - targetInfo, err = os.Stat(outputDir) + _, err = os.Stat(outputDir) if err != nil { return err } } - if !targetInfo.IsDir() { + if targetInfo, err := os.Stat(outputDir); err != nil { + return err + } else if !targetInfo.IsDir() { return errors.New("fields info isn't a directory") } return nil @@ -68,10 +61,12 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) { if err != nil { return "", fmt.Errorf("unable to download UniFi Controller deb: %w", err) } + if debResp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unable to download UniFi Controller deb: HTTP%d. Probably it does not exist under %s", debResp.StatusCode, downloadUrl.String()) + } defer debResp.Body.Close() var uncompressedReader io.Reader - arReader := ar.NewReader(debResp.Body) for { header, err := arReader.Next() @@ -81,8 +76,6 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) { if err != nil { return "", fmt.Errorf("in ar next: %w", err) } - - // read the data file if header.Name == "data.tar.xz" { uncompressedReader, err = xz.NewReader(arReader) if err != nil { @@ -96,9 +89,7 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) { } tarReader := tar.NewReader(uncompressedReader) - var aceJar *os.File - log.Debugln("extracting ace.jar from downloaded controller package") for { header, err := tarReader.Next() @@ -108,12 +99,9 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) { if err != nil { return "", fmt.Errorf("in next: %w", err) } - if header.Typeflag != tar.TypeReg || header.Name != "./usr/lib/unifi/lib/ace.jar" { - // skipping continue } - aceJar, err = os.Create(filepath.Join(outputDir, "ace.jar")) if err != nil { return "", fmt.Errorf("unable to create temp file: %w", err) @@ -123,34 +111,14 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) { return "", fmt.Errorf("unable to write ace.jar temp file: %w", err) } } - if aceJar == nil { return "", errors.New("unable to find ace.jar") } - defer aceJar.Close() log.Debugf("ace.jar extracted to: %s", aceJar.Name()) return aceJar.Name(), nil } -func sanitizeExtractedPath(filePath, destinationDir string) (string, error) { - absDestinationDir, err := filepath.Abs(destinationDir) - if err != nil { - return "", err - } - - absFilePath, err := filepath.Abs(filepath.Join(destinationDir, filepath.Base(filePath))) - if err != nil { - return "", err - } - - if !strings.HasPrefix(absFilePath, absDestinationDir) { - return "", fmt.Errorf("invalid file path: %s", filePath) - } - - return absFilePath, nil -} - func extractJSON(jarFile, fieldsDir string) error { jarZip, err := zip.OpenReader(jarFile) if err != nil { @@ -161,7 +129,6 @@ func extractJSON(jarFile, fieldsDir string) error { log.Tracef("opened jar %s with %d files", jarFile, len(jarZip.File)) for _, f := range jarZip.File { if !strings.HasPrefix(f.Name, "api/fields/") || path.Ext(f.Name) != ".json" { - // skip file continue } @@ -175,19 +142,16 @@ func extractJSON(jarFile, fieldsDir string) error { if err != nil { return err } - dst, err := os.Create(dstPath) if err != nil { return err } defer dst.Close() - _, err = io.Copy(dst, src) log.Debugf("extracted %s", f.Name) if err != nil { return err } - return nil }() if err != nil { @@ -227,6 +191,23 @@ func extractJSON(jarFile, fieldsDir string) error { log.Tracef("splitted %s into %s", settingKey, fileName) } - // TODO: cleanup JSON return nil } + +func sanitizeExtractedPath(filePath, destinationDir string) (string, error) { + absDestinationDir, err := filepath.Abs(destinationDir) + if err != nil { + return "", err + } + + absFilePath, err := filepath.Abs(filepath.Join(destinationDir, filepath.Base(filePath))) + if err != nil { + return "", err + } + + if !strings.HasPrefix(absFilePath, absDestinationDir) { + return "", fmt.Errorf("invalid file path: %s", filePath) + } + + return absFilePath, nil +} diff --git a/codegen/download_test.go b/codegen/download_test.go new file mode 100644 index 0000000..39fb269 --- /dev/null +++ b/codegen/download_test.go @@ -0,0 +1,141 @@ +package main + +import ( + "archive/zip" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create a temporary zip file with given entries. 'entries' maps file names to their content. +func createTempZipFile(t *testing.T, entries map[string]string) string { + t.Helper() + tempDir := t.TempDir() + tempFileName := filepath.Join(tempDir, "test.zip") + tempFile, err := os.Create(tempFileName) + require.NoError(t, err, "Failed to create temp zip file") + // We need to truncate and write zip contents + w := zip.NewWriter(tempFile) + for name, content := range entries { + f, err := w.Create(name) + require.NoError(t, err, "Failed to add entry %s", name) + _, err = f.Write([]byte(content)) + require.NoError(t, err, "Failed to write content for %s", name) + } + err = w.Close() + require.NoError(t, err, "Failed to close zip writer") + err = tempFile.Close() + require.NoError(t, err, "Failed to close temp file") + return tempFile.Name() +} + +// Test when the output directory already exists. In this case, DownloadAndExtract should not call downloadJarFn or extractJSONFn. +func TestDownloadAndExtract_WithExistingDirectory(t *testing.T) { + t.Parallel() + r := require.New(t) + + tempDir := t.TempDir() + testURL, _ := url.Parse("http://example.com/test.deb") + + err := DownloadAndExtract(*testURL, tempDir) + + r.NoError(err, "Expected no error when directory exists") +} + +// // Test when output path is not a directory. +func TestDownloadAndExtract_PathNotDirectory(t *testing.T) { + t.Parallel() + r := require.New(t) + + tempDir := t.TempDir() + tempFilePath := filepath.Join(tempDir, "dummy") + _, err := os.Create(tempFilePath) + r.NoError(err, "Failed to create temp file") + testURL, _ := url.Parse("http://example.com/test.deb") + + err = DownloadAndExtract(*testURL, tempFilePath) + + r.Error(err, "Expected error because tempFilePath is not a directory") + r.ErrorContains(err, tempFilePath+" isn't a directory") +} + +// // Test extractJSON when the jar file cannot be opened. +func TestExtractJSON_OpenJarError(t *testing.T) { + t.Parallel() + r := require.New(t) + + err := extractJSON("nonexisting.jar", t.TempDir()) + + r.Error(err) + r.ErrorContains(err, "unable to open jar") +} + +// Test extractJSON with a valid zip file that contains a JSON file under api/fields/ and no Setting.json (so splitting is skipped). +func TestExtractJSON_NoSettings(t *testing.T) { + t.Parallel() + r := require.New(t) + tempDir := t.TempDir() + jarFile := createTempZipFile(t, map[string]string{"api/fields/dummy.json": "{\"key\": \"value\"}"}) + + err := extractJSON(jarFile, tempDir) + r.NoError(err) + + // Check that dummy.json has been extracted + expectedPath := filepath.Join(tempDir, "dummy.json") + data, err := os.ReadFile(expectedPath) + r.NoError(err, "Expected file %s to exist", expectedPath) + r.JSONEq("{\"key\": \"value\"}", string(data), "Extracted file content mismatch") +} + +// Test extractJSON with Setting.json present, so that it splits settings into individual files. +func TestExtractJSON_WithSettings(t *testing.T) { + t.Parallel() + r := require.New(t) + tempDir := t.TempDir() + entries := map[string]string{"api/fields/Setting.json": "{\"foo\": {\"bar\": 1}}"} + jarFile := createTempZipFile(t, entries) + + err := extractJSON(jarFile, tempDir) + r.NoError(err) + + // Check that the split settings file exists + settingFile := filepath.Join(tempDir, "SettingFoo.json") + data, err := os.ReadFile(settingFile) + r.NoError(err) + r.Contains(string(data), "bar") +} + +// Test sanitizeExtractedPath with valid input. +func TestSanitizeExtractedPath_Valid(t *testing.T) { + t.Parallel() + a := assert.New(t) + r := require.New(t) + + tempDir := t.TempDir() + filePath := "api/fields/dummy.json" + + result, err := sanitizeExtractedPath(filePath, tempDir) + r.NoError(err, "Expected nil error from sanitizeExtractedPath") + + expExpected := filepath.Join(tempDir, "dummy.json") + absExpected, err := filepath.Abs(expExpected) + r.NoError(err, "Failed to get abs path") + a.Equal(absExpected, result, "Sanitized path mismatch") +} + +// Test extractJSON with invalid Setting.json content, expecting an unmarshal error. +func TestExtractJSON_InvalidSettings(t *testing.T) { + t.Parallel() + r := require.New(t) + tempDir := t.TempDir() + jarFile := createTempZipFile(t, map[string]string{"api/fields/Setting.json": "invalid json"}) + + err := extractJSON(jarFile, tempDir) + + r.Error(err) + r.ErrorContains(err, "unable to unmarshal settings") +} diff --git a/codegen/fwupdate.go b/codegen/fwupdate.go index 8422afc..a696669 100644 --- a/codegen/fwupdate.go +++ b/codegen/fwupdate.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/go-version" ) -var firmwareUpdateApi = "https://fw-update.ubnt.com/api/firmware-latest" +const defaultFirmwareUpdateApi = "https://fw-update.ubnt.com/api/firmware-latest" const ( debianPlatform = "debian" @@ -55,21 +55,20 @@ func (l *firmwareUpdateApiResponseEmbeddedFirmwareDataLink) MarshalJSON() ([]byt func (l *firmwareUpdateApiResponseEmbeddedFirmwareDataLink) UnmarshalJSON(j []byte) error { var m map[string]interface{} - - err := json.Unmarshal(j, &m) - if err != nil { + if err := json.Unmarshal(j, &m); err != nil { return err } - - if href := m["href"]; href != nil { - url, err := url.Parse(href.(string)) + if href, exists := m["href"]; exists && href != nil { + strHref, ok := href.(string) + if !ok { + return fmt.Errorf("expected string for href, got %T", href) + } + u, err := url.Parse(strHref) if err != nil { return err } - - l.Href = url + l.Href = u } - return nil } diff --git a/codegen/fwupdate_test.go b/codegen/fwupdate_test.go new file mode 100644 index 0000000..e6b8dd6 --- /dev/null +++ b/codegen/fwupdate_test.go @@ -0,0 +1,201 @@ +package main + +import ( + "encoding/json" + "net/url" + "testing" + + "github.com/hashicorp/go-version" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFirmwareUpdateApiFilter(t *testing.T) { + t.Parallel() + tests := []struct { + name string + key string + value string + expected string + }{ + {"channel", "channel", "release", "eq~~channel~~release"}, + {"product", "product", "unifi-controller", "eq~~product~~unifi-controller"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + a := assert.New(t) + + filter := firmwareUpdateApiFilter(tc.key, tc.value) + + a.Equal(tc.expected, filter) + }) + } +} + +func TestMarshalJSONDataLink(t *testing.T) { + t.Parallel() + tests := []struct { + name string + link firmwareUpdateApiResponseEmbeddedFirmwareDataLink + expectedJSON string + }{ + { + "nil", + firmwareUpdateApiResponseEmbeddedFirmwareDataLink{Href: nil}, + "{\"href\":\"\"}", + }, + { + "with value", + func() firmwareUpdateApiResponseEmbeddedFirmwareDataLink { + u, err := url.Parse("https://example.com/firmware") + require.NoError(t, err) // error checking in test setup + return firmwareUpdateApiResponseEmbeddedFirmwareDataLink{Href: u} + }(), + "{\"href\":\"https://example.com/firmware\"}", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + a := assert.New(t) + + b, err := json.Marshal(&tc.link) + + require.NoError(t, err) + a.JSONEq(tc.expectedJSON, string(b)) + }) + } +} + +func TestUnmarshalJSONDataLink(t *testing.T) { + t.Parallel() + tests := []struct { + name string + jsonStr string + shouldError bool + errorContains string + expectedHref *string + }{ + { + "valid", + `{"href": "https://example.com/firmware"}`, + false, + "", + func(s string) *string { return &s }("https://example.com/firmware"), + }, + { + "null", + `{"href": null}`, + false, + "", + nil, + }, + { + "missing", + `{}`, + false, + "", + nil, + }, + { + "non-string", + `{"href": 123}`, + true, + "expected string for href", + nil, + }, + { + "invalid json", + `{"href": }`, + true, + "", + nil, + }, + { + "invalid URL", + `{"href": "://missing"}`, + true, + "missing protocol scheme", + nil, + }, + { + "empty", + `{"href": ""}`, + false, + "", + func(s string) *string { return &s }(""), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + a := assert.New(t) + var link firmwareUpdateApiResponseEmbeddedFirmwareDataLink + + err := json.Unmarshal([]byte(tc.jsonStr), &link) + + if tc.shouldError { + require.Error(t, err) + if tc.errorContains != "" { + require.ErrorContains(t, err, tc.errorContains) + } + } else { + require.NoError(t, err) + if tc.expectedHref == nil { + a.Nil(link.Href) + } else { + require.NotNil(t, link.Href) + a.Equal(*tc.expectedHref, link.Href.String()) + } + } + }) + } +} + +func TestFirmwareUpdateApiResponse_Complete(t *testing.T) { + t.Parallel() + a := assert.New(t) + + ver, err := version.NewVersion("1.2.3") + require.NoError(t, err) + u, err := url.Parse("https://example.com/download") + require.NoError(t, err) + + req := firmwareUpdateApiResponse{ + Embedded: firmwareUpdateApiResponseEmbedded{ + Firmware: []firmwareUpdateApiResponseEmbeddedFirmware{ + { + Channel: "release", + Created: "2020-01-01T00:00:00Z", + Id: "unique-id", + Platform: "debian", + Product: "unifi-controller", + Version: ver, + Links: firmwareUpdateApiResponseEmbeddedFirmwareLinks{ + Data: firmwareUpdateApiResponseEmbeddedFirmwareDataLink{ + Href: u, + }, + }, + }, + }, + }, + } + + jsonBytes, err := json.Marshal(req) + require.NoError(t, err) + + var newReq firmwareUpdateApiResponse + err = json.Unmarshal(jsonBytes, &newReq) + require.NoError(t, err) + + require.Len(t, newReq.Embedded.Firmware, 1) + fw := newReq.Embedded.Firmware[0] + a.Equal("release", fw.Channel) + a.Equal("debian", fw.Platform) + a.NotNil(fw.Version) + a.Equal("https://example.com/download", fw.Links.Data.Href.String()) +} diff --git a/codegen/generator.go b/codegen/generator.go index f2cfc49..960c699 100644 --- a/codegen/generator.go +++ b/codegen/generator.go @@ -11,7 +11,6 @@ import ( "text/template" "github.com/iancoleman/strcase" - log "github.com/sirupsen/logrus" ) // Generatable is the interface for generation sources. @@ -43,6 +42,10 @@ func generateCodeFromTemplate(templateName, templateContent string, toWrite any) // generateCode generates code for each generation source and writes it to file. func generateCode(fieldsDir string, outDir string) error { + if _, err := ensurePath(outDir); err != nil { + return fmt.Errorf("unable to create output directory %s: %w", outDir, err) + } + generators := make([]Generatable, 0) resources, err := buildResourcesFromDownloadedFields(fieldsDir) if err != nil { diff --git a/codegen/main.go b/codegen/main.go index 702a995..cbadd93 100644 --- a/codegen/main.go +++ b/codegen/main.go @@ -9,9 +9,11 @@ import ( "path/filepath" "strings" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) +var log = logrus.New() + func usage() { fmt.Printf("Usage: %s [OPTIONS] version\n", path.Base(os.Args[0])) fmt.Printf("version can be a specific version or '%s' (default) for the latest UniFi Controller version\n", LatestVersionMarker) @@ -19,21 +21,29 @@ func usage() { } func setupLogging(debugEnabled, traceEnabled bool) { - log.SetFormatter(&log.TextFormatter{ + log.SetFormatter(&logrus.TextFormatter{ DisableTimestamp: true, DisableLevelTruncation: true, ForceColors: true, FullTimestamp: false, }) if traceEnabled { - log.SetLevel(log.TraceLevel) + log.SetLevel(logrus.TraceLevel) } else if debugEnabled { - log.SetLevel(log.DebugLevel) + log.SetLevel(logrus.DebugLevel) } else { - log.SetLevel(log.InfoLevel) + log.SetLevel(logrus.InfoLevel) } } +type options struct { + versionBaseDir string + outputDir string + downloadOnly bool + version string + firmwareUpdateApi string +} + func main() { flag.Usage = usage @@ -43,16 +53,31 @@ func main() { debugFlag := flag.Bool("debug", false, "Enable debug logging") traceFlag := flag.Bool("trace", false, "Enable trace logging") + flag.CommandLine.Init(os.Args[0], flag.PanicOnError) // set error handling to panic if parse ends with error flag.Parse() setupLogging(*debugFlag, *traceFlag) specifiedVersion := strings.TrimSpace(flag.Arg(0)) if specifiedVersion == "" { specifiedVersion = LatestVersionMarker // default to latest version } - unifiVersion, err := determineUnifiVersion(specifiedVersion) + err := generate(options{ + versionBaseDir: *versionBaseDirFlag, + outputDir: *outputDirFlag, + downloadOnly: *downloadOnly, + version: specifiedVersion, + firmwareUpdateApi: defaultFirmwareUpdateApi, + }) if err != nil { - log.Fatalf("unable to determine version and download URL for Unifi version %s: %s", specifiedVersion, err) - panic(err) + log.Error(err) + os.Exit(1) + } +} + +func generate(opts options) error { + p := NewUnifiVersionProvider(opts.firmwareUpdateApi) + unifiVersion, err := p.ByVersionMarker(opts.version) + if err != nil { + return fmt.Errorf("unable to determine version and download URL for Unifi version %s: %w", opts.version, err) } log.Infof("UniFi Controller version: %s", unifiVersion.Version) @@ -60,42 +85,49 @@ func main() { wd, err := os.Getwd() if err != nil { - log.Fatalf("unable to determine working directory: %s", err) - panic(err) + return fmt.Errorf("unable to determine working directory: %w", err) } - - structuresDir := filepath.Join(wd, *versionBaseDirFlag, fmt.Sprintf("v%s", unifiVersion.Version)) + var structuresDir string + if path.IsAbs(opts.versionBaseDir) { + structuresDir = opts.versionBaseDir + } else { + structuresDir = filepath.Join(wd, opts.versionBaseDir) + } + structuresDir = filepath.Join(structuresDir, fmt.Sprintf("v%s", unifiVersion.Version)) log.Infoln("Downloading UniFi Controller API structures definitions...") err = DownloadAndExtract(*unifiVersion.DownloadUrl, structuresDir) if err != nil { - log.Fatalf("unable to download and extract UniFi Controller API structures definitions: %s", err) - panic(err) + return fmt.Errorf("unable to download and extract UniFi Controller API structures definitions: %w", err) } log.Infof("Downloaded UniFi Controller API structures definitions in %s", structuresDir) - if *downloadOnly { + if opts.downloadOnly { log.Infoln("Structure JSONs ready!") - os.Exit(0) + return nil } log.Infoln("Generating resources code...") - outDir := filepath.Join(wd, *outputDirFlag) + + var outDir string + if path.IsAbs(opts.outputDir) { + outDir = opts.outputDir + } else { + outDir = filepath.Join(wd, opts.outputDir) + } if err = generateCode(structuresDir, outDir); err != nil { - log.Fatalf("unable to generate resources code: %s", err) - panic(err) + return fmt.Errorf("unable to generate resources code: %w", err) } log.Infof("Writing version file...") if err = writeVersionFile(unifiVersion.Version, outDir); err != nil { - log.Fatalf("failed to write version file to %s: %s", outDir, err) - panic(err) + return fmt.Errorf("failed to write version file to %s: %w", outDir, err) } basepath := filepath.Dir(wd) if err = writeVersionRepoMarkerFile(unifiVersion.Version, basepath); err != nil { - log.Fatalf("failed to write version file to %s: %s", basepath, err) - panic(err) + return fmt.Errorf("failed to write version file to %s: %w", basepath, err) } log.Infof("Generated resources in %s", outDir) + return nil } diff --git a/codegen/main_test.go b/codegen/main_test.go new file mode 100644 index 0000000..9492033 --- /dev/null +++ b/codegen/main_test.go @@ -0,0 +1,137 @@ +package main + +import ( + "os" + "os/exec" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetupLogging(t *testing.T) { + t.Parallel() + a := assert.New(t) + + setupLogging(false, false) + a.Equal(logrus.InfoLevel, log.Level) + + setupLogging(true, false) + a.Equal(logrus.DebugLevel, log.Level) + + setupLogging(false, true) + a.Equal(logrus.TraceLevel, log.Level) + + setupLogging(true, true) + a.Equal(logrus.TraceLevel, log.Level) +} + +// integration tests for the CLI +// these test require Internet access + +func execCli(args ...string) (string, error) { + in := []string{"run", "."} + in = append(in, args...) + cmd := exec.Command("go", in...) + output, err := cmd.CombinedOutput() + return string(output), err +} + +func TestHelpFlag(t *testing.T) { + t.Parallel() + + out, err := execCli("-h") + + require.Error(t, err) + assert.Contains(t, out, "Usage: codegen [OPTIONS] version") +} + +func TestInvalidFlag(t *testing.T) { + t.Parallel() + + out, err := execCli("-invalid") + + require.Error(t, err) + assert.Contains(t, out, "flag provided but not defined: -invalid") +} + +func TestDefaultVersion(t *testing.T) { + t.Parallel() + + out, err := execCli("-version-base-dir", t.TempDir(), "-output-dir", t.TempDir()) + + require.NoError(t, err) + assert.Contains(t, out, "UniFi Controller version") +} + +func testGenerate(t *testing.T, opts *options) error { + t.Helper() + + setupLogging(false, false) + if opts.versionBaseDir == "" { + opts.versionBaseDir = t.TempDir() + } + if opts.outputDir == "" { + opts.outputDir = t.TempDir() + } + if opts.firmwareUpdateApi == "" { + opts.firmwareUpdateApi = defaultFirmwareUpdateApi + } + return generate(*opts) +} + +func TestNonExistentVersion(t *testing.T) { + t.Parallel() + + err := testGenerate(t, &options{version: "1.2.3"}) + + require.Error(t, err) +} + +func TestInvalidVersion(t *testing.T) { + t.Parallel() + r := require.New(t) + + err := testGenerate(t, &options{version: "invalid-version"}) + + r.Error(err) + r.ErrorContains(err, "Malformed") + r.ErrorContains(err, "invalid-version") +} + +func TestGenerateLatest(t *testing.T) { + t.Parallel() + r := require.New(t) + + opts := &options{version: LatestVersionMarker} + + err := testGenerate(t, opts) + r.NoError(err) + + files, err := os.ReadDir(opts.versionBaseDir) + r.NoError(err) + assert.NotEmptyf(t, files, "version base dir '%s' should not be empty", opts.versionBaseDir) + + files, err = os.ReadDir(opts.outputDir) + r.NoError(err) + assert.NotEmptyf(t, files, "output dir '%s' should not be empty", opts.outputDir) +} + +func TestGenerateDownloadOnly(t *testing.T) { + t.Parallel() + r := require.New(t) + + opts := &options{version: LatestVersionMarker, downloadOnly: true} + + err := testGenerate(t, opts) + r.NoError(err) + + files, err := os.ReadDir(opts.versionBaseDir) + r.NoError(err) + assert.NotEmptyf(t, files, "version base dir '%s' should not be empty", opts.versionBaseDir) + + files, err = os.ReadDir(opts.outputDir) + r.NoError(err) // test generated dir + assert.Emptyf(t, files, "output dir '%s' should be empty", opts.outputDir) +} diff --git a/codegen/resources.go b/codegen/resources.go index 3030c4a..12a4d83 100644 --- a/codegen/resources.go +++ b/codegen/resources.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/iancoleman/strcase" - log "github.com/sirupsen/logrus" ) type replacement struct { diff --git a/codegen/utils.go b/codegen/utils.go new file mode 100644 index 0000000..cb85e3c --- /dev/null +++ b/codegen/utils.go @@ -0,0 +1,26 @@ +package main + +import ( + "errors" + "fmt" + "os" +) + +// ensurePath checks if a path exists and is a directory, if not it creates the directory. Returns true if the directories were created. +func ensurePath(path string) (bool, error) { + // Check if output directory exists, if not create and perform extraction + targetInfo, err := os.Stat(path) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return false, err + } + if err = os.MkdirAll(path, 0o755); err != nil { + return false, err + } + return true, nil + } + if !targetInfo.IsDir() { + return false, fmt.Errorf("%s isn't a directory", path) + } + return false, nil +} diff --git a/codegen/version.go b/codegen/version.go index fa2f48d..87bbaf5 100644 --- a/codegen/version.go +++ b/codegen/version.go @@ -31,8 +31,23 @@ func NewUnifiVersion(unifiVersion *version.Version, downloadUrl *url.URL) *Unifi } } -func latestUnifiVersion() (*UnifiVersion, error) { - url, err := url.Parse(firmwareUpdateApi) +type UnifiVersionProvider interface { + Latest() (*UnifiVersion, error) + ByVersionMarker(versionMarker string) (*UnifiVersion, error) +} + +type defaultUnifiVersionProvider struct { + firmwareUpdateApi string +} + +func NewUnifiVersionProvider(firmwareUpdateApi string) UnifiVersionProvider { //nolint:ireturn + return &defaultUnifiVersionProvider{ + firmwareUpdateApi: firmwareUpdateApi, + } +} + +func (p *defaultUnifiVersionProvider) Latest() (*UnifiVersion, error) { + url, err := url.Parse(p.firmwareUpdateApi) if err != nil { return nil, err } @@ -70,9 +85,9 @@ func latestUnifiVersion() (*UnifiVersion, error) { return nil, errors.New("no Unifi Controller firmware found") } -func determineUnifiVersion(versionMarker string) (*UnifiVersion, error) { +func (p *defaultUnifiVersionProvider) ByVersionMarker(versionMarker string) (*UnifiVersion, error) { if versionMarker == LatestVersionMarker { - return latestUnifiVersion() + return p.Latest() } else { unifiVersion, err := version.NewVersion(versionMarker) if err != nil { diff --git a/codegen/version_test.go b/codegen/version_test.go index 3feb93a..8b3a825 100644 --- a/codegen/version_test.go +++ b/codegen/version_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" ) -func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersion, error)) { +func assertLatestVersionUsingProvider(t *testing.T, provider func(p UnifiVersionProvider) (*UnifiVersion, error)) { t.Helper() assert := assert.New(t) require := require.New(t) @@ -85,8 +85,9 @@ func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersi })) defer server.Close() - firmwareUpdateApi = server.URL - gotVersion, err := provider() + p := NewUnifiVersionProvider(server.URL) + + gotVersion, err := provider(p) require.NoError(err) assert.Equal(fwVersion.Core(), gotVersion.Version) @@ -95,15 +96,15 @@ func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersi func TestLatestUnifiVersion(t *testing.T) { t.Parallel() - assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) { - return latestUnifiVersion() + assertLatestVersionUsingProvider(t, func(p UnifiVersionProvider) (*UnifiVersion, error) { + return p.Latest() }) } func TestDetermineUnifiVersion_latest(t *testing.T) { t.Parallel() - assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) { - return determineUnifiVersion(LatestVersionMarker) + assertLatestVersionUsingProvider(t, func(p UnifiVersionProvider) (*UnifiVersion, error) { + return p.ByVersionMarker(LatestVersionMarker) }) } @@ -121,7 +122,7 @@ func TestDetermineUnifiVersion_provided(t *testing.T) { t.Parallel() a := assert.New(t) - unifiVersion, err := determineUnifiVersion(providedVersion) + unifiVersion, err := NewUnifiVersionProvider(defaultFirmwareUpdateApi).ByVersionMarker(providedVersion) require.NoError(t, err) a.Equal(expectedVersion, unifiVersion.Version.String()) @@ -143,7 +144,7 @@ func TestDetermineUnifiVersion_invalid(t *testing.T) { for _, providedVersion := range testCases { t.Run(providedVersion, func(t *testing.T) { t.Parallel() - _, err := determineUnifiVersion(providedVersion) + _, err := NewUnifiVersionProvider(defaultFirmwareUpdateApi).ByVersionMarker(providedVersion) require.ErrorContains(t, err, providedVersion) }) } @@ -171,8 +172,7 @@ func TestLatestUnifiVersion_HttpError(t *testing.T) { })) defer server.Close() - firmwareUpdateApi = server.URL - _, err := latestUnifiVersion() + _, err := NewUnifiVersionProvider(server.URL).Latest() require.Error(t, err) } @@ -187,8 +187,8 @@ func TestLatestUnifiVersion_InvalidJson(t *testing.T) { })) defer server.Close() - firmwareUpdateApi = server.URL - _, err := latestUnifiVersion() + _, err := NewUnifiVersionProvider(server.URL).Latest() + require.Error(t, err) require.ErrorContains(t, err, "invalid") } @@ -224,8 +224,8 @@ func TestLatestUnifiVersion_NoDebianFirmware(t *testing.T) { })) defer server.Close() - firmwareUpdateApi = server.URL - _, err = latestUnifiVersion() + _, err = NewUnifiVersionProvider(server.URL).Latest() + require.Error(t, err) require.ErrorContains(t, err, "no Unifi Controller firmware found") } @@ -265,8 +265,7 @@ func TestWriteVersionRepoMarkerFile(t *testing.T) { func TestLatestUnifiVersion_InvalidUrl(t *testing.T) { t.Parallel() - firmwareUpdateApi = ":\\invalid" - _, err := latestUnifiVersion() + _, err := NewUnifiVersionProvider(":\\invalid").Latest() require.Error(t, err) require.ErrorContains(t, err, "invalid") } @@ -321,8 +320,7 @@ func TestLatestUnifiVersion_NilVersion(t *testing.T) { })) defer server.Close() - firmwareUpdateApi = server.URL - _, err := latestUnifiVersion() + _, err := NewUnifiVersionProvider(server.URL).Latest() require.Error(t, err) }