test: add comprehensive testing of codegen (#14)
This commit is contained in:
committed by
GitHub
parent
6016a3d34a
commit
bb21419c0e
@@ -16,24 +16,16 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/ulikunitz/xz"
|
||||
"github.com/xor-gate/ar"
|
||||
)
|
||||
|
||||
func DownloadAndExtract(downloadUrl url.URL, outputDir string) error {
|
||||
targetInfo, err := os.Stat(outputDir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
// Check if output directory exists, if not create and perform extraction
|
||||
|
||||
err = os.MkdirAll(outputDir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// download fields, create
|
||||
if created, err := ensurePath(outputDir); err != nil {
|
||||
return fmt.Errorf("unable to create output directory %s: %w", outputDir, err)
|
||||
} else if created {
|
||||
log.Debugf("downloading UniFi Controller package from: %s", downloadUrl.String())
|
||||
jarFile, err := downloadJar(downloadUrl, outputDir)
|
||||
if err != nil {
|
||||
@@ -41,18 +33,19 @@ func DownloadAndExtract(downloadUrl url.URL, outputDir string) error {
|
||||
}
|
||||
|
||||
log.Debugf("extracting JSON files with API structures from: %s to: %s", jarFile, outputDir)
|
||||
err = extractJSON(jarFile, outputDir)
|
||||
if err != nil {
|
||||
if err = extractJSON(jarFile, outputDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("JSON files extracted to: %s", outputDir)
|
||||
targetInfo, err = os.Stat(outputDir)
|
||||
_, err = os.Stat(outputDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !targetInfo.IsDir() {
|
||||
if targetInfo, err := os.Stat(outputDir); err != nil {
|
||||
return err
|
||||
} else if !targetInfo.IsDir() {
|
||||
return errors.New("fields info isn't a directory")
|
||||
}
|
||||
return nil
|
||||
@@ -68,10 +61,12 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to download UniFi Controller deb: %w", err)
|
||||
}
|
||||
if debResp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unable to download UniFi Controller deb: HTTP%d. Probably it does not exist under %s", debResp.StatusCode, downloadUrl.String())
|
||||
}
|
||||
defer debResp.Body.Close()
|
||||
|
||||
var uncompressedReader io.Reader
|
||||
|
||||
arReader := ar.NewReader(debResp.Body)
|
||||
for {
|
||||
header, err := arReader.Next()
|
||||
@@ -81,8 +76,6 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("in ar next: %w", err)
|
||||
}
|
||||
|
||||
// read the data file
|
||||
if header.Name == "data.tar.xz" {
|
||||
uncompressedReader, err = xz.NewReader(arReader)
|
||||
if err != nil {
|
||||
@@ -96,9 +89,7 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
|
||||
}
|
||||
|
||||
tarReader := tar.NewReader(uncompressedReader)
|
||||
|
||||
var aceJar *os.File
|
||||
|
||||
log.Debugln("extracting ace.jar from downloaded controller package")
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
@@ -108,12 +99,9 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("in next: %w", err)
|
||||
}
|
||||
|
||||
if header.Typeflag != tar.TypeReg || header.Name != "./usr/lib/unifi/lib/ace.jar" {
|
||||
// skipping
|
||||
continue
|
||||
}
|
||||
|
||||
aceJar, err = os.Create(filepath.Join(outputDir, "ace.jar"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to create temp file: %w", err)
|
||||
@@ -123,34 +111,14 @@ func downloadJar(downloadUrl url.URL, outputDir string) (string, error) {
|
||||
return "", fmt.Errorf("unable to write ace.jar temp file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if aceJar == nil {
|
||||
return "", errors.New("unable to find ace.jar")
|
||||
}
|
||||
|
||||
defer aceJar.Close()
|
||||
log.Debugf("ace.jar extracted to: %s", aceJar.Name())
|
||||
return aceJar.Name(), nil
|
||||
}
|
||||
|
||||
func sanitizeExtractedPath(filePath, destinationDir string) (string, error) {
|
||||
absDestinationDir, err := filepath.Abs(destinationDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
absFilePath, err := filepath.Abs(filepath.Join(destinationDir, filepath.Base(filePath)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(absFilePath, absDestinationDir) {
|
||||
return "", fmt.Errorf("invalid file path: %s", filePath)
|
||||
}
|
||||
|
||||
return absFilePath, nil
|
||||
}
|
||||
|
||||
func extractJSON(jarFile, fieldsDir string) error {
|
||||
jarZip, err := zip.OpenReader(jarFile)
|
||||
if err != nil {
|
||||
@@ -161,7 +129,6 @@ func extractJSON(jarFile, fieldsDir string) error {
|
||||
log.Tracef("opened jar %s with %d files", jarFile, len(jarZip.File))
|
||||
for _, f := range jarZip.File {
|
||||
if !strings.HasPrefix(f.Name, "api/fields/") || path.Ext(f.Name) != ".json" {
|
||||
// skip file
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -175,19 +142,16 @@ func extractJSON(jarFile, fieldsDir string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst, err := os.Create(dstPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
_, err = io.Copy(dst, src)
|
||||
log.Debugf("extracted %s", f.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
@@ -227,6 +191,23 @@ func extractJSON(jarFile, fieldsDir string) error {
|
||||
log.Tracef("splitted %s into %s", settingKey, fileName)
|
||||
}
|
||||
|
||||
// TODO: cleanup JSON
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeExtractedPath(filePath, destinationDir string) (string, error) {
|
||||
absDestinationDir, err := filepath.Abs(destinationDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
absFilePath, err := filepath.Abs(filepath.Join(destinationDir, filepath.Base(filePath)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(absFilePath, absDestinationDir) {
|
||||
return "", fmt.Errorf("invalid file path: %s", filePath)
|
||||
}
|
||||
|
||||
return absFilePath, nil
|
||||
}
|
||||
|
||||
141
codegen/download_test.go
Normal file
141
codegen/download_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper function to create a temporary zip file with given entries. 'entries' maps file names to their content.
|
||||
func createTempZipFile(t *testing.T, entries map[string]string) string {
|
||||
t.Helper()
|
||||
tempDir := t.TempDir()
|
||||
tempFileName := filepath.Join(tempDir, "test.zip")
|
||||
tempFile, err := os.Create(tempFileName)
|
||||
require.NoError(t, err, "Failed to create temp zip file")
|
||||
// We need to truncate and write zip contents
|
||||
w := zip.NewWriter(tempFile)
|
||||
for name, content := range entries {
|
||||
f, err := w.Create(name)
|
||||
require.NoError(t, err, "Failed to add entry %s", name)
|
||||
_, err = f.Write([]byte(content))
|
||||
require.NoError(t, err, "Failed to write content for %s", name)
|
||||
}
|
||||
err = w.Close()
|
||||
require.NoError(t, err, "Failed to close zip writer")
|
||||
err = tempFile.Close()
|
||||
require.NoError(t, err, "Failed to close temp file")
|
||||
return tempFile.Name()
|
||||
}
|
||||
|
||||
// Test when the output directory already exists. In this case, DownloadAndExtract should not call downloadJarFn or extractJSONFn.
|
||||
func TestDownloadAndExtract_WithExistingDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
testURL, _ := url.Parse("http://example.com/test.deb")
|
||||
|
||||
err := DownloadAndExtract(*testURL, tempDir)
|
||||
|
||||
r.NoError(err, "Expected no error when directory exists")
|
||||
}
|
||||
|
||||
// // Test when output path is not a directory.
|
||||
func TestDownloadAndExtract_PathNotDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := filepath.Join(tempDir, "dummy")
|
||||
_, err := os.Create(tempFilePath)
|
||||
r.NoError(err, "Failed to create temp file")
|
||||
testURL, _ := url.Parse("http://example.com/test.deb")
|
||||
|
||||
err = DownloadAndExtract(*testURL, tempFilePath)
|
||||
|
||||
r.Error(err, "Expected error because tempFilePath is not a directory")
|
||||
r.ErrorContains(err, tempFilePath+" isn't a directory")
|
||||
}
|
||||
|
||||
// // Test extractJSON when the jar file cannot be opened.
|
||||
func TestExtractJSON_OpenJarError(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
err := extractJSON("nonexisting.jar", t.TempDir())
|
||||
|
||||
r.Error(err)
|
||||
r.ErrorContains(err, "unable to open jar")
|
||||
}
|
||||
|
||||
// Test extractJSON with a valid zip file that contains a JSON file under api/fields/ and no Setting.json (so splitting is skipped).
|
||||
func TestExtractJSON_NoSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
tempDir := t.TempDir()
|
||||
jarFile := createTempZipFile(t, map[string]string{"api/fields/dummy.json": "{\"key\": \"value\"}"})
|
||||
|
||||
err := extractJSON(jarFile, tempDir)
|
||||
r.NoError(err)
|
||||
|
||||
// Check that dummy.json has been extracted
|
||||
expectedPath := filepath.Join(tempDir, "dummy.json")
|
||||
data, err := os.ReadFile(expectedPath)
|
||||
r.NoError(err, "Expected file %s to exist", expectedPath)
|
||||
r.JSONEq("{\"key\": \"value\"}", string(data), "Extracted file content mismatch")
|
||||
}
|
||||
|
||||
// Test extractJSON with Setting.json present, so that it splits settings into individual files.
|
||||
func TestExtractJSON_WithSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
tempDir := t.TempDir()
|
||||
entries := map[string]string{"api/fields/Setting.json": "{\"foo\": {\"bar\": 1}}"}
|
||||
jarFile := createTempZipFile(t, entries)
|
||||
|
||||
err := extractJSON(jarFile, tempDir)
|
||||
r.NoError(err)
|
||||
|
||||
// Check that the split settings file exists
|
||||
settingFile := filepath.Join(tempDir, "SettingFoo.json")
|
||||
data, err := os.ReadFile(settingFile)
|
||||
r.NoError(err)
|
||||
r.Contains(string(data), "bar")
|
||||
}
|
||||
|
||||
// Test sanitizeExtractedPath with valid input.
|
||||
func TestSanitizeExtractedPath_Valid(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
filePath := "api/fields/dummy.json"
|
||||
|
||||
result, err := sanitizeExtractedPath(filePath, tempDir)
|
||||
r.NoError(err, "Expected nil error from sanitizeExtractedPath")
|
||||
|
||||
expExpected := filepath.Join(tempDir, "dummy.json")
|
||||
absExpected, err := filepath.Abs(expExpected)
|
||||
r.NoError(err, "Failed to get abs path")
|
||||
a.Equal(absExpected, result, "Sanitized path mismatch")
|
||||
}
|
||||
|
||||
// Test extractJSON with invalid Setting.json content, expecting an unmarshal error.
|
||||
func TestExtractJSON_InvalidSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
tempDir := t.TempDir()
|
||||
jarFile := createTempZipFile(t, map[string]string{"api/fields/Setting.json": "invalid json"})
|
||||
|
||||
err := extractJSON(jarFile, tempDir)
|
||||
|
||||
r.Error(err)
|
||||
r.ErrorContains(err, "unable to unmarshal settings")
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
var firmwareUpdateApi = "https://fw-update.ubnt.com/api/firmware-latest"
|
||||
const defaultFirmwareUpdateApi = "https://fw-update.ubnt.com/api/firmware-latest"
|
||||
|
||||
const (
|
||||
debianPlatform = "debian"
|
||||
@@ -55,21 +55,20 @@ func (l *firmwareUpdateApiResponseEmbeddedFirmwareDataLink) MarshalJSON() ([]byt
|
||||
|
||||
func (l *firmwareUpdateApiResponseEmbeddedFirmwareDataLink) UnmarshalJSON(j []byte) error {
|
||||
var m map[string]interface{}
|
||||
|
||||
err := json.Unmarshal(j, &m)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal(j, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if href := m["href"]; href != nil {
|
||||
url, err := url.Parse(href.(string))
|
||||
if href, exists := m["href"]; exists && href != nil {
|
||||
strHref, ok := href.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected string for href, got %T", href)
|
||||
}
|
||||
u, err := url.Parse(strHref)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.Href = url
|
||||
l.Href = u
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
201
codegen/fwupdate_test.go
Normal file
201
codegen/fwupdate_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFirmwareUpdateApiFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
expected string
|
||||
}{
|
||||
{"channel", "channel", "release", "eq~~channel~~release"},
|
||||
{"product", "product", "unifi-controller", "eq~~product~~unifi-controller"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
|
||||
filter := firmwareUpdateApiFilter(tc.key, tc.value)
|
||||
|
||||
a.Equal(tc.expected, filter)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalJSONDataLink(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
link firmwareUpdateApiResponseEmbeddedFirmwareDataLink
|
||||
expectedJSON string
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
firmwareUpdateApiResponseEmbeddedFirmwareDataLink{Href: nil},
|
||||
"{\"href\":\"\"}",
|
||||
},
|
||||
{
|
||||
"with value",
|
||||
func() firmwareUpdateApiResponseEmbeddedFirmwareDataLink {
|
||||
u, err := url.Parse("https://example.com/firmware")
|
||||
require.NoError(t, err) // error checking in test setup
|
||||
return firmwareUpdateApiResponseEmbeddedFirmwareDataLink{Href: u}
|
||||
}(),
|
||||
"{\"href\":\"https://example.com/firmware\"}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
|
||||
b, err := json.Marshal(&tc.link)
|
||||
|
||||
require.NoError(t, err)
|
||||
a.JSONEq(tc.expectedJSON, string(b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSONDataLink(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonStr string
|
||||
shouldError bool
|
||||
errorContains string
|
||||
expectedHref *string
|
||||
}{
|
||||
{
|
||||
"valid",
|
||||
`{"href": "https://example.com/firmware"}`,
|
||||
false,
|
||||
"",
|
||||
func(s string) *string { return &s }("https://example.com/firmware"),
|
||||
},
|
||||
{
|
||||
"null",
|
||||
`{"href": null}`,
|
||||
false,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"missing",
|
||||
`{}`,
|
||||
false,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"non-string",
|
||||
`{"href": 123}`,
|
||||
true,
|
||||
"expected string for href",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"invalid json",
|
||||
`{"href": }`,
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"invalid URL",
|
||||
`{"href": "://missing"}`,
|
||||
true,
|
||||
"missing protocol scheme",
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
`{"href": ""}`,
|
||||
false,
|
||||
"",
|
||||
func(s string) *string { return &s }(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
var link firmwareUpdateApiResponseEmbeddedFirmwareDataLink
|
||||
|
||||
err := json.Unmarshal([]byte(tc.jsonStr), &link)
|
||||
|
||||
if tc.shouldError {
|
||||
require.Error(t, err)
|
||||
if tc.errorContains != "" {
|
||||
require.ErrorContains(t, err, tc.errorContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if tc.expectedHref == nil {
|
||||
a.Nil(link.Href)
|
||||
} else {
|
||||
require.NotNil(t, link.Href)
|
||||
a.Equal(*tc.expectedHref, link.Href.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirmwareUpdateApiResponse_Complete(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
|
||||
ver, err := version.NewVersion("1.2.3")
|
||||
require.NoError(t, err)
|
||||
u, err := url.Parse("https://example.com/download")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := firmwareUpdateApiResponse{
|
||||
Embedded: firmwareUpdateApiResponseEmbedded{
|
||||
Firmware: []firmwareUpdateApiResponseEmbeddedFirmware{
|
||||
{
|
||||
Channel: "release",
|
||||
Created: "2020-01-01T00:00:00Z",
|
||||
Id: "unique-id",
|
||||
Platform: "debian",
|
||||
Product: "unifi-controller",
|
||||
Version: ver,
|
||||
Links: firmwareUpdateApiResponseEmbeddedFirmwareLinks{
|
||||
Data: firmwareUpdateApiResponseEmbeddedFirmwareDataLink{
|
||||
Href: u,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var newReq firmwareUpdateApiResponse
|
||||
err = json.Unmarshal(jsonBytes, &newReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, newReq.Embedded.Firmware, 1)
|
||||
fw := newReq.Embedded.Firmware[0]
|
||||
a.Equal("release", fw.Channel)
|
||||
a.Equal("debian", fw.Platform)
|
||||
a.NotNil(fw.Version)
|
||||
a.Equal("https://example.com/download", fw.Links.Data.Href.String())
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"text/template"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Generatable is the interface for generation sources.
|
||||
@@ -43,6 +42,10 @@ func generateCodeFromTemplate(templateName, templateContent string, toWrite any)
|
||||
|
||||
// generateCode generates code for each generation source and writes it to file.
|
||||
func generateCode(fieldsDir string, outDir string) error {
|
||||
if _, err := ensurePath(outDir); err != nil {
|
||||
return fmt.Errorf("unable to create output directory %s: %w", outDir, err)
|
||||
}
|
||||
|
||||
generators := make([]Generatable, 0)
|
||||
resources, err := buildResourcesFromDownloadedFields(fieldsDir)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,9 +9,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var log = logrus.New()
|
||||
|
||||
func usage() {
|
||||
fmt.Printf("Usage: %s [OPTIONS] version\n", path.Base(os.Args[0]))
|
||||
fmt.Printf("version can be a specific version or '%s' (default) for the latest UniFi Controller version\n", LatestVersionMarker)
|
||||
@@ -19,21 +21,29 @@ func usage() {
|
||||
}
|
||||
|
||||
func setupLogging(debugEnabled, traceEnabled bool) {
|
||||
log.SetFormatter(&log.TextFormatter{
|
||||
log.SetFormatter(&logrus.TextFormatter{
|
||||
DisableTimestamp: true,
|
||||
DisableLevelTruncation: true,
|
||||
ForceColors: true,
|
||||
FullTimestamp: false,
|
||||
})
|
||||
if traceEnabled {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
} else if debugEnabled {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
log.SetLevel(logrus.DebugLevel)
|
||||
} else {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
log.SetLevel(logrus.InfoLevel)
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
versionBaseDir string
|
||||
outputDir string
|
||||
downloadOnly bool
|
||||
version string
|
||||
firmwareUpdateApi string
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
|
||||
@@ -43,16 +53,31 @@ func main() {
|
||||
debugFlag := flag.Bool("debug", false, "Enable debug logging")
|
||||
traceFlag := flag.Bool("trace", false, "Enable trace logging")
|
||||
|
||||
flag.CommandLine.Init(os.Args[0], flag.PanicOnError) // set error handling to panic if parse ends with error
|
||||
flag.Parse()
|
||||
setupLogging(*debugFlag, *traceFlag)
|
||||
specifiedVersion := strings.TrimSpace(flag.Arg(0))
|
||||
if specifiedVersion == "" {
|
||||
specifiedVersion = LatestVersionMarker // default to latest version
|
||||
}
|
||||
unifiVersion, err := determineUnifiVersion(specifiedVersion)
|
||||
err := generate(options{
|
||||
versionBaseDir: *versionBaseDirFlag,
|
||||
outputDir: *outputDirFlag,
|
||||
downloadOnly: *downloadOnly,
|
||||
version: specifiedVersion,
|
||||
firmwareUpdateApi: defaultFirmwareUpdateApi,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("unable to determine version and download URL for Unifi version %s: %s", specifiedVersion, err)
|
||||
panic(err)
|
||||
log.Error(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func generate(opts options) error {
|
||||
p := NewUnifiVersionProvider(opts.firmwareUpdateApi)
|
||||
unifiVersion, err := p.ByVersionMarker(opts.version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to determine version and download URL for Unifi version %s: %w", opts.version, err)
|
||||
}
|
||||
|
||||
log.Infof("UniFi Controller version: %s", unifiVersion.Version)
|
||||
@@ -60,42 +85,49 @@ func main() {
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("unable to determine working directory: %s", err)
|
||||
panic(err)
|
||||
return fmt.Errorf("unable to determine working directory: %w", err)
|
||||
}
|
||||
|
||||
structuresDir := filepath.Join(wd, *versionBaseDirFlag, fmt.Sprintf("v%s", unifiVersion.Version))
|
||||
var structuresDir string
|
||||
if path.IsAbs(opts.versionBaseDir) {
|
||||
structuresDir = opts.versionBaseDir
|
||||
} else {
|
||||
structuresDir = filepath.Join(wd, opts.versionBaseDir)
|
||||
}
|
||||
structuresDir = filepath.Join(structuresDir, fmt.Sprintf("v%s", unifiVersion.Version))
|
||||
log.Infoln("Downloading UniFi Controller API structures definitions...")
|
||||
err = DownloadAndExtract(*unifiVersion.DownloadUrl, structuresDir)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to download and extract UniFi Controller API structures definitions: %s", err)
|
||||
panic(err)
|
||||
return fmt.Errorf("unable to download and extract UniFi Controller API structures definitions: %w", err)
|
||||
}
|
||||
log.Infof("Downloaded UniFi Controller API structures definitions in %s", structuresDir)
|
||||
|
||||
if *downloadOnly {
|
||||
if opts.downloadOnly {
|
||||
log.Infoln("Structure JSONs ready!")
|
||||
os.Exit(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infoln("Generating resources code...")
|
||||
outDir := filepath.Join(wd, *outputDirFlag)
|
||||
|
||||
var outDir string
|
||||
if path.IsAbs(opts.outputDir) {
|
||||
outDir = opts.outputDir
|
||||
} else {
|
||||
outDir = filepath.Join(wd, opts.outputDir)
|
||||
}
|
||||
if err = generateCode(structuresDir, outDir); err != nil {
|
||||
log.Fatalf("unable to generate resources code: %s", err)
|
||||
panic(err)
|
||||
return fmt.Errorf("unable to generate resources code: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Writing version file...")
|
||||
if err = writeVersionFile(unifiVersion.Version, outDir); err != nil {
|
||||
log.Fatalf("failed to write version file to %s: %s", outDir, err)
|
||||
panic(err)
|
||||
return fmt.Errorf("failed to write version file to %s: %w", outDir, err)
|
||||
}
|
||||
|
||||
basepath := filepath.Dir(wd)
|
||||
if err = writeVersionRepoMarkerFile(unifiVersion.Version, basepath); err != nil {
|
||||
log.Fatalf("failed to write version file to %s: %s", basepath, err)
|
||||
panic(err)
|
||||
return fmt.Errorf("failed to write version file to %s: %w", basepath, err)
|
||||
}
|
||||
|
||||
log.Infof("Generated resources in %s", outDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
137
codegen/main_test.go
Normal file
137
codegen/main_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetupLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
|
||||
setupLogging(false, false)
|
||||
a.Equal(logrus.InfoLevel, log.Level)
|
||||
|
||||
setupLogging(true, false)
|
||||
a.Equal(logrus.DebugLevel, log.Level)
|
||||
|
||||
setupLogging(false, true)
|
||||
a.Equal(logrus.TraceLevel, log.Level)
|
||||
|
||||
setupLogging(true, true)
|
||||
a.Equal(logrus.TraceLevel, log.Level)
|
||||
}
|
||||
|
||||
// integration tests for the CLI
|
||||
// these test require Internet access
|
||||
|
||||
func execCli(args ...string) (string, error) {
|
||||
in := []string{"run", "."}
|
||||
in = append(in, args...)
|
||||
cmd := exec.Command("go", in...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
return string(output), err
|
||||
}
|
||||
|
||||
func TestHelpFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out, err := execCli("-h")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, out, "Usage: codegen [OPTIONS] version")
|
||||
}
|
||||
|
||||
func TestInvalidFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out, err := execCli("-invalid")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, out, "flag provided but not defined: -invalid")
|
||||
}
|
||||
|
||||
func TestDefaultVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
out, err := execCli("-version-base-dir", t.TempDir(), "-output-dir", t.TempDir())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, out, "UniFi Controller version")
|
||||
}
|
||||
|
||||
func testGenerate(t *testing.T, opts *options) error {
|
||||
t.Helper()
|
||||
|
||||
setupLogging(false, false)
|
||||
if opts.versionBaseDir == "" {
|
||||
opts.versionBaseDir = t.TempDir()
|
||||
}
|
||||
if opts.outputDir == "" {
|
||||
opts.outputDir = t.TempDir()
|
||||
}
|
||||
if opts.firmwareUpdateApi == "" {
|
||||
opts.firmwareUpdateApi = defaultFirmwareUpdateApi
|
||||
}
|
||||
return generate(*opts)
|
||||
}
|
||||
|
||||
func TestNonExistentVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := testGenerate(t, &options{version: "1.2.3"})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
err := testGenerate(t, &options{version: "invalid-version"})
|
||||
|
||||
r.Error(err)
|
||||
r.ErrorContains(err, "Malformed")
|
||||
r.ErrorContains(err, "invalid-version")
|
||||
}
|
||||
|
||||
func TestGenerateLatest(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
opts := &options{version: LatestVersionMarker}
|
||||
|
||||
err := testGenerate(t, opts)
|
||||
r.NoError(err)
|
||||
|
||||
files, err := os.ReadDir(opts.versionBaseDir)
|
||||
r.NoError(err)
|
||||
assert.NotEmptyf(t, files, "version base dir '%s' should not be empty", opts.versionBaseDir)
|
||||
|
||||
files, err = os.ReadDir(opts.outputDir)
|
||||
r.NoError(err)
|
||||
assert.NotEmptyf(t, files, "output dir '%s' should not be empty", opts.outputDir)
|
||||
}
|
||||
|
||||
func TestGenerateDownloadOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := require.New(t)
|
||||
|
||||
opts := &options{version: LatestVersionMarker, downloadOnly: true}
|
||||
|
||||
err := testGenerate(t, opts)
|
||||
r.NoError(err)
|
||||
|
||||
files, err := os.ReadDir(opts.versionBaseDir)
|
||||
r.NoError(err)
|
||||
assert.NotEmptyf(t, files, "version base dir '%s' should not be empty", opts.versionBaseDir)
|
||||
|
||||
files, err = os.ReadDir(opts.outputDir)
|
||||
r.NoError(err) // test generated dir
|
||||
assert.Emptyf(t, files, "output dir '%s' should be empty", opts.outputDir)
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/iancoleman/strcase"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type replacement struct {
|
||||
|
||||
26
codegen/utils.go
Normal file
26
codegen/utils.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ensurePath checks if a path exists and is a directory, if not it creates the directory. Returns true if the directories were created.
|
||||
func ensurePath(path string) (bool, error) {
|
||||
// Check if output directory exists, if not create and perform extraction
|
||||
targetInfo, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return false, err
|
||||
}
|
||||
if err = os.MkdirAll(path, 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
if !targetInfo.IsDir() {
|
||||
return false, fmt.Errorf("%s isn't a directory", path)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -31,8 +31,23 @@ func NewUnifiVersion(unifiVersion *version.Version, downloadUrl *url.URL) *Unifi
|
||||
}
|
||||
}
|
||||
|
||||
func latestUnifiVersion() (*UnifiVersion, error) {
|
||||
url, err := url.Parse(firmwareUpdateApi)
|
||||
type UnifiVersionProvider interface {
|
||||
Latest() (*UnifiVersion, error)
|
||||
ByVersionMarker(versionMarker string) (*UnifiVersion, error)
|
||||
}
|
||||
|
||||
type defaultUnifiVersionProvider struct {
|
||||
firmwareUpdateApi string
|
||||
}
|
||||
|
||||
func NewUnifiVersionProvider(firmwareUpdateApi string) UnifiVersionProvider { //nolint:ireturn
|
||||
return &defaultUnifiVersionProvider{
|
||||
firmwareUpdateApi: firmwareUpdateApi,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *defaultUnifiVersionProvider) Latest() (*UnifiVersion, error) {
|
||||
url, err := url.Parse(p.firmwareUpdateApi)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -70,9 +85,9 @@ func latestUnifiVersion() (*UnifiVersion, error) {
|
||||
return nil, errors.New("no Unifi Controller firmware found")
|
||||
}
|
||||
|
||||
func determineUnifiVersion(versionMarker string) (*UnifiVersion, error) {
|
||||
func (p *defaultUnifiVersionProvider) ByVersionMarker(versionMarker string) (*UnifiVersion, error) {
|
||||
if versionMarker == LatestVersionMarker {
|
||||
return latestUnifiVersion()
|
||||
return p.Latest()
|
||||
} else {
|
||||
unifiVersion, err := version.NewVersion(versionMarker)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersion, error)) {
|
||||
func assertLatestVersionUsingProvider(t *testing.T, provider func(p UnifiVersionProvider) (*UnifiVersion, error)) {
|
||||
t.Helper()
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
@@ -85,8 +85,9 @@ func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersi
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
firmwareUpdateApi = server.URL
|
||||
gotVersion, err := provider()
|
||||
p := NewUnifiVersionProvider(server.URL)
|
||||
|
||||
gotVersion, err := provider(p)
|
||||
require.NoError(err)
|
||||
|
||||
assert.Equal(fwVersion.Core(), gotVersion.Version)
|
||||
@@ -95,15 +96,15 @@ func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersi
|
||||
|
||||
func TestLatestUnifiVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) {
|
||||
return latestUnifiVersion()
|
||||
assertLatestVersionUsingProvider(t, func(p UnifiVersionProvider) (*UnifiVersion, error) {
|
||||
return p.Latest()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetermineUnifiVersion_latest(t *testing.T) {
|
||||
t.Parallel()
|
||||
assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) {
|
||||
return determineUnifiVersion(LatestVersionMarker)
|
||||
assertLatestVersionUsingProvider(t, func(p UnifiVersionProvider) (*UnifiVersion, error) {
|
||||
return p.ByVersionMarker(LatestVersionMarker)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -121,7 +122,7 @@ func TestDetermineUnifiVersion_provided(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := assert.New(t)
|
||||
|
||||
unifiVersion, err := determineUnifiVersion(providedVersion)
|
||||
unifiVersion, err := NewUnifiVersionProvider(defaultFirmwareUpdateApi).ByVersionMarker(providedVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(expectedVersion, unifiVersion.Version.String())
|
||||
@@ -143,7 +144,7 @@ func TestDetermineUnifiVersion_invalid(t *testing.T) {
|
||||
for _, providedVersion := range testCases {
|
||||
t.Run(providedVersion, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := determineUnifiVersion(providedVersion)
|
||||
_, err := NewUnifiVersionProvider(defaultFirmwareUpdateApi).ByVersionMarker(providedVersion)
|
||||
require.ErrorContains(t, err, providedVersion)
|
||||
})
|
||||
}
|
||||
@@ -171,8 +172,7 @@ func TestLatestUnifiVersion_HttpError(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
firmwareUpdateApi = server.URL
|
||||
_, err := latestUnifiVersion()
|
||||
_, err := NewUnifiVersionProvider(server.URL).Latest()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -187,8 +187,8 @@ func TestLatestUnifiVersion_InvalidJson(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
firmwareUpdateApi = server.URL
|
||||
_, err := latestUnifiVersion()
|
||||
_, err := NewUnifiVersionProvider(server.URL).Latest()
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "invalid")
|
||||
}
|
||||
@@ -224,8 +224,8 @@ func TestLatestUnifiVersion_NoDebianFirmware(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
firmwareUpdateApi = server.URL
|
||||
_, err = latestUnifiVersion()
|
||||
_, err = NewUnifiVersionProvider(server.URL).Latest()
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "no Unifi Controller firmware found")
|
||||
}
|
||||
@@ -265,8 +265,7 @@ func TestWriteVersionRepoMarkerFile(t *testing.T) {
|
||||
func TestLatestUnifiVersion_InvalidUrl(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
firmwareUpdateApi = ":\\invalid"
|
||||
_, err := latestUnifiVersion()
|
||||
_, err := NewUnifiVersionProvider(":\\invalid").Latest()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "invalid")
|
||||
}
|
||||
@@ -321,8 +320,7 @@ func TestLatestUnifiVersion_NilVersion(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
firmwareUpdateApi = server.URL
|
||||
_, err := latestUnifiVersion()
|
||||
_, err := NewUnifiVersionProvider(server.URL).Latest()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user