From 3fda37369e46ea38280b6866496d4a5d814be8c1 Mon Sep 17 00:00:00 2001 From: Mateusz Filipowicz Date: Tue, 7 Jan 2025 02:10:23 +0100 Subject: [PATCH] feat: removed -latest flag and instead support latest version marker, refactor (#233) * feat: remove -latest flag and instead support latest version marker * refactor: move determining unifi version, download URL and downloading resources and field definitions to dedicated files * refactor: move writing version file to version.go file * refactor: move code generation code to dedicated generator.go file * chore: fix linting issues * chore: add version tests --- fields/{extract.go => download.go} | 42 +- fields/generator.go | 514 +++++++++++++++++ fields/{main_test.go => generator_test.go} | 0 fields/main.go | 620 +-------------------- fields/version.go | 72 ++- fields/version_test.go | 68 ++- unifi/fields.go | 2 +- 7 files changed, 705 insertions(+), 613 deletions(-) rename fields/{extract.go => download.go} (78%) create mode 100644 fields/generator.go rename fields/{main_test.go => generator_test.go} (100%) diff --git a/fields/extract.go b/fields/download.go similarity index 78% rename from fields/extract.go rename to fields/download.go index 962d21f..a5bdb5c 100644 --- a/fields/extract.go +++ b/fields/download.go @@ -20,15 +20,49 @@ import ( "github.com/xor-gate/ar" ) -func downloadJar(url *url.URL, outputDir string) (string, error) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url.String(), nil) +func DownloadAndExtract(downloadUrl url.URL, outputDir string) error { + targetInfo, err := os.Stat(outputDir) if err != nil { - return "", fmt.Errorf("unable to download deb: %w", err) + if !errors.Is(err, os.ErrNotExist) { + return err + } + + err = os.MkdirAll(outputDir, 0o755) + if err != nil { + return err + } + + // download fields, create + jarFile, err := downloadJar(downloadUrl, outputDir) + if err != nil { + return err + } + + err = extractJSON(jarFile, outputDir) + if err != nil { + return err + } + + targetInfo, err = os.Stat(outputDir) + if err != nil { + return err + } + } + if !targetInfo.IsDir() { + return errors.New("fields info isn't a directory") + } + return nil +} + +func downloadJar(downloadUrl url.URL, outputDir string) (string, error) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, downloadUrl.String(), nil) + if err != nil { + return "", fmt.Errorf("unable to download UniFi Controller deb: %w", err) } debResp, err := http.DefaultClient.Do(req) if err != nil { - return "", fmt.Errorf("unable to download deb: %w", err) + return "", fmt.Errorf("unable to download UniFi Controller deb: %w", err) } defer debResp.Body.Close() diff --git a/fields/generator.go b/fields/generator.go new file mode 100644 index 0000000..9ad580a --- /dev/null +++ b/fields/generator.go @@ -0,0 +1,514 @@ +package main + +import ( + "bytes" + _ "embed" + "encoding/json" + "fmt" + "go/format" + "io" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "text/template" + + "github.com/iancoleman/strcase" + log "github.com/sirupsen/logrus" +) + +type replacement struct { + Old string + New string +} + +var fieldReps = []replacement{ + {"Dhcpdv6", "DHCPDV6"}, + + {"Dhcpd", "DHCPD"}, + {"Idx", "IDX"}, + {"Ipsec", "IPSec"}, + {"Ipv6", "IPV6"}, + {"Openvpn", "OpenVPN"}, + {"Tftp", "TFTP"}, + {"Wlangroup", "WLANGroup"}, + + {"Bc", "Broadcast"}, + {"Dhcp", "DHCP"}, + {"Dns", "DNS"}, + {"Dpi", "DPI"}, + {"Dtim", "DTIM"}, + {"Firewallgroup", "FirewallGroup"}, + {"Fixedip", "FixedIP"}, + {"Icmp", "ICMP"}, + {"Id", "ID"}, + {"Igmp", "IGMP"}, + {"Ip", "IP"}, + {"Leasetime", "LeaseTime"}, + {"Mac", "MAC"}, + {"Mcastenhance", "MulticastEnhance"}, + {"Minrssi", "MinRSSI"}, + {"Monthdays", "MonthDays"}, + {"Nat", "NAT"}, + {"Networkconf", "Network"}, + {"Networkgroup", "NetworkGroup"}, + {"Pd", "PD"}, + {"Pmf", "PMF"}, + {"Portconf", "PortProfile"}, + {"Qos", "QOS"}, + {"Radiusprofile", "RADIUSProfile"}, + {"Radius", "RADIUS"}, + {"Ssid", "SSID"}, + {"Startdate", "StartDate"}, + {"Starttime", "StartTime"}, + {"Stopdate", "StopDate"}, + {"Stoptime", "StopTime"}, + {"Tcp", "TCP"}, + {"Udp", "UDP"}, + {"Usergroup", "UserGroup"}, + {"Utc", "UTC"}, + {"Vlan", "VLAN"}, + {"Vpn", "VPN"}, + {"Wan", "WAN"}, + {"Wep", "WEP"}, + {"Wlan", "WLAN"}, + {"Wpa", "WPA"}, +} + +var fileReps = []replacement{ + {"WlanConf", "WLAN"}, + {"Dhcp", "DHCP"}, + {"Wlan", "WLAN"}, + {"NetworkConf", "Network"}, + {"PortConf", "PortProfile"}, + {"RadiusProfile", "RADIUSProfile"}, + {"ApGroups", "APGroup"}, +} + +type Resource struct { + StructName string + ResourcePath string + Types map[string]*FieldInfo + FieldProcessor func(name string, f *FieldInfo) error +} + +type FieldInfo struct { + FieldName string + JSONName string + FieldType string + FieldValidation string + OmitEmpty bool + IsArray bool + Fields map[string]*FieldInfo + CustomUnmarshalType string + CustomUnmarshalFunc string +} + +func NewResource(structName string, resourcePath string) *Resource { + baseType := NewFieldInfo(structName, resourcePath, "struct", "", false, false, "") + resource := &Resource{ + StructName: structName, + ResourcePath: resourcePath, + Types: map[string]*FieldInfo{ + structName: baseType, + }, + FieldProcessor: func(name string, f *FieldInfo) error { return nil }, + } + + // Since template files iterate through map keys in sorted order, these initial fields + // are named such that they stay at the top for consistency. The spacer items create a + // blank line in the resulting generated file. + // + // This hack is here for stability of the generated code, but can be removed if desired. + baseType.Fields = map[string]*FieldInfo{ + " ID": NewFieldInfo("ID", "_id", "string", "", true, false, ""), + " SiteID": NewFieldInfo("SiteID", "site_id", "string", "", true, false, ""), + " _Spacer": nil, + + " Hidden": NewFieldInfo("Hidden", "attr_hidden", "bool", "", true, false, ""), + " HiddenID": NewFieldInfo("HiddenID", "attr_hidden_id", "string", "", true, false, ""), + " NoDelete": NewFieldInfo("NoDelete", "attr_no_delete", "bool", "", true, false, ""), + " NoEdit": NewFieldInfo("NoEdit", "attr_no_edit", "bool", "", true, false, ""), + " _Spacer": nil, + + " _Spacer": nil, + } + + switch { + case resource.IsSetting(): + resource.ResourcePath = strcase.ToSnake(strings.TrimPrefix(structName, "Setting")) + baseType.Fields[" Key"] = NewFieldInfo("Key", "key", "string", "", false, false, "") + + if resource.StructName == "SettingUsg" { + // Removed in v7, retaining for backwards compatibility + baseType.Fields["MdnsEnabled"] = NewFieldInfo("MdnsEnabled", "mdns_enabled", "bool", "", false, false, "") + } + case resource.StructName == "Device": + baseType.Fields[" MAC"] = NewFieldInfo("MAC", "mac", "string", "", true, false, "") + baseType.Fields["Adopted"] = NewFieldInfo("Adopted", "adopted", "bool", "", false, false, "") + baseType.Fields["Model"] = NewFieldInfo("Model", "model", "string", "", true, false, "") + baseType.Fields["State"] = NewFieldInfo("State", "state", "DeviceState", "", false, false, "") + baseType.Fields["Type"] = NewFieldInfo("Type", "type", "string", "", true, false, "") + case resource.StructName == "User": + baseType.Fields[" IP"] = NewFieldInfo("IP", "ip", "string", "non-generated field", true, false, "") + baseType.Fields[" DevIdOverride"] = NewFieldInfo("DevIdOverride", "dev_id_override", "int", "non-generated field", true, false, "") + case resource.StructName == "WLAN": + // this field removed in v6, retaining for backwards compatibility + baseType.Fields["WLANGroupID"] = NewFieldInfo("WLANGroupID", "wlangroup_id", "string", "", false, false, "") + } + + return resource +} + +func NewFieldInfo(fieldName string, jsonName string, fieldType string, fieldValidation string, omitempty bool, isArray bool, customUnmarshalType string) *FieldInfo { + return &FieldInfo{ + FieldName: fieldName, + JSONName: jsonName, + FieldType: fieldType, + FieldValidation: fieldValidation, + OmitEmpty: omitempty, + IsArray: isArray, + CustomUnmarshalType: customUnmarshalType, + } +} + +func cleanName(name string, reps []replacement) string { + for _, rep := range reps { + name = strings.ReplaceAll(name, rep.Old, rep.New) + } + + return name +} + +func (r *Resource) IsSetting() bool { + return strings.HasPrefix(r.StructName, "Setting") +} + +func (r *Resource) processFields(fields map[string]interface{}) { + t := r.Types[r.StructName] + for name, validation := range fields { + fieldInfo, err := r.fieldInfoFromValidation(name, validation) + if err != nil { + continue + } + + t.Fields[fieldInfo.FieldName] = fieldInfo + } +} + +func (r *Resource) fieldInfoFromValidation(name string, validation interface{}) (*FieldInfo, error) { + fieldName := strcase.ToCamel(name) + fieldName = cleanName(fieldName, fieldReps) + + empty := &FieldInfo{} + var fieldInfo *FieldInfo + + switch validation := validation.(type) { + case []interface{}: + if len(validation) == 0 { + fieldInfo = NewFieldInfo(fieldName, name, "string", "", false, true, "") + err := r.FieldProcessor(fieldName, fieldInfo) + return fieldInfo, err + } + if len(validation) > 1 { + return empty, fmt.Errorf("unknown validation %#v", validation) + } + + fieldInfo, err := r.fieldInfoFromValidation(name, validation[0]) + if err != nil { + return empty, err + } + + fieldInfo.OmitEmpty = true + fieldInfo.IsArray = true + + err = r.FieldProcessor(fieldName, fieldInfo) + return fieldInfo, err + + case map[string]interface{}: + typeName := r.StructName + fieldName + + result := NewFieldInfo(fieldName, name, typeName, "", true, false, "") + result.Fields = make(map[string]*FieldInfo) + + for name, fv := range validation { + child, err := r.fieldInfoFromValidation(name, fv) + if err != nil { + return empty, err + } + + result.Fields[child.FieldName] = child + } + + err := r.FieldProcessor(fieldName, result) + r.Types[typeName] = result + return result, err + + case string: + fieldValidation := validation + normalized := normalizeValidation(validation) + + omitEmpty := false + + switch { + case normalized == "falsetrue" || normalized == "truefalse": + fieldInfo = NewFieldInfo(fieldName, name, "bool", "", omitEmpty, false, "") + return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) + default: + if _, err := strconv.ParseFloat(normalized, 64); err == nil { + if normalized == "09" || normalized == "09.09" { + fieldValidation = "" + } + + if strings.Contains(normalized, ".") { + if strings.Contains(validation, "\\.){3}") { + break + } + + omitEmpty = true + fieldInfo = NewFieldInfo(fieldName, name, "float64", fieldValidation, omitEmpty, false, "") + return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) + } + + omitEmpty = true + fieldInfo = NewFieldInfo(fieldName, name, "int", fieldValidation, omitEmpty, false, "") + fieldInfo.CustomUnmarshalType = "emptyStringInt" + return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) + } + } + if validation != "" && normalized != "" { + log.Debugf("normalize %q to %q", validation, normalized) + } + + omitEmpty = omitEmpty || (!strings.Contains(validation, "^$") && !strings.HasSuffix(fieldName, "ID")) + fieldInfo = NewFieldInfo(fieldName, name, "string", fieldValidation, omitEmpty, false, "") + return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) + } + + return empty, fmt.Errorf("unable to determine type from validation %q", validation) +} + +func (r *Resource) processJSON(b []byte) error { + var fields map[string]interface{} + err := json.Unmarshal(b, &fields) + if err != nil { + return err + } + + r.processFields(fields) + + return nil +} + +//go:embed api.go.tmpl +var apiGoTemplate string + +func (r *Resource) generateCode() (string, error) { + var err error + var buf bytes.Buffer + writer := io.Writer(&buf) + + tpl := template.Must(template.New("api.go.tmpl").Parse(apiGoTemplate)) + + err = tpl.Execute(writer, r) + if err != nil { + return "", fmt.Errorf("failed to render template: %w", err) + } + + src, err := format.Source(buf.Bytes()) + if err != nil { + return "", fmt.Errorf("failed to format source: %w", err) + } + + return string(src), err +} + +func normalizeValidation(re string) string { + re = strings.ReplaceAll(re, "\\d", "[0-9]") + re = strings.ReplaceAll(re, "[-+]?", "") + re = strings.ReplaceAll(re, "[+-]?", "") + re = strings.ReplaceAll(re, "[-]?", "") + re = strings.ReplaceAll(re, "\\.", ".") + re = strings.ReplaceAll(re, "[.]?", ".") + + quants := regexp.MustCompile(`\{\d*,?\d*\}|\*|\+|\?`) + re = quants.ReplaceAllString(re, "") + + control := regexp.MustCompile(`[\(\[\]\)\|\-\$\^]`) + re = control.ReplaceAllString(re, "") + + re = strings.TrimPrefix(re, "^") + re = strings.TrimSuffix(re, "$") + + return re +} + +func generateCode(fieldsDir string, outDir string) error { + fieldsFiles, err := os.ReadDir(fieldsDir) + if err != nil { + return fmt.Errorf("unable to read fields directory %s: %w", fieldsDir, err) + } + for _, fieldsFile := range fieldsFiles { + name := fieldsFile.Name() + ext := filepath.Ext(name) + + switch name { + case "AuthenticationRequest.json", "Setting.json", "Wall.json": + continue + } + + if filepath.Ext(name) != ".json" { + continue + } + + log.Debugf("Processing %s...", fieldsFile.Name()) + name = name[:len(name)-len(ext)] + + urlPath := strings.ToLower(name) + structName := cleanName(name, fileReps) + + goFile := strcase.ToSnake(structName) + ".generated.go" + fieldsFilePath := filepath.Join(fieldsDir, fieldsFile.Name()) + b, err := os.ReadFile(fieldsFilePath) + if err != nil { + log.Warnf("skipping file %s: %s", fieldsFile.Name(), err) + continue + } + + resource := NewResource(structName, urlPath) + + switch resource.StructName { + case "Account": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + switch name { + case "IP", "NetworkID": + f.OmitEmpty = true + } + return nil + } + case "ChannelPlan": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + switch name { + case "Channel", "BackupChannel", "TxPower": + if f.FieldType == "string" { + f.CustomUnmarshalType = "numberOrString" + } + } + return nil + } + case "Device": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + switch name { + case "X", "Y": + f.FieldType = "float64" + case "StpPriority": + f.FieldType = "string" + f.CustomUnmarshalType = "numberOrString" + case "Ht": + f.FieldType = "int" + case "Channel", "BackupChannel", "TxPower": + if f.FieldType == "string" { + f.CustomUnmarshalType = "numberOrString" + } + case "LteExtAnt", "LtePoe": + f.CustomUnmarshalType = "booleanishString" + } + + f.OmitEmpty = true + switch name { + case "PortOverrides": + f.OmitEmpty = false + } + + return nil + } + case "Network": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + switch name { + case "InternetAccessEnabled", "IntraNetworkAccessEnabled": + if f.FieldType == "bool" { + f.CustomUnmarshalType = "*bool" + f.CustomUnmarshalFunc = "emptyBoolToTrue" + } + } + return nil + } + case "SettingGlobalAp": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + if strings.HasPrefix(name, "6E") { + f.FieldName = strings.Replace(f.FieldName, "6E", "SixE", 1) + } + + return nil + } + case "SettingMgmt": + sshKeyField := NewFieldInfo(resource.StructName+"XSshKeys", "x_ssh_keys", "struct", "", false, false, "") + sshKeyField.Fields = map[string]*FieldInfo{ + "name": NewFieldInfo("Name", "name", "string", "", false, false, ""), + "keyType": NewFieldInfo("KeyType", "type", "string", "", false, false, ""), + "key": NewFieldInfo("Key", "key", "string", "", false, false, ""), + "comment": NewFieldInfo("Comment", "comment", "string", "", false, false, ""), + "date": NewFieldInfo("Date", "date", "string", "", false, false, ""), + "fingerprint": NewFieldInfo("Fingerprint", "fingerprint", "string", "", false, false, ""), + } + resource.Types[sshKeyField.FieldName] = sshKeyField + + resource.FieldProcessor = func(name string, f *FieldInfo) error { + if name == "XSshKeys" { + f.FieldType = sshKeyField.FieldName + } + return nil + } + case "SettingUsg": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + if strings.HasSuffix(name, "Timeout") && name != "ArpCacheTimeout" { + f.FieldType = "int" + f.CustomUnmarshalType = "emptyStringInt" + } + return nil + } + case "User": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + switch name { + case "Blocked": + f.FieldType = "bool" + case "LastSeen": + f.FieldType = "int" + f.CustomUnmarshalType = "emptyStringInt" + } + return nil + } + case "WLAN": + resource.FieldProcessor = func(name string, f *FieldInfo) error { + switch name { + case "ScheduleWithDuration": + // always send schedule, so we can empty it if we want to + f.OmitEmpty = false + } + return nil + } + } + + err = resource.processJSON(b) + if err != nil { + log.Warnf("skipping file %s: %s", fieldsFile.Name(), err) + continue + } + + var code string + if code, err = resource.generateCode(); err != nil { + log.Errorf("failed to generate code for %s: %s", fieldsFile.Name(), err) + continue + } + + goFilePath := filepath.Join(outDir, goFile) + _ = os.Remove(goFilePath) + if err := os.WriteFile(goFile, ([]byte)(code), 0o644); err != nil { + log.Errorf("failed to write file %s: %s", goFile, err) + continue + } + log.Debugf("Generated %s with resource %s\n\n", goFile, structName) + } + return nil +} diff --git a/fields/main_test.go b/fields/generator_test.go similarity index 100% rename from fields/main_test.go rename to fields/generator_test.go diff --git a/fields/main.go b/fields/main.go index 151c11b..f0bd3ef 100644 --- a/fields/main.go +++ b/fields/main.go @@ -1,193 +1,20 @@ package main import ( - "bytes" _ "embed" - "encoding/json" - "errors" "flag" "fmt" - "go/format" - "io" - "net/url" "os" "path" "path/filepath" - "regexp" - "strconv" "strings" - "text/template" - "github.com/hashicorp/go-version" - "github.com/iancoleman/strcase" log "github.com/sirupsen/logrus" ) -type replacement struct { - Old string - New string -} - -var fieldReps = []replacement{ - {"Dhcpdv6", "DHCPDV6"}, - - {"Dhcpd", "DHCPD"}, - {"Idx", "IDX"}, - {"Ipsec", "IPSec"}, - {"Ipv6", "IPV6"}, - {"Openvpn", "OpenVPN"}, - {"Tftp", "TFTP"}, - {"Wlangroup", "WLANGroup"}, - - {"Bc", "Broadcast"}, - {"Dhcp", "DHCP"}, - {"Dns", "DNS"}, - {"Dpi", "DPI"}, - {"Dtim", "DTIM"}, - {"Firewallgroup", "FirewallGroup"}, - {"Fixedip", "FixedIP"}, - {"Icmp", "ICMP"}, - {"Id", "ID"}, - {"Igmp", "IGMP"}, - {"Ip", "IP"}, - {"Leasetime", "LeaseTime"}, - {"Mac", "MAC"}, - {"Mcastenhance", "MulticastEnhance"}, - {"Minrssi", "MinRSSI"}, - {"Monthdays", "MonthDays"}, - {"Nat", "NAT"}, - {"Networkconf", "Network"}, - {"Networkgroup", "NetworkGroup"}, - {"Pd", "PD"}, - {"Pmf", "PMF"}, - {"Portconf", "PortProfile"}, - {"Qos", "QOS"}, - {"Radiusprofile", "RADIUSProfile"}, - {"Radius", "RADIUS"}, - {"Ssid", "SSID"}, - {"Startdate", "StartDate"}, - {"Starttime", "StartTime"}, - {"Stopdate", "StopDate"}, - {"Stoptime", "StopTime"}, - {"Tcp", "TCP"}, - {"Udp", "UDP"}, - {"Usergroup", "UserGroup"}, - {"Utc", "UTC"}, - {"Vlan", "VLAN"}, - {"Vpn", "VPN"}, - {"Wan", "WAN"}, - {"Wep", "WEP"}, - {"Wlan", "WLAN"}, - {"Wpa", "WPA"}, -} - -var fileReps = []replacement{ - {"WlanConf", "WLAN"}, - {"Dhcp", "DHCP"}, - {"Wlan", "WLAN"}, - {"NetworkConf", "Network"}, - {"PortConf", "PortProfile"}, - {"RadiusProfile", "RADIUSProfile"}, - {"ApGroups", "APGroup"}, -} - -type Resource struct { - StructName string - ResourcePath string - Types map[string]*FieldInfo - FieldProcessor func(name string, f *FieldInfo) error -} - -type FieldInfo struct { - FieldName string - JSONName string - FieldType string - FieldValidation string - OmitEmpty bool - IsArray bool - Fields map[string]*FieldInfo - CustomUnmarshalType string - CustomUnmarshalFunc string -} - -func NewResource(structName string, resourcePath string) *Resource { - baseType := NewFieldInfo(structName, resourcePath, "struct", "", false, false, "") - resource := &Resource{ - StructName: structName, - ResourcePath: resourcePath, - Types: map[string]*FieldInfo{ - structName: baseType, - }, - FieldProcessor: func(name string, f *FieldInfo) error { return nil }, - } - - // Since template files iterate through map keys in sorted order, these initial fields - // are named such that they stay at the top for consistency. The spacer items create a - // blank line in the resulting generated file. - // - // This hack is here for stability of the generated code, but can be removed if desired. - baseType.Fields = map[string]*FieldInfo{ - " ID": NewFieldInfo("ID", "_id", "string", "", true, false, ""), - " SiteID": NewFieldInfo("SiteID", "site_id", "string", "", true, false, ""), - " _Spacer": nil, - - " Hidden": NewFieldInfo("Hidden", "attr_hidden", "bool", "", true, false, ""), - " HiddenID": NewFieldInfo("HiddenID", "attr_hidden_id", "string", "", true, false, ""), - " NoDelete": NewFieldInfo("NoDelete", "attr_no_delete", "bool", "", true, false, ""), - " NoEdit": NewFieldInfo("NoEdit", "attr_no_edit", "bool", "", true, false, ""), - " _Spacer": nil, - - " _Spacer": nil, - } - - switch { - case resource.IsSetting(): - resource.ResourcePath = strcase.ToSnake(strings.TrimPrefix(structName, "Setting")) - baseType.Fields[" Key"] = NewFieldInfo("Key", "key", "string", "", false, false, "") - - if resource.StructName == "SettingUsg" { - // Removed in v7, retaining for backwards compatibility - baseType.Fields["MdnsEnabled"] = NewFieldInfo("MdnsEnabled", "mdns_enabled", "bool", "", false, false, "") - } - case resource.StructName == "Device": - baseType.Fields[" MAC"] = NewFieldInfo("MAC", "mac", "string", "", true, false, "") - baseType.Fields["Adopted"] = NewFieldInfo("Adopted", "adopted", "bool", "", false, false, "") - baseType.Fields["Model"] = NewFieldInfo("Model", "model", "string", "", true, false, "") - baseType.Fields["State"] = NewFieldInfo("State", "state", "DeviceState", "", false, false, "") - baseType.Fields["Type"] = NewFieldInfo("Type", "type", "string", "", true, false, "") - case resource.StructName == "User": - baseType.Fields[" IP"] = NewFieldInfo("IP", "ip", "string", "non-generated field", true, false, "") - baseType.Fields[" DevIdOverride"] = NewFieldInfo("DevIdOverride", "dev_id_override", "int", "non-generated field", true, false, "") - case resource.StructName == "WLAN": - // this field removed in v6, retaining for backwards compatibility - baseType.Fields["WLANGroupID"] = NewFieldInfo("WLANGroupID", "wlangroup_id", "string", "", false, false, "") - } - - return resource -} - -func NewFieldInfo(fieldName string, jsonName string, fieldType string, fieldValidation string, omitempty bool, isArray bool, customUnmarshalType string) *FieldInfo { - return &FieldInfo{ - FieldName: fieldName, - JSONName: jsonName, - FieldType: fieldType, - FieldValidation: fieldValidation, - OmitEmpty: omitempty, - IsArray: isArray, - CustomUnmarshalType: customUnmarshalType, - } -} - -func cleanName(name string, reps []replacement) string { - for _, rep := range reps { - name = strings.ReplaceAll(name, rep.Old, rep.New) - } - - return name -} - func usage() { fmt.Printf("Usage: %s [OPTIONS] version\n", path.Base(os.Args[0])) + fmt.Printf("version can be a specific version or '%s' (default) for the latest UniFi Controller version\n", LatestVersionMarker) flag.PrintDefaults() } @@ -210,48 +37,22 @@ func main() { versionBaseDirFlag := flag.String("version-base-dir", ".", "The base directory for version JSON files") outputDirFlag := flag.String("output-dir", ".", "The output directory of the generated Go code") downloadOnly := flag.Bool("download-only", false, "Only download and build the fields JSON directory, do not generate") - useLatestVersion := flag.Bool("latest", false, "Use the latest available version") debugFlag := flag.Bool("debug", false, "Enable debug logging") flag.Parse() setupLogging(*debugFlag) - specifiedVersion := flag.Arg(0) - if specifiedVersion != "" && *useLatestVersion { - log.Error("cannot specify version with latest\n\n") - usage() - os.Exit(1) - } else if specifiedVersion == "" && !*useLatestVersion { - log.Error("must specify version or latest\n\n") - usage() - os.Exit(1) + specifiedVersion := strings.TrimSpace(flag.Arg(0)) + if specifiedVersion == "" { + specifiedVersion = LatestVersionMarker // default to latest version + } + unifiVersion, err := determineUnifiVersion(specifiedVersion) + if err != nil { + log.Fatalf("unable to determine version and download URL for Unifi version %s", specifiedVersion) + panic(err) } - var unifiVersion *version.Version - var unifiDownloadUrl *url.URL - var err error - - if *useLatestVersion { - unifiVersion, unifiDownloadUrl, err = latestUnifiVersion() - if err != nil { - log.Fatalln("unable to determine latest UniFi Controller version") - panic(err) - } - } else { - unifiVersion, err = version.NewVersion(specifiedVersion) - if err != nil { - log.Errorf("invalid version %s: %s", specifiedVersion, err) - os.Exit(1) - } - - unifiDownloadUrl, err = url.Parse(fmt.Sprintf("https://dl.ui.com/unifi/%s/unifi_sysvinit_all.deb", unifiVersion)) - if err != nil { - log.Fatalln("unable to parse download URL") - panic(err) - } - } - - log.Infof("UniFi Controller version: %s\n", unifiVersion) - log.Infof("UniFi Controller download URL: %s\n", unifiDownloadUrl) + log.Infof("UniFi Controller version: %s", unifiVersion.Version) + log.Infof("UniFi Controller download URL: %s", unifiVersion.DownloadUrl.String()) wd, err := os.Getwd() if err != nil { @@ -259,401 +60,32 @@ func main() { panic(err) } - fieldsDir := filepath.Join(wd, *versionBaseDirFlag, fmt.Sprintf("v%s", unifiVersion)) - outDir := filepath.Join(wd, *outputDirFlag) - - fieldsInfo, err := os.Stat(fieldsDir) + fieldsDir := filepath.Join(wd, *versionBaseDirFlag, fmt.Sprintf("v%s", unifiVersion.Version)) + log.Infoln("Downloading UniFi Controller field definitions...") + err = DownloadAndExtract(*unifiVersion.DownloadUrl, fieldsDir) if err != nil { - if !errors.Is(err, os.ErrNotExist) { - panic(err) - } - - err = os.MkdirAll(fieldsDir, 0o755) - if err != nil { - panic(err) - } - - log.Infoln("Downloading UniFi Controller JAR...") - // download fields, create - jarFile, err := downloadJar(unifiDownloadUrl, fieldsDir) - if err != nil { - log.Fatalf("unable to download UniFi Controller JAR from %s", unifiDownloadUrl) - panic(err) - } - - log.Debugln("Extracting fields JSONs...") - err = extractJSON(jarFile, fieldsDir) - if err != nil { - log.Fatalf("unable to extract fields JSONs from %s", jarFile) - panic(err) - } - - fieldsInfo, err = os.Stat(fieldsDir) - if err != nil { - panic(err) - } - } - if !fieldsInfo.IsDir() { - log.Errorln("version info isn't a directory") - os.Exit(1) + log.Fatalln("unable to download and extract UniFi Controller field definitions") + panic(err) } + log.Infof("Downloaded UniFi Controller field definitions in %s", fieldsDir) if *downloadOnly { log.Infoln("Fields JSON ready!") os.Exit(0) } - fieldsFiles, err := os.ReadDir(fieldsDir) - if err != nil { - log.Fatalf("unable to read fields directory %s", fieldsDir) + log.Infoln("Generating resources code...") + outDir := filepath.Join(wd, *outputDirFlag) + if err = generateCode(fieldsDir, outDir); err != nil { + log.Fatalln("unable to generate resources code") panic(err) } - log.Infoln("Generating resources...") - for _, fieldsFile := range fieldsFiles { - name := fieldsFile.Name() - ext := filepath.Ext(name) - - switch name { - case "AuthenticationRequest.json", "Setting.json", "Wall.json": - continue - } - - if filepath.Ext(name) != ".json" { - continue - } - - log.Debugf("Processing %s...", fieldsFile.Name()) - name = name[:len(name)-len(ext)] - - urlPath := strings.ToLower(name) - structName := cleanName(name, fileReps) - - goFile := strcase.ToSnake(structName) + ".generated.go" - fieldsFilePath := filepath.Join(fieldsDir, fieldsFile.Name()) - b, err := os.ReadFile(fieldsFilePath) - if err != nil { - log.Warnf("skipping file %s: %s", fieldsFile.Name(), err) - continue - } - - resource := NewResource(structName, urlPath) - - switch resource.StructName { - case "Account": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - switch name { - case "IP", "NetworkID": - f.OmitEmpty = true - } - return nil - } - case "ChannelPlan": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - switch name { - case "Channel", "BackupChannel", "TxPower": - if f.FieldType == "string" { - f.CustomUnmarshalType = "numberOrString" - } - } - return nil - } - case "Device": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - switch name { - case "X", "Y": - f.FieldType = "float64" - case "StpPriority": - f.FieldType = "string" - f.CustomUnmarshalType = "numberOrString" - case "Ht": - f.FieldType = "int" - case "Channel", "BackupChannel", "TxPower": - if f.FieldType == "string" { - f.CustomUnmarshalType = "numberOrString" - } - case "LteExtAnt", "LtePoe": - f.CustomUnmarshalType = "booleanishString" - } - - f.OmitEmpty = true - switch name { - case "PortOverrides": - f.OmitEmpty = false - } - - return nil - } - case "Network": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - switch name { - case "InternetAccessEnabled", "IntraNetworkAccessEnabled": - if f.FieldType == "bool" { - f.CustomUnmarshalType = "*bool" - f.CustomUnmarshalFunc = "emptyBoolToTrue" - } - } - return nil - } - case "SettingGlobalAp": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - if strings.HasPrefix(name, "6E") { - f.FieldName = strings.Replace(f.FieldName, "6E", "SixE", 1) - } - - return nil - } - case "SettingMgmt": - sshKeyField := NewFieldInfo(resource.StructName+"XSshKeys", "x_ssh_keys", "struct", "", false, false, "") - sshKeyField.Fields = map[string]*FieldInfo{ - "name": NewFieldInfo("Name", "name", "string", "", false, false, ""), - "keyType": NewFieldInfo("KeyType", "type", "string", "", false, false, ""), - "key": NewFieldInfo("Key", "key", "string", "", false, false, ""), - "comment": NewFieldInfo("Comment", "comment", "string", "", false, false, ""), - "date": NewFieldInfo("Date", "date", "string", "", false, false, ""), - "fingerprint": NewFieldInfo("Fingerprint", "fingerprint", "string", "", false, false, ""), - } - resource.Types[sshKeyField.FieldName] = sshKeyField - - resource.FieldProcessor = func(name string, f *FieldInfo) error { - if name == "XSshKeys" { - f.FieldType = sshKeyField.FieldName - } - return nil - } - case "SettingUsg": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - if strings.HasSuffix(name, "Timeout") && name != "ArpCacheTimeout" { - f.FieldType = "int" - f.CustomUnmarshalType = "emptyStringInt" - } - return nil - } - case "User": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - switch name { - case "Blocked": - f.FieldType = "bool" - case "LastSeen": - f.FieldType = "int" - f.CustomUnmarshalType = "emptyStringInt" - } - return nil - } - case "WLAN": - resource.FieldProcessor = func(name string, f *FieldInfo) error { - switch name { - case "ScheduleWithDuration": - // always send schedule, so we can empty it if we want to - f.OmitEmpty = false - } - return nil - } - } - - err = resource.processJSON(b) - if err != nil { - log.Warnf("skipping file %s: %s", fieldsFile.Name(), err) - continue - } - - var code string - if code, err = resource.generateCode(); err != nil { - log.Errorf("failed to generate code for %s: %s", fieldsFile.Name(), err) - continue - } - - goFilePath := filepath.Join(outDir, goFile) - _ = os.Remove(goFilePath) - if err := os.WriteFile(goFile, ([]byte)(code), 0o644); err != nil { - log.Errorf("failed to write file %s: %s", goFile, err) - continue - } - log.Debugf("Generated %s with resource %s\n\n", goFile, structName) + log.Infof("Writing version file...") + if err = writeVersionFile(unifiVersion.Version, outDir); err != nil { + log.Fatalf("failed to write version file to %s", outDir) + panic(err) } - // Write version file. - versionGo := []byte(fmt.Sprintf(` -// Generated code. DO NOT EDIT. - -package unifi - -const UnifiVersion = %q -`, unifiVersion)) - - versionGo, err = format.Source(versionGo) - if err != nil { - log.Errorf("failed to format version file: %s", err) - os.Exit(1) - } - - if err := os.WriteFile(filepath.Join(outDir, "version.generated.go"), versionGo, 0o644); err != nil { - log.Errorf("failed to write version file: %s", err) - os.Exit(1) - } - - log.Infof("Generated resources in %s\n", outDir) -} - -func (r *Resource) IsSetting() bool { - return strings.HasPrefix(r.StructName, "Setting") -} - -func (r *Resource) processFields(fields map[string]interface{}) { - t := r.Types[r.StructName] - for name, validation := range fields { - fieldInfo, err := r.fieldInfoFromValidation(name, validation) - if err != nil { - continue - } - - t.Fields[fieldInfo.FieldName] = fieldInfo - } -} - -func (r *Resource) fieldInfoFromValidation(name string, validation interface{}) (*FieldInfo, error) { - fieldName := strcase.ToCamel(name) - fieldName = cleanName(fieldName, fieldReps) - - empty := &FieldInfo{} - var fieldInfo *FieldInfo - - switch validation := validation.(type) { - case []interface{}: - if len(validation) == 0 { - fieldInfo = NewFieldInfo(fieldName, name, "string", "", false, true, "") - err := r.FieldProcessor(fieldName, fieldInfo) - return fieldInfo, err - } - if len(validation) > 1 { - return empty, fmt.Errorf("unknown validation %#v", validation) - } - - fieldInfo, err := r.fieldInfoFromValidation(name, validation[0]) - if err != nil { - return empty, err - } - - fieldInfo.OmitEmpty = true - fieldInfo.IsArray = true - - err = r.FieldProcessor(fieldName, fieldInfo) - return fieldInfo, err - - case map[string]interface{}: - typeName := r.StructName + fieldName - - result := NewFieldInfo(fieldName, name, typeName, "", true, false, "") - result.Fields = make(map[string]*FieldInfo) - - for name, fv := range validation { - child, err := r.fieldInfoFromValidation(name, fv) - if err != nil { - return empty, err - } - - result.Fields[child.FieldName] = child - } - - err := r.FieldProcessor(fieldName, result) - r.Types[typeName] = result - return result, err - - case string: - fieldValidation := validation - normalized := normalizeValidation(validation) - - omitEmpty := false - - switch { - case normalized == "falsetrue" || normalized == "truefalse": - fieldInfo = NewFieldInfo(fieldName, name, "bool", "", omitEmpty, false, "") - return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) - default: - if _, err := strconv.ParseFloat(normalized, 64); err == nil { - if normalized == "09" || normalized == "09.09" { - fieldValidation = "" - } - - if strings.Contains(normalized, ".") { - if strings.Contains(validation, "\\.){3}") { - break - } - - omitEmpty = true - fieldInfo = NewFieldInfo(fieldName, name, "float64", fieldValidation, omitEmpty, false, "") - return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) - } - - omitEmpty = true - fieldInfo = NewFieldInfo(fieldName, name, "int", fieldValidation, omitEmpty, false, "") - fieldInfo.CustomUnmarshalType = "emptyStringInt" - return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) - } - } - if validation != "" && normalized != "" { - log.Debugf("normalize %q to %q", validation, normalized) - } - - omitEmpty = omitEmpty || (!strings.Contains(validation, "^$") && !strings.HasSuffix(fieldName, "ID")) - fieldInfo = NewFieldInfo(fieldName, name, "string", fieldValidation, omitEmpty, false, "") - return fieldInfo, r.FieldProcessor(fieldName, fieldInfo) - } - - return empty, fmt.Errorf("unable to determine type from validation %q", validation) -} - -func (r *Resource) processJSON(b []byte) error { - var fields map[string]interface{} - err := json.Unmarshal(b, &fields) - if err != nil { - return err - } - - r.processFields(fields) - - return nil -} - -//go:embed api.go.tmpl -var apiGoTemplate string - -func (r *Resource) generateCode() (string, error) { - var err error - var buf bytes.Buffer - writer := io.Writer(&buf) - - tpl := template.Must(template.New("api.go.tmpl").Parse(apiGoTemplate)) - - err = tpl.Execute(writer, r) - if err != nil { - return "", fmt.Errorf("failed to render template: %w", err) - } - - src, err := format.Source(buf.Bytes()) - if err != nil { - return "", fmt.Errorf("failed to format source: %w", err) - } - - return string(src), err -} - -func normalizeValidation(re string) string { - re = strings.ReplaceAll(re, "\\d", "[0-9]") - re = strings.ReplaceAll(re, "[-+]?", "") - re = strings.ReplaceAll(re, "[+-]?", "") - re = strings.ReplaceAll(re, "[-]?", "") - re = strings.ReplaceAll(re, "\\.", ".") - re = strings.ReplaceAll(re, "[.]?", ".") - - quants := regexp.MustCompile(`\{\d*,?\d*\}|\*|\+|\?`) - re = quants.ReplaceAllString(re, "") - - control := regexp.MustCompile(`[\(\[\]\)\|\-\$\^]`) - re = control.ReplaceAllString(re, "") - - re = strings.TrimPrefix(re, "^") - re = strings.TrimSuffix(re, "$") - - return re + log.Infof("Generated resources in %s", outDir) } diff --git a/fields/version.go b/fields/version.go index db1ddc1..5ee894a 100644 --- a/fields/version.go +++ b/fields/version.go @@ -3,16 +3,38 @@ package main import ( "context" "encoding/json" + "errors" + "fmt" + "go/format" "net/http" "net/url" + "os" + "path/filepath" "github.com/hashicorp/go-version" ) -func latestUnifiVersion() (*version.Version, *url.URL, error) { +const ( + LatestVersionMarker = "latest" + baseDownloadUrl = "https://dl.ui.com/unifi/%s/unifi_sysvinit_all.deb" +) + +type UnifiVersion struct { + Version *version.Version + DownloadUrl *url.URL +} + +func NewUnifiVersion(unifiVersion *version.Version, downloadUrl *url.URL) *UnifiVersion { + return &UnifiVersion{ + Version: unifiVersion, + DownloadUrl: downloadUrl, + } +} + +func latestUnifiVersion() (*UnifiVersion, error) { url, err := url.Parse(firmwareUpdateApi) if err != nil { - return nil, nil, err + return nil, err } query := url.Query() @@ -22,29 +44,63 @@ func latestUnifiVersion() (*version.Version, *url.URL, error) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url.String(), nil) if err != nil { - return nil, nil, err + return nil, err } client := &http.Client{} resp, err := client.Do(req) if err != nil { - return nil, nil, err + return nil, err } defer resp.Body.Close() var respData firmwareUpdateApiResponse err = json.NewDecoder(resp.Body).Decode(&respData) if err != nil { - return nil, nil, err + return nil, err } for _, firmware := range respData.Embedded.Firmware { if firmware.Platform != debianPlatform { continue } - - return firmware.Version.Core(), firmware.Links.Data.Href, nil + return NewUnifiVersion(firmware.Version.Core(), firmware.Links.Data.Href), nil } - return nil, nil, nil + return nil, errors.New("no Unifi Controller firmware found") +} + +func determineUnifiVersion(versionMarker string) (*UnifiVersion, error) { + if versionMarker == LatestVersionMarker { + return latestUnifiVersion() + } else { + unifiVersion, err := version.NewVersion(versionMarker) + if err != nil { + return nil, err + } + unifiVersion = unifiVersion.Core() + downloadUrl := fmt.Sprintf(baseDownloadUrl, unifiVersion) + unifiDownloadUrl, err := url.Parse(downloadUrl) + if err != nil { + return nil, err + } + return NewUnifiVersion(unifiVersion, unifiDownloadUrl), nil + } +} + +func writeVersionFile(version *version.Version, outDir string) error { + versionGo := []byte(fmt.Sprintf(` +// Generated code. DO NOT EDIT. + +package unifi + +const UnifiVersion = %q +`, version.Core())) + + versionGo, err := format.Source(versionGo) + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(outDir, "version.generated.go"), versionGo, 0o644) } diff --git a/fields/version_test.go b/fields/version_test.go index 9637a98..2c41287 100644 --- a/fields/version_test.go +++ b/fields/version_test.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -12,9 +13,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestLatestUnifiVersion(t *testing.T) { - t.Parallel() - +func assertLatestVersionUsingProvider(t *testing.T, provider func() (*UnifiVersion, error)) { + t.Helper() assert := assert.New(t) require := require.New(t) @@ -84,9 +84,65 @@ func TestLatestUnifiVersion(t *testing.T) { defer server.Close() firmwareUpdateApi = server.URL - gotVersion, gotDownload, err := latestUnifiVersion() + gotVersion, err := provider() require.NoError(err) - assert.Equal(fwVersion.Core(), gotVersion) - assert.Equal(fwDownload, gotDownload) + assert.Equal(fwVersion.Core(), gotVersion.Version) + assert.Equal(fwDownload, gotVersion.DownloadUrl) +} + +func TestLatestUnifiVersion(t *testing.T) { + t.Parallel() + assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) { + return latestUnifiVersion() + }) +} + +func TestDetermineUnifiVersion_latest(t *testing.T) { + t.Parallel() + assertLatestVersionUsingProvider(t, func() (*UnifiVersion, error) { + return determineUnifiVersion(LatestVersionMarker) + }) +} + +func TestDetermineUnifiVersion_provided(t *testing.T) { + t.Parallel() + testCases := map[string]string{ + "7.3.83+atag-7.3.83-19645": "7.3.83", + "7.3.83": "7.3.83", + "7.3": "7.3.0", + "7": "7.0.0", + } + + for providedVersion, expectedVersion := range testCases { + t.Run(providedVersion, func(t *testing.T) { + t.Parallel() + assert := assert.New(t) + require := require.New(t) + + unifiVersion, err := determineUnifiVersion(providedVersion) + require.NoError(err) + + assert.Equal(expectedVersion, unifiVersion.Version.String()) + assert.Equal(fmt.Sprintf(baseDownloadUrl, expectedVersion), unifiVersion.DownloadUrl.String()) + }) + } +} + +func TestDetermineUnifiVersion_invalid(t *testing.T) { + t.Parallel() + testCases := []string{ + "invalid", + "-1", + "", + } + assert := assert.New(t) + + for _, providedVersion := range testCases { + t.Run(providedVersion, func(t *testing.T) { + t.Parallel() + _, err := determineUnifiVersion(providedVersion) + assert.ErrorContains(err, providedVersion) + }) + } } diff --git a/unifi/fields.go b/unifi/fields.go index 554570d..9b1e1a1 100644 --- a/unifi/fields.go +++ b/unifi/fields.go @@ -2,4 +2,4 @@ package unifi // This will generate the *.generated.go files in this package for the specified // Unifi controller version. -//go:generate go run ../fields/ -version-base-dir=../fields/ -latest +//go:generate go run ../fields/ -version-base-dir=../fields/ latest