From 7d78ccaaef811689e3ee5123c44ec5d8d2adffb3 Mon Sep 17 00:00:00 2001 From: Kristiyan Gostev Date: Fri, 10 May 2024 11:23:17 +0300 Subject: [PATCH] Merge `dev-m5` branch into main (#72) [#73] Merge `dev-m5` branch into `main` * [#64] Refactor flags implementation (#65) * [#3] MQTTS support in the software-update's local connection (#63) --------- Signed-off-by: Antonia Avramova Signed-off-by: Kristiyan Gostev Co-authored-by: Antonia Avramova --- cmd/software-update/main.go | 8 +- internal/command.go | 37 +++- internal/duration.go | 2 +- internal/edge.go | 22 +++ internal/feature.go | 99 ++++++++-- internal/feature_test.go | 14 +- internal/flags.go | 303 +++++++++--------------------- internal/flags_test.go | 267 ++++++++++++-------------- internal/logger/logger.go | 10 +- internal/path_args.go | 20 +- internal/path_args_test.go | 58 ++++++ internal/utils_test.go | 16 +- util/tls/testdata/ca.crt | 22 +++ util/tls/testdata/certificate.pem | 27 +++ util/tls/testdata/empty.crt | 0 util/tls/testdata/invalid.pem | 1 + util/tls/testdata/key.pem | 27 +++ util/tls/tls_config.go | 60 ++++++ util/tls/tls_config_test.go | 104 ++++++++++ 19 files changed, 690 insertions(+), 407 deletions(-) create mode 100644 internal/path_args_test.go create mode 100644 util/tls/testdata/ca.crt create mode 100644 util/tls/testdata/certificate.pem create mode 100644 util/tls/testdata/empty.crt create mode 100644 util/tls/testdata/invalid.pem create mode 100644 util/tls/testdata/key.pem create mode 100644 util/tls/tls_config.go create mode 100644 util/tls/tls_config_test.go diff --git a/cmd/software-update/main.go b/cmd/software-update/main.go index 5b6cd11..a40265d 100644 --- a/cmd/software-update/main.go +++ b/cmd/software-update/main.go @@ -25,23 +25,23 @@ var version = "N/A" func main() { // Initialize flags. - suConfig, logConfig, err := feature.InitFlags(version) + cfg, err := feature.LoadConfig(version) if err != nil { fmt.Println(err) os.Exit(1) } // Initialize logs. - loggerOut := logger.SetupLogger(logConfig) + loggerOut := logger.SetupLogger(&cfg.LogConfig) defer loggerOut.Close() - if err := suConfig.Validate(); err != nil { + if err := cfg.Validate(); err != nil { logger.Errorf("failed to validate script-based software updatable configuration: %v\n", err) os.Exit(1) } // Create new Script-Based software updatable - edgeCtr, err := feature.InitScriptBasedSU(suConfig) + edgeCtr, err := feature.InitScriptBasedSU(&cfg.ScriptBasedSoftwareUpdatableConfig) if err != nil { logger.Errorf("failed to create script-based software updatable: %v", err) os.Exit(1) diff --git a/internal/command.go b/internal/command.go index 05339e7..3df4d03 100644 --- a/internal/command.go +++ b/internal/command.go @@ -13,6 +13,7 @@ package feature import ( + "encoding/json" "fmt" "os/exec" "path/filepath" @@ -22,11 +23,13 @@ import ( "github.com/eclipse-kanto/software-update/internal/logger" ) +// command is custom type of command name and arguments of command in order to add json unmarshal support type command struct { cmd string args []string } +// String is representation of command as combination of name and arguments of the command func (i *command) String() string { if len(i.args) == 0 { return i.cmd @@ -34,20 +37,25 @@ func (i *command) String() string { return fmt.Sprint(i.cmd, " ", strings.Join(i.args, " ")) } +// Set command from string, used for flag set func (i *command) Set(value string) error { if i.cmd == "" { - i.cmd = value - i.args = []string{} - if runtime.GOOS != "windows" && strings.HasSuffix(value, ".sh") { - i.cmd = "/bin/sh" - i.args = []string{value} - } + i.setCommand(value) } else { i.args = append(i.args, value) } return nil } +func (i *command) setCommand(value string) { + i.cmd = value + i.args = []string{} + if runtime.GOOS != "windows" && strings.HasSuffix(value, ".sh") { + i.cmd = "/bin/sh" + i.args = []string{value} + } +} + func (i *command) run(dir string, def string) (err error) { script := i.cmd args := i.args @@ -71,3 +79,20 @@ func (i *command) run(dir string, def string) (err error) { } return err } + +// UnmarshalJSON unmarshal command type +func (i *command) UnmarshalJSON(b []byte) error { + var v []string + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + for num, elem := range v { + if num == 0 { + i.setCommand(elem) + } else { + i.args = append(i.args, elem) + } + } + return nil +} diff --git a/internal/duration.go b/internal/duration.go index b24d4c2..9341191 100644 --- a/internal/duration.go +++ b/internal/duration.go @@ -27,8 +27,8 @@ func (d *durationTime) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &v); err != nil { return err } - switch value := v.(type) { + switch value := v.(type) { case string: duration, err := time.ParseDuration(value) if err != nil { diff --git a/internal/edge.go b/internal/edge.go index ff7a92f..dcc71c5 100644 --- a/internal/edge.go +++ b/internal/edge.go @@ -14,8 +14,10 @@ package feature import ( "encoding/json" + "net/url" "github.com/eclipse-kanto/software-update/internal/logger" + "github.com/eclipse-kanto/software-update/util/tls" MQTT "github.com/eclipse/paho.mqtt.golang" "github.com/google/uuid" @@ -58,6 +60,17 @@ func newEdgeConnector(scriptSUPConfig *ScriptBasedSoftwareUpdatableConfig, ecl e if len(scriptSUPConfig.Username) > 0 { opts = opts.SetUsername(scriptSUPConfig.Username).SetPassword(scriptSUPConfig.Password) } + u, err := url.Parse(scriptSUPConfig.Broker) + if err != nil { + return nil, err + } + if isConnectionSecure(u.Scheme) { + tlsConfig, err := tls.NewTLSConfig(scriptSUPConfig.CACert, scriptSUPConfig.Cert, scriptSUPConfig.Key) + if err != nil { + return nil, err + } + opts.SetTLSConfig(tlsConfig) + } p := &EdgeConnector{mqttClient: MQTT.NewClient(opts), edgeClient: ecl} if token := p.mqttClient.Connect(); token.Wait() && token.Error() != nil { @@ -98,6 +111,15 @@ func newEdgeConnector(scriptSUPConfig *ScriptBasedSoftwareUpdatableConfig, ecl e return p, nil } +func isConnectionSecure(schema string) bool { + switch schema { + case "wss", "ssl", "tls", "mqtts", "mqtt+ssl", "tcps": + return true + default: + } + return false +} + // Close the EdgeConnector func (p *EdgeConnector) Close() { if p.cfg != nil { diff --git a/internal/feature.go b/internal/feature.go index 61a518f..d6b6d66 100644 --- a/internal/feature.go +++ b/internal/feature.go @@ -27,15 +27,36 @@ import ( ) const ( - defaultDisconnectTimeout = 250 * time.Millisecond - defaultKeepAlive = 20 * time.Second - modeStrict = "strict" modeScoped = "scoped" modeLax = "lax" typeArchive = "archive" typePlain = "plain" + + defaultDisconnectTimeout = 250 * time.Millisecond + defaultKeepAlive = 20 * time.Second + defaultBroker = "tcp://localhost:1883" + defaultUsername = "" + defaultPassword = "" + defaultCACert = "" + defaultCert = "" + defaultKey = "" + defaultStorageLocation = "." + defaultFeatureID = "SoftwareUpdatable" + defaultModuleType = "software" + defaultArtifactType = "archive" + defaultServerCert = "" + defaultDownloadRetryCount = 0 + defaultDownloadRetryInterval = "5s" + defaultInstallDirs = "" + defaultMode = modeStrict + defaultInstallCommand = "" + defaultLogFile = "log/software-update.log" + defaultLogLevel = "INFO" + defaultLogFileSize = 2 + defaultLogFileCount = 5 + defaultLogFileMaxAge = 28 ) var ( @@ -49,19 +70,22 @@ type operationFunc func() bool // ScriptBasedSoftwareUpdatableConfig provides the Script-Based SoftwareUpdatable configuration. type ScriptBasedSoftwareUpdatableConfig struct { - Broker string - Username string - Password string - StorageLocation string - FeatureID string - ModuleType string - ArtifactType string - ServerCert string - DownloadRetryCount int - DownloadRetryInterval durationTime - InstallDirs pathArgs - Mode string - InstallCommand command + Broker string `json:"broker,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + CACert string `json:"caCert,omitempty"` + Cert string `json:"cert,omitempty"` + Key string `json:"key,omitempty"` + StorageLocation string `json:"storageLocation,omitempty"` + FeatureID string `json:"featureId,omitempty"` + ModuleType string `json:"moduleType,omitempty"` + ArtifactType string `json:"artifactType,omitempty"` + ServerCert string `json:"serverCert,omitempty"` + DownloadRetryCount int `json:"downloadRetryCount,omitempty"` + DownloadRetryInterval durationTime `json:"downloadRetryInterval,omitempty"` + InstallDirs []string `json:"installDirs,omitempty"` + Mode string `json:"mode,omitempty"` + InstallCommand command `json:"install,omitempty"` } // ScriptBasedSoftwareUpdatable is the Script-Based SoftwareUpdatable actual implementation. @@ -81,6 +105,47 @@ type ScriptBasedSoftwareUpdatable struct { installCommand *command } +// BasicConfig combine ScriptBaseSoftwareUpdatable configuration and Log configuration +type BasicConfig struct { + ScriptBasedSoftwareUpdatableConfig + logger.LogConfig + ConfigFile string `json:"configFile,omitempty"` +} + +// NewDefaultConfig returns a default mqtt client connection config instance +func NewDefaultConfig() *BasicConfig { + duration, err := time.ParseDuration(defaultDownloadRetryInterval) + if err != nil { + duration = 0 + } + return &BasicConfig{ + ScriptBasedSoftwareUpdatableConfig: ScriptBasedSoftwareUpdatableConfig{ + Broker: defaultBroker, + Username: defaultUsername, + Password: defaultPassword, + CACert: defaultCACert, + Cert: defaultCert, + Key: defaultKey, + StorageLocation: defaultStorageLocation, + FeatureID: defaultFeatureID, + ModuleType: defaultModuleType, + ArtifactType: defaultArtifactType, + ServerCert: defaultServerCert, + DownloadRetryCount: defaultDownloadRetryCount, + Mode: defaultMode, + DownloadRetryInterval: durationTime(duration), + InstallDirs: make([]string, 0), + }, + LogConfig: logger.LogConfig{ + LogFile: defaultLogFile, + LogLevel: defaultLogLevel, + LogFileSize: defaultLogFileSize, + LogFileCount: defaultLogFileCount, + LogFileMaxAge: defaultLogFileMaxAge, + }, + } +} + // InitScriptBasedSU creates a new Script-Based SoftwareUpdatable instance, listening for edge configuration. func InitScriptBasedSU(scriptSUPConfig *ScriptBasedSoftwareUpdatableConfig) (*EdgeConnector, error) { logger.Infof("New Script-Based SoftwareUpdatable [Broker: %s, Type: %s]", @@ -103,7 +168,7 @@ func InitScriptBasedSU(scriptSUPConfig *ScriptBasedSoftwareUpdatableConfig) (*Ed // Interval between download reattempts downloadRetryInterval: time.Duration(scriptSUPConfig.DownloadRetryInterval), // Install locations for local artifacts - installDirs: scriptSUPConfig.InstallDirs.args, + installDirs: scriptSUPConfig.InstallDirs, // Access mode for local artifacts accessMode: initAccessMode(scriptSUPConfig.Mode), // Define the module artifact(s) type: archive or plain diff --git a/internal/feature_test.go b/internal/feature_test.go index d753cc4..03a7dc9 100644 --- a/internal/feature_test.go +++ b/internal/feature_test.go @@ -99,9 +99,9 @@ func TestScriptBasedInitLoadDependencies(t *testing.T) { // 1. Try to init a new ScriptBasedSoftwareUpdatable with error for loading install dependencies _, _, err := mockScriptBasedSoftwareUpdatable(t, &testConfig{ - clientConnected: true, storageLocation: dir, featureID: getDefaultFlagValue(t, flagFeatureID)}) + clientConnected: true, storageLocation: dir, featureID: NewDefaultConfig().FeatureID}) if err == nil { - t.Fatalf("expected to fail when mandatory field is missing in insalled dept file") + t.Fatalf("expected to fail when mandatory field is missing in installed dept file") } } @@ -114,7 +114,7 @@ func TestScriptBasedInit(t *testing.T) { // 1. Try to init a new ScriptBasedSoftwareUpdatable with error for not connected client _, _, err := mockScriptBasedSoftwareUpdatable(t, &testConfig{ - clientConnected: false, storageLocation: dir, featureID: getDefaultFlagValue(t, flagFeatureID)}) + clientConnected: false, storageLocation: dir, featureID: NewDefaultConfig().FeatureID}) if err == nil { t.Fatal("ditto Client shall not be connected!") } @@ -150,7 +150,7 @@ func testScriptBasedSoftwareUpdatableOperations(noResume bool, t *testing.T) { // 1. Try to init a new ScriptBasedSoftwareUpdatable. feature, mc, err := mockScriptBasedSoftwareUpdatable(t, &testConfig{ - clientConnected: true, featureID: getDefaultFlagValue(t, flagFeatureID), storageLocation: dir}) + clientConnected: true, featureID: NewDefaultConfig().FeatureID, storageLocation: dir}) if err != nil { t.Fatalf("failed to initialize ScriptBasedSoftwareUpdatable: %v", err) } @@ -195,7 +195,7 @@ func testDisconnectWhileRunningOperation(feature *ScriptBasedSoftwareUpdatable, statuses = append(statuses, pullStatusChanges(mc, postDisconnectEventCount)...) waitDisconnect.Wait() - defer connectFeature(t, mc, feature, getDefaultFlagValue(t, flagFeatureID)) + defer connectFeature(t, mc, feature, NewDefaultConfig().FeatureID) if install { checkInstallStatusEvents(0, statuses, t) } else { @@ -212,7 +212,7 @@ func TestScriptBasedDownloadAndInstallMixedResources(t *testing.T) { defer os.RemoveAll(storageDir) feature, mc, err := mockScriptBasedSoftwareUpdatable(t, &testConfig{ - clientConnected: true, featureID: getDefaultFlagValue(t, flagFeatureID), storageLocation: storageDir, mode: modeLax, + clientConnected: true, featureID: NewDefaultConfig().FeatureID, storageLocation: storageDir, mode: modeLax, }) if err != nil { t.Fatalf("failed to initialize ScriptBasedSoftwareUpdatable: %v", err) @@ -300,7 +300,7 @@ func testScriptBasedSoftwareUpdatableOperationsLocal(t *testing.T, installDirs [ defer os.RemoveAll(dir) feature, mc, err := mockScriptBasedSoftwareUpdatable(t, &testConfig{ - clientConnected: true, featureID: getDefaultFlagValue(t, flagFeatureID), storageLocation: dir, + clientConnected: true, featureID: NewDefaultConfig().FeatureID, storageLocation: dir, installDirs: installDirs, mode: mode}) if err != nil { t.Fatalf("failed to initialize ScriptBasedSoftwareUpdatable: %v", err) diff --git a/internal/flags.go b/internal/flags.go index d5b57ca..a2e364b 100644 --- a/internal/flags.go +++ b/internal/flags.go @@ -16,252 +16,133 @@ import ( "encoding/json" "flag" "fmt" - "io/ioutil" - "log" + "io" "os" - "reflect" - "strconv" "strings" - "unicode" + "time" "github.com/eclipse-kanto/software-update/internal/logger" ) const ( - flagVersion = "version" - flagConfigFile = "configFile" - flagInstall = "install" - flagInstallDirs = "installDirs" + flagConfigFile = "configFile" + flagInstall = "install" ) var ( - suConfig = &ScriptBasedSoftwareUpdatableConfig{} - logConfig = &logger.LogConfig{} - - descriptions = map[string]string{ - "mode": "Artifact access mode. Restricts where local file system artifacts can be located.\nAllowed values are:" + - "\n 'strict' - artifacts can only be located in directories, included in installDirs property value" + - "\n 'scoped' - artifacts can only be located in directories or their subdirectories recursively, included in installDirs property value" + - "\n 'lax' - artifacts can be located anywhere on local file system. Use with care!", - } + modeDescription = "Artifact access mode. Restricts where local file system artifacts can be located.\nAllowed values are:" + + "\n 'strict' - artifacts can only be located in directories, included in installDirs property value" + + "\n 'scoped' - artifacts can only be located in directories or their subdirectories recursively, included in installDirs property value" + + "\n 'lax' - artifacts can be located anywhere on local file system. Use with care!" ) -type cfg struct { - Broker string `json:"broker" def:"tcp://localhost:1883" descr:"Local MQTT broker address"` - Username string `json:"username" descr:"Username for authorized local client"` - Password string `json:"password" descr:"Password for authorized local client"` - StorageLocation string `json:"storageLocation" def:"." descr:"Location of the storage"` - FeatureID string `json:"featureId" def:"SoftwareUpdatable" descr:"Feature identifier of SoftwareUpdatable"` - ModuleType string `json:"moduleType" def:"software" descr:"Module type of SoftwareUpdatable"` - ArtifactType string `json:"artifactType" def:"archive" descr:"Defines the module artifact type: archive or plain"` - Install []string `json:"install" descr:"Defines the absolute path to install script"` - ServerCert string `json:"serverCert" descr:"A PEM encoded certificate \"file\" for secure artifact download"` - DownloadRetryCount int `json:"downloadRetryCount" def:"0" descr:"Number of retries, in case of a failed download.\n By default no retries are supported."` - DownloadRetryInterval durationTime `json:"downloadRetryInterval" def:"5s" descr:"Interval between retries, in case of a failed download.\n Should be a sequence of decimal numbers, each with optional fraction and a unit suffix, such as '300ms', '1.5h', '10m30s', etc. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'."` - InstallDirs []string `json:"installDirs" descr:"Local file system directories, where to search for module artifacts"` - Mode string `json:"mode" def:"strict" descr:"%s"` - LogFile string `json:"logFile" def:"log/software-update.log" descr:"Log file location in storage directory"` - LogLevel string `json:"logLevel" def:"INFO" descr:"Log levels are ERROR, WARN, INFO, DEBUG, TRACE"` - LogFileSize int `json:"logFileSize" def:"2" descr:"Log file size in MB before it gets rotated"` - LogFileCount int `json:"logFileCount" def:"5" descr:"Log file max rotations count"` - LogFileMaxAge int `json:"logFileMaxAge" def:"28" descr:"Log file rotations max age in days"` -} - // InitFlags tries to initialize Script-Based SoftwareUpdatable and Log configurations. // Returns true if version flag is specified for print version and exit. Returns error // if JSON configuration file cannot be read properly or missing config file is specified with flag. -func InitFlags(version string) (*ScriptBasedSoftwareUpdatableConfig, *logger.LogConfig, error) { - flgConfig := &cfg{} - printVersion := flag.Bool(flagVersion, false, "Prints current version and exits") - configFile := flag.String(flagConfigFile, "", "Defines the configuration file") - - // the install flag is set in the config object initially - flag.Var(&suConfig.InstallCommand, flagInstall, "Defines the absolute path to install script") - flag.Var(&suConfig.InstallDirs, flagInstallDirs, "Local file system directories, where to search for module artifacts") - - initFlagsWithDefaultValues(flgConfig) - flag.Parse() - if *printVersion { - fmt.Println(version) - os.Exit(0) - } - - if err := applyConfigurationFile(*configFile); err != nil { - return nil, nil, err - } - applyFlags(flgConfig) - return suConfig, logConfig, nil +func InitFlags(flagSet *flag.FlagSet, cfg *BasicConfig) { + // init log flags + flagSet.StringVar(&cfg.LogLevel, "logLevel", cfg.LogLevel, "Log levels are ERROR, WARN, INFO, DEBUG, TRACE") + flagSet.StringVar(&cfg.LogFile, "logFile", cfg.LogFile, "Log file location in storage directory") + flagSet.IntVar(&cfg.LogFileSize, "logFileSize", cfg.LogFileSize, "Log file size in MB before it gets rotated") + flagSet.IntVar(&cfg.LogFileCount, "logFileCount", cfg.LogFileCount, "Log file max rotations count") + flagSet.IntVar(&cfg.LogFileMaxAge, "logFileMaxAge", cfg.LogFileMaxAge, "Log file rotations max age in days") + + // init connection flags + flagSet.StringVar(&cfg.Broker, "broker", cfg.Broker, "Local MQTT broker address") + flagSet.StringVar(&cfg.Username, "username", cfg.Username, "Username that is a part of the credentials") + flagSet.StringVar(&cfg.Password, "password", cfg.Password, "Password that is a part of the credentials") + flagSet.StringVar(&cfg.CACert, "caCert", cfg.CACert, "A PEM encoded CA certificates file for MQTT broker connection") + flagSet.StringVar(&cfg.Cert, "cert", cfg.Cert, "A PEM encoded certificate file to authenticate to the MQTT server/broker") + flagSet.StringVar(&cfg.Key, "key", cfg.Key, "A PEM encoded unencrypted private key file to authenticate to the MQTT server/broker") + flagSet.StringVar(&cfg.StorageLocation, "storageLocation", cfg.StorageLocation, "Location of the storage") + flagSet.StringVar(&cfg.FeatureID, "featureId", cfg.FeatureID, "Feature identifier of SoftwareUpdatable") + flagSet.StringVar(&cfg.ModuleType, "moduleType", cfg.ModuleType, "Module type of SoftwareUpdatable") + flagSet.StringVar(&cfg.ArtifactType, "artifactType", cfg.ArtifactType, "Defines the module artifact type: archive or plain") + flagSet.StringVar(&cfg.ServerCert, "serverCert", cfg.ServerCert, "A PEM encoded certificate 'file' for secure artifact download") + flagSet.IntVar(&cfg.DownloadRetryCount, "downloadRetryCount", cfg.DownloadRetryCount, "Number of retries, in case of a failed download. By default no retries are supported.") + flagSet.DurationVar((*time.Duration)(&cfg.DownloadRetryInterval), "downloadRetryInterval", (time.Duration)(cfg.DownloadRetryInterval), "Interval between retries, in case of a failed download. Should be a sequence of decimal numbers, each with optional fraction and a unit suffix, such as '300ms', '1.5h', '10m30s', etc. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'") + + flagSet.StringVar(&cfg.Mode, "mode", cfg.Mode, modeDescription) + + flagSet.Var(&cfg.InstallCommand, flagInstall, "Defines the absolute path to install script") + flagSet.Var(newPathArgs(&cfg.InstallDirs), "installDirs", "Local file system directories, where to search for module artifacts") + flagSet.StringVar(&cfg.ConfigFile, flagConfigFile, cfg.ConfigFile, "Defines the configuration file") } -func initFlagsWithDefaultValues(config interface{}) { - valueOf := reflect.ValueOf(config).Elem() - typeOf := valueOf.Type() - for i := 0; i < typeOf.NumField(); i++ { - fieldType := typeOf.Field(i) - defaultValue := fieldType.Tag.Get("def") - description := fieldType.Tag.Get("descr") - fieldValue := valueOf.FieldByName(fieldType.Name) - pointer := fieldValue.Addr().Interface() - flagName := toFlagName(fieldType.Name) - if val, ok := descriptions[flagName]; ok { - description = fmt.Sprintf(description, val) - } - switch val := fieldValue.Interface(); val.(type) { - case string: - flag.StringVar(pointer.(*string), flagName, defaultValue, description) - case int: - value, err := strconv.Atoi(defaultValue) - if err != nil { - log.Printf("error parsing integer argument %v with value %v", fieldType.Name, defaultValue) - } - flag.IntVar(pointer.(*int), flagName, value, description) - case durationTime: - v, ok := pointer.(flag.Value) - if ok { - flag.Var(v, flagName, description) - } else { - log.Println("custom type Duration must implement reflect.Value interface") - } - } +// ParseConfigFilePath returns the value for configuration file path if set. +func ParseConfigFilePath() string { + var cfgFilePath string + flagSet := flag.NewFlagSet("", flag.ContinueOnError) + flagSet.SetOutput(io.Discard) + flagSet.StringVar(&cfgFilePath, flagConfigFile, "", "Defines the configuration file") + if err := flagSet.Parse(getFlagArgs(flagConfigFile)); err != nil { + logger.Errorf("Cannot parse the configFile flag: %v", err) } + return cfgFilePath } -func loadDefaultValues() *cfg { - result := &cfg{} - valueOf := reflect.ValueOf(result).Elem() - typeOf := valueOf.Type() - for i := 0; i < typeOf.NumField(); i++ { - fieldType := typeOf.Field(i) - defaultValue := fieldType.Tag.Get("def") - fieldValue := valueOf.FieldByName(fieldType.Name) - pointer := fieldValue.Addr().Interface() - if len(defaultValue) > 0 { - fieldValue := valueOf.FieldByName(fieldType.Name) - switch fieldValue.Interface().(type) { - case string: - fieldValue.Set(reflect.ValueOf(defaultValue)) - case int: - value, err := strconv.Atoi(defaultValue) - if err != nil { - log.Printf("error parsing integer argument %v with value %v", fieldType.Name, defaultValue) - } - fieldValue.Set(reflect.ValueOf(value)) - case durationTime: - v, ok := pointer.(flag.Value) - if ok { - if err := v.Set(defaultValue); err == nil { - fieldValue.Set(reflect.ValueOf(v).Elem()) - } else { - log.Printf("error parsing argument %v with value %v - %v", fieldType.Name, defaultValue, err) - } - } else { - log.Println("custom type Duration must implement reflect.Value interface") - } - } - +func getFlagArgs(flag string) []string { + args := os.Args[1:] + flag1 := "-" + flag + flag2 := "--" + flag + for index, arg := range args { + if strings.HasPrefix(arg, flag1+"=") || strings.HasPrefix(arg, flag2+"=") { + return []string{arg} + } + if (arg == flag1 || arg == flag2) && index < len(args)-1 { + return args[index : index+2] } } - return result + return []string{} } -func applyFlags(flagsConfig interface{}) { - flagsConfigVal := reflect.ValueOf(flagsConfig).Elem() - suConfigVal := reflect.ValueOf(suConfig).Elem() - logConfigVal := reflect.ValueOf(logConfig).Elem() - flag.Visit(func(f *flag.Flag) { - name := toFieldName(f.Name) - srcFieldVal := flagsConfigVal.FieldByName(name) - if srcFieldVal.Kind() != reflect.Invalid && srcFieldVal.Kind() != reflect.Struct { - dstFieldSu := suConfigVal.FieldByName(name) - dstFieldLog := logConfigVal.FieldByName(name) - if dstFieldSu.Kind() != reflect.Invalid && dstFieldSu.Kind() != reflect.Struct { - dstFieldSu.Set(srcFieldVal) - } else if dstFieldLog.Kind() != reflect.Invalid { - dstFieldLog.Set(srcFieldVal) - } - } - }) -} +func parseFlags(cfg *BasicConfig, version string) { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flagSet := flag.CommandLine -func applyConfigurationFile(configFile string) error { - def := loadDefaultValues() + InitFlags(flagSet, cfg) - // Load configuration file (if possible) - if len(configFile) > 0 { - if jf, err := os.Open(configFile); err == nil { - defer jf.Close() - if buf, err := ioutil.ReadAll(jf); err == nil { - if err := json.Unmarshal(buf, def); err != nil { - return err - } - } - } else { - return fmt.Errorf("unable to locate or load the config file: %v", configFile) + fVersion := flagSet.Bool("version", false, "Prints current version and exits") + args := os.Args[1:] + flag1 := "-" + flagInstall + flag2 := "--" + flagInstall + for _, arg := range args { + if strings.HasPrefix(arg, flag1+"=") || strings.HasPrefix(arg, flag2+"=") { + cfg.InstallCommand = command{} + break } } - - // Fulfil SoftwareUpdatable configuration with default/configuration values. - copyConfigData(def, suConfig) - - // Set install command only if install flag was not visited. - // If visited, the command is set initially in the config struct when it is defined. - if len(def.Install) > 0 && isNotVisited(flagInstall) { - for _, v := range def.Install { - suConfig.InstallCommand.Set(v) - } + if err := flagSet.Parse(args); err != nil { + logger.Errorf("Cannot parse command flags: %v", err) } - if len(def.InstallDirs) > 0 && isNotVisited(flagInstallDirs) { - for _, v := range def.InstallDirs { - suConfig.InstallDirs.Set(v) - } + if *fVersion { + fmt.Println(version) + os.Exit(0) } - - // Fulfil Log configuration with default/configuration values. - copyConfigData(def, logConfig) - return nil } -func isNotVisited(name string) bool { - res := true - flag.Visit(func(f *flag.Flag) { - if f.Name == name { - res = false - } - }) - return res -} - -func copyConfigData(sourceConfig interface{}, targetConfig interface{}) { - sourceConfigVal := reflect.ValueOf(sourceConfig).Elem() - targetConfigVal := reflect.ValueOf(targetConfig).Elem() - typeOfSourceConfig := sourceConfigVal.Type() - for i := 0; i < typeOfSourceConfig.NumField(); i++ { - fieldType := typeOfSourceConfig.Field(i) - fieldToSet := targetConfigVal.FieldByName(fieldType.Name) - if fieldToSet.Kind() != reflect.Invalid && fieldToSet.Kind() != reflect.Struct { - fieldToSet.Set(sourceConfigVal.FieldByName(fieldType.Name)) - } +// LoadConfigFromFile reads the file contents and unmarshal them into the given config structure. +func LoadConfigFromFile(filePath string, config interface{}) error { + if !isFile(filePath) { + return fmt.Errorf("incorrect config file %s", filePath) } + file, err := os.ReadFile(filePath) + if err != nil { + return err + } + return json.Unmarshal(file, config) } -func toFieldName(s string) string { - s = replaceSuffix(s, "Id", "ID") - rn := []rune(s) - rn[0] = unicode.ToUpper(rn[0]) - return string(rn) -} - -func toFlagName(s string) string { - s = replaceSuffix(s, "ID", "Id") - rn := []rune(s) - rn[0] = unicode.ToLower(rn[0]) - return string(rn) -} - -func replaceSuffix(s, suff, replacement string) string { - if strings.HasSuffix(s, suff) { - return s[:len(s)-len(suff)] + replacement +// LoadConfig loads a new configuration instance using flags and config file (if set). +func LoadConfig(version string) (*BasicConfig, error) { + configFilePath := ParseConfigFilePath() + config := NewDefaultConfig() + if configFilePath != "" { + if err := LoadConfigFromFile(configFilePath, config); err != nil { + return nil, err + } } - return s + parseFlags(config, version) + return config, nil } diff --git a/internal/flags_test.go b/internal/flags_test.go index e42338e..e62e168 100644 --- a/internal/flags_test.go +++ b/internal/flags_test.go @@ -106,31 +106,34 @@ func TestFlagsHasHigherPriority(t *testing.T) { // 1. Test with log field writeToConfigFile(t, "{\""+flagLogLevel+"\": \"TRACE\"}") setFlags([]string{c(flagConfigFile, testConfigFileDirs), c(flagLogLevel, expectedResult)}) - _, lc, err := InitFlags(testVersion) + cfg, err := LoadConfig(testVersion) if err != nil { t.Errorf("not expecting error when initializing flags with log level: %v", err) } - if lc.LogLevel != expectedResult { - t.Errorf("unmatching result for flags high priority, found %v instead of %v", lc.LogLevel, expectedResult) + if cfg.LogLevel != expectedResult { + t.Errorf("unmatching result for flags high priority, found %v instead of %v", cfg.LogLevel, expectedResult) } expectedResult = "FeatureTestUpdatable" // 2. Test with software updatable config field writeToConfigFile(t, "{\""+flagFeatureID+"\": \"WrongFeatureID\"}") setFlags([]string{c(flagConfigFile, testConfigFileDirs), c(flagFeatureID, expectedResult)}) - sc, _, err := InitFlags(testVersion) + cfg, err = LoadConfig(testVersion) if err != nil { t.Errorf("not expecting error when initializing flags with featureId: %v", err) } - if sc.FeatureID != expectedResult { - t.Errorf("unmatching result for flags high priority, found %v instead of %v", sc.FeatureID, expectedResult) + if cfg.FeatureID != expectedResult { + t.Errorf("unmatching result for flags high priority, found %v instead of %v", cfg.FeatureID, expectedResult) } // 3. Test with all flags are applied instead of default values if no configuration JSON is provided expectedFlagBroker := "host:1234" expectedArtifact := "TestArchive" + expectedCACert := "TestCaCert.crt" + expectedCert := "TestCert.cert" + expectedKey := "TestKey.key" expectedFeatureID := "TestFeature" expectedInstall := "TestInstall" expectedServerCert := "TestCert" @@ -152,9 +155,12 @@ func TestFlagsHasHigherPriority(t *testing.T) { setFlags([]string{ c(flagBroker, expectedFlagBroker), c(flagArtifactType, expectedArtifact), + c(flagCACert, expectedCACert), + c(flagCert, expectedCert), + c(flagKey, expectedKey), c(flagFeatureID, expectedFeatureID), c(flagInstall, expectedInstall), - c(flagCert, expectedServerCert), + c(flagServerCert, expectedServerCert), c(flagRetryCount, strconv.Itoa(expectedDownloadRetryCount)), c(flagRetryInterval, expectedDownloadRetryInterval), c(flagInstallDirs, expectedInstallDir), @@ -171,28 +177,31 @@ func TestFlagsHasHigherPriority(t *testing.T) { c(flagVersion, expectedPrintVersion), }) - sc, lc, err = InitFlags(testVersion) + cfg, err = LoadConfig(testVersion) if err != nil { t.Errorf("not expecting error when initializing flags with featureId: %v", err) } - assertConfigEqual(sc, &ScriptBasedSoftwareUpdatableConfig{ + assertSoftwareUpdatable(t, cfg.ScriptBasedSoftwareUpdatableConfig, ScriptBasedSoftwareUpdatableConfig{ Broker: expectedFlagBroker, Username: expectedUsername, Password: expectedPassword, + CACert: expectedCACert, + Cert: expectedCert, + Key: expectedKey, ServerCert: expectedServerCert, StorageLocation: expectedStorageLocation, InstallCommand: command{cmd: expectedInstall}, DownloadRetryCount: expectedDownloadRetryCount, DownloadRetryInterval: getDurationTime(t, expectedDownloadRetryInterval), - InstallDirs: pathArgs{args: []string{expectedInstallDir}}, + InstallDirs: []string{expectedInstallDir}, Mode: expectedMode, FeatureID: expectedFeatureID, ModuleType: expectedModuleType, ArtifactType: expectedArtifact, }) - assertConfigEqual(lc, &logger.LogConfig{ + assertLogConfig(t, cfg.LogConfig, logger.LogConfig{ LogFile: expectedLogFile, LogLevel: expectedLogLevel, LogFileSize: expectedLogFileSize, @@ -201,20 +210,20 @@ func TestFlagsHasHigherPriority(t *testing.T) { }) } -// TestInitFlagsWithPrintVersion tests that initialization with flags prints version and exit -// when flag version is specified. Since the program exits, this test expects the InitFlags func to panic. -func TestInitFlagsWithPrintVersion(t *testing.T) { +// TestLoadConfigWithPrintVersion tests that initialization with flags prints version and exit +// when flag version is specified. Since the program exits, this test expects the LoadConfig func to panic. +func TestLoadConfigWithPrintVersion(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("expected to print the version and exit") } }() setFlags([]string{c(flagVersion, "true")}) - InitFlags(testVersion) + LoadConfig(testVersion) } -// TestInitFlagWithInvalidConfigFile tests the behaviour when wrong JSON config file is supplied. -func TestInitFlagWithInvalidConfigFile(t *testing.T) { +// TestLoadConfigWithInvalidConfigFile tests the behavior when wrong JSON config file is supplied. +func TestLoadConfigWithInvalidConfigFile(t *testing.T) { // Prepare test default dir dir := assertDirs(t, testDirFlags, true) defer os.RemoveAll(dir) @@ -222,32 +231,28 @@ func TestInitFlagWithInvalidConfigFile(t *testing.T) { //1. Test with JSON which is not starting with leading "{" character writeToConfigFile(t, "\"Broker\": \"tcp://host:1234\"}") setFlags([]string{c(flagConfigFile, testConfigFileDirs)}) - sc, lc, err := InitFlags(testVersion) - assertInitFlagsFails(t, sc, lc, err, "expecting init flags to fail if JSON format of the config file is not valid") + assertLoadConfigFails(t, "expecting init flags to fail if JSON format of the config file is not valid") //2. Test with JSON using string instead of integer for config field value writeToConfigFile(t, "{\"Broker\": \"tcp://host:1234\", \"LogFileSize\": \"20\"}") setFlags([]string{c(flagConfigFile, testConfigFileDirs)}) - sc, lc, err = InitFlags(testVersion) - assertInitFlagsFails(t, sc, lc, err, "expecting init flags to fail if provide int value to an string field") + assertLoadConfigFails(t, "expecting init flags to fail if provide int value to an string field") //3. Test with JSON using integer instead of string for config field value writeToConfigFile(t, "{\"Broker\": \"tcp://host:1234\", \"Username\": 200}") setFlags([]string{c(flagConfigFile, testConfigFileDirs)}) - sc, lc, err = InitFlags(testVersion) - assertInitFlagsFails(t, sc, lc, err, "expecting init flags to fail if provide string value to an int field") + assertLoadConfigFails(t, "expecting init flags to fail if provide string value to an int field") } -// TestInitFlagsWithMissingConfigFile tests that the error is returned when missing config file is +// TestLoadConfigWithMissingConfigFile tests that the error is returned when missing config file is // specified with flag. -func TestInitFlagsWithMissingConfigFile(t *testing.T) { +func TestLoadConfigWithMissingConfigFile(t *testing.T) { setFlags([]string{c(flagConfigFile, "TestLocation")}) - sc, lc, err := InitFlags(testVersion) - assertInitFlagsFails(t, sc, lc, err, "expecting error when initializing with missing config file flag") + assertLoadConfigFails(t, "expecting error when initializing with missing config file flag") } -// TestInitFlagsConfigAllPropertiesProvided verifies that all of the properties from the configuration file are set -func TestInitFlagsConfigAllPropertiesProvided(t *testing.T) { +// TestLoadConfigConfigAllPropertiesProvided verifies that all of the properties from the configuration file are set +func TestLoadConfigConfigAllPropertiesProvided(t *testing.T) { // Prepare test default dir dir := assertDirs(t, testDirFlags, true) defer os.RemoveAll(dir) @@ -258,27 +263,24 @@ func TestInitFlagsConfigAllPropertiesProvided(t *testing.T) { setFlags([]string{c(flagConfigFile, testConfigFileDirs)}) - expectedConfig := &ScriptBasedSoftwareUpdatableConfig{ - Broker: "tcp://host:1234", - FeatureID: "SoftwareTestUpdatable", - ArtifactType: "TestArchive", - ModuleType: "TestSoftware", - StorageLocation: dir, - Username: "TestUser", - Password: "TestPass", - DownloadRetryInterval: getDurationTime(t, "7s"), - Mode: "Scoped", - } - - expectedLogConfig := &logger.LogConfig{ - LogFile: "TestLogFile.txt", - LogLevel: "TRACE", - LogFileSize: 10, - LogFileCount: 20, - LogFileMaxAge: 30, - } - - compareConfigResult(t, expectedConfig, expectedLogConfig) + expectedConfig := NewDefaultConfig() + expectedConfig.Broker = "tcp://host:1234" + expectedConfig.FeatureID = "SoftwareTestUpdatable" + expectedConfig.ArtifactType = "TestArchive" + expectedConfig.ModuleType = "TestSoftware" + expectedConfig.StorageLocation = dir + expectedConfig.Username = "TestUser" + expectedConfig.Password = "TestPass" + expectedConfig.DownloadRetryInterval = getDurationTime(t, "7s") + expectedConfig.Mode = "Scoped" + + expectedConfig.LogFile = "TestLogFile.txt" + expectedConfig.LogLevel = "TRACE" + expectedConfig.LogFileSize = 10 + expectedConfig.LogFileCount = 20 + expectedConfig.LogFileMaxAge = 30 + + compareConfigResult(t, expectedConfig) } // TestWithEmptyConfigFile verifies if empty JSON configuration file is provided. @@ -291,15 +293,12 @@ func TestWithEmptyConfigFile(t *testing.T) { writeToConfigFile(t, "") setFlags([]string{c(flagConfigFile, testConfigFileDirs)}) - _, _, err := InitFlags(testVersion) - if err == nil { - t.Errorf("expected to fail when empty config file is supplied") - } + assertLoadConfigFails(t, "expected to fail when empty config file is supplied") } -// TestInitFlagsWithConfigMixedContent verifies if some of the configuration properties are provided +// TestLoadConfigWithConfigMixedContent verifies if some of the configuration properties are provided // and those who are not, are used from default values -func TestInitFlagsWithConfigMixedContent(t *testing.T) { +func TestLoadConfigWithConfigMixedContent(t *testing.T) { // Prepare test default dir dir := assertDirs(t, testDirFlags, true) defer os.RemoveAll(dir) @@ -308,79 +307,86 @@ func TestInitFlagsWithConfigMixedContent(t *testing.T) { writeToConfigFile(t, content) setFlags([]string{c(flagConfigFile, testConfigFileDirs)}) - expectedConfig := &ScriptBasedSoftwareUpdatableConfig{ - Broker: "tcp://host:12345", - FeatureID: getDefaultFlagValue(t, flagFeatureID), - ArtifactType: getDefaultFlagValue(t, flagArtifactType), - ModuleType: getDefaultFlagValue(t, flagModuleType), - StorageLocation: getDefaultFlagValue(t, flagStorageLocation), - Username: "test", - Password: getDefaultFlagValue(t, flagPassword), - DownloadRetryInterval: getDurationTime(t, getDefaultFlagValue(t, flagRetryInterval)), - Mode: getDefaultFlagValue(t, flagMode), - } + expectedConfig := NewDefaultConfig() + expectedConfig.Broker = "tcp://host:12345" + expectedConfig.Username = "test" - expectedLogConfig := &logger.LogConfig{ - LogFile: "test_log.txt", - LogLevel: "TRACE", - LogFileSize: getDefaultFlagValueInt(t, flagLogFileSize), - LogFileCount: getDefaultFlagValueInt(t, flagLogFileCount), - LogFileMaxAge: getDefaultFlagValueInt(t, flagLogFileMaxAge), - } + expectedConfig.LogFile = "test_log.txt" + expectedConfig.LogLevel = "TRACE" - compareConfigResult(t, expectedConfig, expectedLogConfig) + compareConfigResult(t, expectedConfig) } func TestInvalidAccessModeFlag(t *testing.T) { setFlags([]string{c(flagMode, "test"), c(flagFeatureID, "id")}) - sc, _, err := InitFlags(testVersion) + cfg, err := LoadConfig(testVersion) if err != nil { t.Errorf("not expecting error when initializing flags with invalid access mode: %v", err) } - if err = sc.Validate(); err == nil { + if err = cfg.Validate(); err == nil { t.Fatal("expecting error when validating configuration with invalid access mode flag") } } // compareConfigResult function verifies the content of the expected and actual configuration struct -func compareConfigResult(t *testing.T, expectedConfig *ScriptBasedSoftwareUpdatableConfig, expectedLogConfig *logger.LogConfig) { - suConfig, logConfig, err := InitFlags(testVersion) +func compareConfigResult(t *testing.T, expectedConfig *BasicConfig) { + cfg, err := LoadConfig(testVersion) - if suConfig == nil || logConfig == nil || err != nil { + if err != nil { t.Error("failed to init flags with valid configuration file: ", err) } - if !assertConfigEqual(expectedConfig, suConfig) || !assertConfigEqual(expectedLogConfig, logConfig) { - t.Errorf("configurations does not match, suConfig: %v != %v or logConfig: %v != %v ", - expectedConfig, suConfig, expectedLogConfig, logConfig) + assertSoftwareUpdatable(t, cfg.ScriptBasedSoftwareUpdatableConfig, expectedConfig.ScriptBasedSoftwareUpdatableConfig) + assertLogConfig(t, cfg.LogConfig, expectedConfig.LogConfig) +} + +func assertString(t *testing.T, actual, expected string) { + if expected != actual { + t.Errorf("Expected string value %s, but received value %s", expected, actual) } } -func assertConfigEqual(actual interface{}, expected interface{}) bool { - valueOfExp := reflect.ValueOf(expected).Elem() - typeOfExp := valueOfExp.Type() - valueOfAct := reflect.ValueOf(actual).Elem() - for i := 0; i < typeOfExp.NumField(); i++ { - fieldType := typeOfExp.Field(i) - expectedValue := valueOfExp.FieldByName(fieldType.Name).Interface() - actualValue := valueOfAct.FieldByName(fieldType.Name).Interface() - switch expectedValue.(type) { - case int, string: - if expectedValue != actualValue { - return false - } - default: - if !reflect.DeepEqual(expectedValue, actualValue) { - return false - } - } +func assertInt(t *testing.T, actual, expected int) { + if expected != actual { + t.Errorf("Expected int value %d, but received value %d", expected, actual) } - return true } -func assertInitFlagsFails(t *testing.T, suConfig *ScriptBasedSoftwareUpdatableConfig, - logConfig *logger.LogConfig, err error, msg string) { - if suConfig != nil || logConfig != nil || err == nil { +func assertDeep(t *testing.T, actual, expected interface{}) { + if !reflect.DeepEqual(expected, actual) { + t.Errorf("Expected %s, but received %s", expected, actual) + } +} + +func assertSoftwareUpdatable(t *testing.T, actual, expected ScriptBasedSoftwareUpdatableConfig) { + assertString(t, actual.Broker, expected.Broker) + assertString(t, actual.Username, expected.Username) + assertString(t, actual.Password, expected.Password) + assertString(t, actual.CACert, expected.CACert) + assertString(t, actual.Cert, expected.Cert) + assertString(t, actual.Key, expected.Key) + assertString(t, actual.ServerCert, expected.ServerCert) + assertString(t, actual.StorageLocation, expected.StorageLocation) + assertInt(t, actual.DownloadRetryCount, expected.DownloadRetryCount) + assertDeep(t, actual.DownloadRetryInterval, expected.DownloadRetryInterval) + assertDeep(t, actual.InstallDirs, expected.InstallDirs) + assertString(t, actual.Mode, expected.Mode) + assertString(t, actual.FeatureID, expected.FeatureID) + assertString(t, actual.ModuleType, expected.ModuleType) + assertString(t, actual.ArtifactType, expected.ArtifactType) +} + +func assertLogConfig(t *testing.T, actual, expected logger.LogConfig) { + assertString(t, actual.LogFile, expected.LogFile) + assertString(t, actual.LogLevel, expected.LogLevel) + assertInt(t, actual.LogFileSize, expected.LogFileSize) + assertInt(t, actual.LogFileCount, expected.LogFileCount) + assertInt(t, actual.LogFileMaxAge, expected.LogFileMaxAge) +} + +func assertLoadConfigFails(t *testing.T, msg string) { + _, err := LoadConfig(testVersion) + if err == nil { t.Error(msg) } } @@ -388,33 +394,33 @@ func assertInitFlagsFails(t *testing.T, suConfig *ScriptBasedSoftwareUpdatableCo // assertInstallCommand verifies the result when initializing flags with install command, // which is specified with config file or flag func assertInstallCommand(t *testing.T, expectedInstallCMD string, expectedInstallArgs string) { - sc, _, err := InitFlags(testVersion) + cfg, err := LoadConfig(testVersion) if err != nil { t.Errorf("not expecting error when initializing with install config: %v", err) } - if len(sc.InstallCommand.args) != 1 { + if len(cfg.InstallCommand.args) != 1 { t.Error("expecting install command to be set") } - if sc.InstallCommand.args[0] != expectedInstallArgs { - t.Errorf("unmatching install command args, expected %v, actual %v", expectedInstallArgs, sc.InstallCommand.args) + if cfg.InstallCommand.args[0] != expectedInstallArgs { + t.Errorf("unmatching install command args, expected %v, actual %v", expectedInstallArgs, cfg.InstallCommand.args) } - if sc.InstallCommand.cmd != expectedInstallCMD { - t.Errorf("unmatching install command path, expected %v, actual %v", expectedInstallCMD, sc.InstallCommand.cmd) + if cfg.InstallCommand.cmd != expectedInstallCMD { + t.Errorf("unmatching install command path, expected %v, actual %v", expectedInstallCMD, cfg.InstallCommand.cmd) } } // assertInstallDirs verifies the result when initializing flags with install directories, // which are specified with config file or flag func assertInstallDirs(t *testing.T, expectedInstallDir string) { - sc, _, err := InitFlags(testVersion) + cfg, err := LoadConfig(testVersion) if err != nil { t.Errorf("not expecting error when initializing with install config: %v", err) } - if len(sc.InstallDirs.args) != 1 { + if len(cfg.InstallDirs) != 1 { t.Error("expecting install directories to be set") } - if sc.InstallDirs.args[0] != expectedInstallDir { - t.Errorf("unmatching install directories args, expected %v, actual %v", expectedInstallDir, sc.InstallCommand.args) + if cfg.InstallDirs[0] != expectedInstallDir { + t.Errorf("unmatching install directories args, expected %v, actual %v", expectedInstallDir, cfg.InstallCommand) } } @@ -441,8 +447,6 @@ func setFlags(args []string) { func resetArgs() { // save the state of args before the test starts oldArgs := os.Args - suConfig = &ScriptBasedSoftwareUpdatableConfig{} - logConfig = &logger.LogConfig{} defer func() { // return the state of args for the next tests @@ -452,35 +456,6 @@ func resetArgs() { }() } -func getDefaultFlagValue(t *testing.T, flagName string) string { - flagName = toFieldName(flagName) - valueOf := reflect.ValueOf(cfg{}) - typeOf := valueOf.Type() - fieldName := toFieldName(flagName) - fieldType, ok := typeOf.FieldByName(fieldName) - if ok { - return fieldType.Tag.Get("def") - } - t.Fatalf("unable to get field %s", fieldName) - return "" // unreachable -} - -func getDefaultFlagValueInt(t *testing.T, flagName string) int { - valueOf := reflect.ValueOf(cfg{}) - typeOf := valueOf.Type() - fieldName := toFieldName(flagName) - fieldType, ok := typeOf.FieldByName(fieldName) - if ok { - result, err := strconv.Atoi(fieldType.Tag.Get("def")) - if err != nil { - t.Fatal(err) - } - return result - } - t.Fatalf("unable to get field %s", fieldName) - return 0 // unreachable -} - func getDurationTime(t *testing.T, defaultValue string) (result durationTime) { err := result.Set(defaultValue) if err != nil { diff --git a/internal/logger/logger.go b/internal/logger/logger.go index bbc1d0a..023f2a4 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -28,11 +28,11 @@ import ( // LogConfig represents a log configuration. type LogConfig struct { - LogFile string - LogLevel string - LogFileSize int - LogFileCount int - LogFileMaxAge int + LogFile string `json:"logFile,omitempty"` + LogLevel string `json:"logLevel,omitempty"` + LogFileSize int `json:"logFileSize,omitempty"` + LogFileCount int `json:"logFileCount,omitempty"` + LogFileMaxAge int `json:"logFileMaxAge,omitempty"` } // LogLevel represents a log level. diff --git a/internal/path_args.go b/internal/path_args.go index c4eaaa8..e44fd6d 100644 --- a/internal/path_args.go +++ b/internal/path_args.go @@ -13,18 +13,32 @@ package feature import ( + "fmt" "strings" ) type pathArgs struct { - args []string + args *[]string } func (a *pathArgs) String() string { - return strings.Join(a.args, " ") + if a.args == nil { + return "" + } + return strings.Join(*a.args, " ") } func (a *pathArgs) Set(value string) error { - a.args = append(a.args, value) + if len(value) == 0 { + return fmt.Errorf("value cannot be empty") + } + *a.args = strings.Fields(value) return nil } + +// newPathArgs creates new flag variable for slice of strings definition. +func newPathArgs(setter *[]string) *pathArgs { + return &pathArgs{ + args: setter, + } +} diff --git a/internal/path_args_test.go b/internal/path_args_test.go new file mode 100644 index 0000000..de66313 --- /dev/null +++ b/internal/path_args_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2023 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package feature + +import ( + "flag" + "io" + "reflect" + "testing" +) + +func TestPathArgsIsSet(t *testing.T) { + f := flag.NewFlagSet("testing", flag.ContinueOnError) + + var s []string + v := newPathArgs(&s) + f.Var(v, "S", "S") + + args := []string{"-S=a b"} + err := f.Parse(args) + if err != nil { + t.Errorf("Expected no error, but received %s", err) + } + + if "a b" != v.String() { + t.Errorf("Expected string value %s, but received value %s", "a b", v.String()) + } + + expected := []string{"a", "b"} + if !reflect.DeepEqual(expected, s) { + t.Errorf("Expected %s, but received %s", expected, s) + } +} + +func TestPathArgsInvalid(t *testing.T) { + f := flag.NewFlagSet("testing", flag.ContinueOnError) + f.SetOutput(io.Discard) + + var s []string + v := newPathArgs(&s) + f.Var(v, "S", "S") + + args := []string{"-S="} + err := f.Parse(args) + if err == nil { + t.Errorf("Expected error, but not received") + } +} diff --git a/internal/utils_test.go b/internal/utils_test.go index 83417d6..b2474ab 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -41,6 +41,9 @@ const ( flagBroker = "broker" flagUsername = "username" flagPassword = "password" + flagCACert = "caCert" + flagCert = "cert" + flagKey = "key" flagStorageLocation = "storageLocation" flagFeatureID = "featureId" flagModuleType = "moduleType" @@ -51,9 +54,11 @@ const ( flagLogFileSize = "logFileSize" flagLogFileCount = "logFileCount" flagLogFileMaxAge = "logFileMaxAge" - flagCert = "serverCert" + flagServerCert = "serverCert" flagRetryCount = "downloadRetryCount" flagRetryInterval = "downloadRetryInterval" + flagInstallDirs = "installDirs" + flagVersion = "version" ) // testConfig is used to provide mock data @@ -92,16 +97,13 @@ func assertDirs(t *testing.T, name string, create bool) string { func connectFeature(t *testing.T, mc *mockedClient, feature *ScriptBasedSoftwareUpdatable, featureID string) error { t.Helper() - supConfig := &ScriptBasedSoftwareUpdatableConfig{ - Broker: getDefaultFlagValue(t, flagBroker), - FeatureID: featureID, - ModuleType: getDefaultFlagValue(t, flagModuleType), - } + cfg := NewDefaultConfig() + cfg.ScriptBasedSoftwareUpdatableConfig.FeatureID = featureID edgeCfg := &edgeConfiguration{ DeviceID: model.NewNamespacedID(testTopicNamespace, testTopicEntryID).String(), TenantID: testTenantID, } - return feature.Connect(mc, supConfig, edgeCfg) + return feature.Connect(mc, &cfg.ScriptBasedSoftwareUpdatableConfig, edgeCfg) } // mockScriptBasedSoftwareUpdatable create new ScriptBasedSoftwareUpdatable with mocked MQTT clients. diff --git a/util/tls/testdata/ca.crt b/util/tls/testdata/ca.crt new file mode 100644 index 0000000..59609aa --- /dev/null +++ b/util/tls/testdata/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUbD0mjn2x1H5VKi54xgjfPsEugzAwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQkcxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIyMTAyNzE1MDk0OFoXDTI3MTAyNzE1MDk0OFowWTELMAkGA1UEBhMCQkcxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAqRZyG9gywFvplpfPq60X7JYtkbj7OF+tOfixF9FkGDJ9Hhh4LoMv +W+rdLnxEHfI4Vicn0Wj3r/Ra1bd9yo8hzGiYdNliH9kegSbjB2Xx8N4yTCqaiZ9E +JQLcstQeXHEP3YwPXLnNfTOmbQPAbC2T9J+USlmolG1qpkuU5rQVC/sjW5M8MOmN +TuKEU6pds/j8GQKhQmsIHddwfypnBDilpYOotgeMwDqsyM4+zdSXbFaNmYoh3Tjb +FSN0WJTdPe7uv+nG03NZe6dHvN4C/8Z5uBIkeLw8yrO/2Wb7aAxwtpK3Wswzlya9 +TAwALiOI+hWuIfUmkoQapskIMhcRAnpMxwIDAQABo1MwUTAdBgNVHQ4EFgQUH2fo +QuQyEoszAL3vBOBqs8OLg7YwHwYDVR0jBBgwFoAUH2foQuQyEoszAL3vBOBqs8OL +g7YwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAYNY+8FNt9Pr/ +I6NYzY56opMIXmMErRMBN4i5WkuZxXp7gaqPNrx42O7gaWJz7dwwGUmrb0eHyOIm +XetphU6cW1vNgaXPOggacp2wRJ6AGJn1+v/hfWU6sPt0XWPM9p5umoCeYJeZ1UlE +uodsUGEQc7b/ODROObPHAFc/18nChoiPylXtB5TcgdyzalzhL/d8B/c4QZJwvT+W +L+8IoChNQzeH4yCgoDZXaQpRfrnGjLyrpx4dojNBYd/rJsGNZa+wMwzhFZU3f3QY +jeZnnp+nVBw+/L3q/FVwCee/RsYiR797OL7wyPAJPGd4iqbh09Hv0B2YgIih86X6 +giPwyRzg7g== +-----END CERTIFICATE----- diff --git a/util/tls/testdata/certificate.pem b/util/tls/testdata/certificate.pem new file mode 100644 index 0000000..7697626 --- /dev/null +++ b/util/tls/testdata/certificate.pem @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEjzCCAncCFAEWn/QU1vzg07IeNzhaX/6pNlTeMA0GCSqGSIb3DQEBCwUAMIGD +MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2Fu +IEZyYW5jaXNjbzEVMBMGA1UECgwMRXhhbXBsZSBJbmMuMRYwFAYDVQQLDA1JVCBE +ZXBhcnRtZW50MRgwFgYDVQQDDA93d3cuZXhhbXBsZS5jb20wHhcNMjEwNTEzMDg0 +NTU3WhcNMjExMDEwMDg0NTU3WjCBgzELMAkGA1UEBhMCVVMxFjAUBgNVBAgMDU1h +c3NhY2h1c2V0dHMxDzANBgNVBAcMBkJvc3RvbjEVMBMGA1UECgwMRXhhbXBsZSBJ +bmMuMRYwFAYDVQQLDA1IUiBEZXBhcnRtZW50MRwwGgYDVQQDDBN3d3cuZXhhbXBs +ZS1pbmMuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApmF7J4WB +BujYYNt86YHSWrv3ffX3Odt57y4kqhWuMM8VcA2RXTlJO3SXX8xLF/+lsaZZJmfg +xR1tJ0hKfBYt78H03bjrylWLOQoRqlyVuz0SF3ueR9vx3mPIO0F0E3Q2mC4SLHbr +5kW4aj3/NszLzgvZPbOchfcktdCd1vME+pM8lPPY6z8qWJzlWjOYdymWUV8z9qlA +c3VCnYQ+UwJY51vTcMPpalByrMPWkiNic+9Onl8KHCik27vNMIVVLYLa6763UwU8 +qCKT+jZj6nlCT8w5oqoJNsaSd/EGGFtGR+qZj6V2TrsI2cNknyY7Qf/QN0+zH0nm +XsOxK6jW8q7RGwIDAQABMA0GCSqGSIb3DQEBCwUAA4ICAQBe/kcT2L54PxZqb3GU +liwYJGjB+9fkqTyMwglt8dAm3it9F/POyXtoKB8a1AuaZ/FJlR+AUOFv+f3i0ZnE +Ek0OAsllVPclv7HhywD1HzrbLh0PreGsBnYgyrW7qZKAfevus0U0GrjhcrY7zCoA +EBFWWqcWqhRCFXYwgI13ZNLhYl7r+NIWLza1bPcnWVfY2g19/nctR53ZFFiVkvlk +FCYGat0SPQWvjFIKaCNQQL6IZSxqk95W87kEWrac9A+bQzENpWLfwu86O5r6vKNK +pHmd47Sy8hyhf75/SEOQuWBEkgT+sPXU7TykvFB8kzO1Wmsz7D2/d5pvkjZF31dQ +m4ZHIuclPOETtAwiY13dI94vAhgruK0FRFn7jyfePn20CFqUCOO9cEQytysCO/n7 +4xJPbIVcvUO825Kbos71OWfNkLEi1tlLkFpe73/rSXnZRWweqAThrY7jxGxhveI4 +iYrOOYEqGdM6VfLvVhYXhsc/MDqiqJLdbhQAS/lE8bJrnVnVRYnnb0ExfNTwqVU3 +8YZB6JbT+j9556c675j0sfa0J8qgDRHsWR7EG5u5wENnDH/s142dYcjJ5S6tCumM +K/GExzmrJFHmLfqKTSxAquMYvDVufICcklL067DacRJKHkyx5KvKkgc3R8rTaX2o +D+0ioElFJXVQ7rULiWzxZs+Gdw== +-----END CERTIFICATE----- diff --git a/util/tls/testdata/empty.crt b/util/tls/testdata/empty.crt new file mode 100644 index 0000000..e69de29 diff --git a/util/tls/testdata/invalid.pem b/util/tls/testdata/invalid.pem new file mode 100644 index 0000000..eeb0277 --- /dev/null +++ b/util/tls/testdata/invalid.pem @@ -0,0 +1 @@ +---invalid data--- \ No newline at end of file diff --git a/util/tls/testdata/key.pem b/util/tls/testdata/key.pem new file mode 100644 index 0000000..e0b52a6 --- /dev/null +++ b/util/tls/testdata/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApmF7J4WBBujYYNt86YHSWrv3ffX3Odt57y4kqhWuMM8VcA2R +XTlJO3SXX8xLF/+lsaZZJmfgxR1tJ0hKfBYt78H03bjrylWLOQoRqlyVuz0SF3ue +R9vx3mPIO0F0E3Q2mC4SLHbr5kW4aj3/NszLzgvZPbOchfcktdCd1vME+pM8lPPY +6z8qWJzlWjOYdymWUV8z9qlAc3VCnYQ+UwJY51vTcMPpalByrMPWkiNic+9Onl8K +HCik27vNMIVVLYLa6763UwU8qCKT+jZj6nlCT8w5oqoJNsaSd/EGGFtGR+qZj6V2 +TrsI2cNknyY7Qf/QN0+zH0nmXsOxK6jW8q7RGwIDAQABAoIBAH09m6qgQAOnellO +XrSW2HUcUKwsXjDbGOoF3et57mknOIfkbqux14I9vUSLT2t9MIiNI0ZZo0Q9ZlDP +heHqACId6eiMrlDcG7SP88Q9dShATEII95g349T3X13bYzjRndbntx5pViE8EhlH +GblyZ2duW9SqQwREiQmjQ2zt+a1zuKUAGdysS+4101UHUj6tC/RjiNN/TCXXRJIX +GOJ4WFLY+f2bXgSmqbK7wqN9nPxmxl/+bv4hO32Gsv06ejuy/6+GFJZa+n3ASU2r +/ptr4vgCK+t6I0OWVTpvYUboEwAXam4JfAu12zLtczfXVFJjzklUvSClIG9aJ6DP +2B3LEuECgYEA1RenCuVl6sM2o858X8iOLTMuCZWX5VFKY1+vvLwB/+CBfy7/dDYw +lbv+xaots0rY9Wn784ewi9zJbdnXE1YNgj0utIMHzylXvTolnDYsoO7SYpIwqtpa +PyzPcAV3Khkd5LGe1hf9VmOJfTF/563ztLXip0HUeIvzgB/maqfAQW8CgYEAx+H4 +GZ3ycdL03x7Zvp5g5yDPZGvxqVIEIFmliagEFyBgqogbXuOvsvZDTg+gKV5/QyVg +FWokz5VtC6U9UcWfF7LJses/Hsedh9IeICd9UIzkey8UmS2mDBeGfTHVTNuczml5 +VzmTK8jGUO5aRVRyOt0tqVf6Oozo8ImdI1fOPRUCgYBI4p4wC+agNcUqoiXIXUDE +FQ1aGeCqfvOCqefiFixY6OFiLyERDrfvfy3VTi/zc1ZiGq4izfaE4C/Fcw0tf/F+ +6o5fD7JMGUf5YTocBCufoBA1xur+hVD46srI9hWcQJsI7ff2Ip50Pfd46sVk6QrC +dLPhoZKa6MOQv1iAgoAv4QKBgQDAfiO6N9vmNizQOxujcU8NBxHzOekvEOccaHj9 +Cqt1wh6V3CHPziHEjVjf8jhh3rlcZsATn3b32oV7c5SMDW9bGTkYeN7+u2pABOAy +QxVx312iLAMASW/hsT45jyZFsDFgrz7F+5J51g72nbSdk+e2PI7eyPUYMd+a1kxY +XxUkyQKBgBde+jg2UjGHAfW6ZWSPHWvi74oD42VzbJg6/o/KFaZGEyHdKWgrwBW7 +8wNyNTMOMSHyuvtMnc5oLd492aO2a+yurgzt7yIelVuphSKda0wd/0PljKb0lwlb +VzldR4bXolbyjNXky1KWmvSNkmbPFNGSCeyHmKQgSLXKGuWbj4ic +-----END RSA PRIVATE KEY----- diff --git a/util/tls/tls_config.go b/util/tls/tls_config.go new file mode 100644 index 0000000..e9fd9ad --- /dev/null +++ b/util/tls/tls_config.go @@ -0,0 +1,60 @@ +// Copyright (c) 2023 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" +) + +// NewTLSConfig initializes the TLS. +func NewTLSConfig(rootCert, cert, key string) (*tls.Config, error) { + fmt.Println("NewTLS CONFIG!", rootCert) + caCert, err := os.ReadFile(rootCert) + if err != nil { + return nil, fmt.Errorf("failed to load CA: %s", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA %s", rootCert) + } + + tlsConfig := &tls.Config{ + InsecureSkipVerify: false, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + CipherSuites: supportedCipherSuites(), + } + + if len(cert) > 0 || len(key) > 0 { + cert, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("failed to load X509 key pair: %s", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + return tlsConfig, nil +} + +func supportedCipherSuites() []uint16 { + cs := tls.CipherSuites() + cid := make([]uint16, len(cs)) + for i := range cs { + cid[i] = cs[i].ID + } + return cid +} diff --git a/util/tls/tls_config_test.go b/util/tls/tls_config_test.go new file mode 100644 index 0000000..aa8b875 --- /dev/null +++ b/util/tls/tls_config_test.go @@ -0,0 +1,104 @@ +// Copyright (c) 2023 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package tls + +import ( + "crypto/tls" + "errors" + "fmt" + "path/filepath" + "testing" +) + +const ( + nonExisting = "nonexisting.test" + invalidFile = "testdata/invalid.pem" + caCertPath = "testdata/ca.crt" + certPath = "testdata/certificate.pem" + keyPath = "testdata/key.pem" + caError = "failed to load CA: open %s: no such file or directory" + keyPairError = "failed to load X509 key pair: open %s: no such file or directory" + invalidError = "failed to load X509 key pair: tls: failed to find any PEM data in key input" +) + +func TestNewTLSConfig(t *testing.T) { + dirAbsPath, _ := filepath.Abs("./") + + tests := map[string]struct { + CACert string + Cert string + Key string + ExpectedError error + }{ + "valid_config_with_credentials": {CACert: caCertPath, Cert: certPath, Key: keyPath, ExpectedError: nil}, + "valid_config_no_credentials": {CACert: caCertPath, Cert: "", Key: "", ExpectedError: nil}, + "no_files_provided": {CACert: "", Cert: "", Key: "", ExpectedError: fmt.Errorf(caError, "")}, + "non_existing_ca_file": {CACert: nonExisting, Cert: "", Key: "", ExpectedError: fmt.Errorf(caError, nonExisting)}, + "invalid_ca_file": {CACert: invalidFile, Cert: certPath, Key: keyPath, ExpectedError: fmt.Errorf("failed to parse CA %s", invalidFile)}, + "invalid_ca_file_arg": {CACert: "\\\000", Cert: certPath, Key: keyPath, ExpectedError: errors.New("failed to load CA: open \\\000: invalid argument")}, + "not_abs_cert_file_provided": {CACert: caCertPath, Cert: nonExisting, Key: "", ExpectedError: fmt.Errorf(keyPairError, nonExisting)}, + "cert_is_directory": {CACert: caCertPath, Cert: dirAbsPath, Key: "", ExpectedError: fmt.Errorf("failed to load X509 key pair: read %s: is a directory", dirAbsPath)}, + "no_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: "", ExpectedError: fmt.Errorf(keyPairError, "")}, + "not_abs_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: nonExisting, ExpectedError: fmt.Errorf(keyPairError, nonExisting)}, + "empty_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: "testdata/empty.crt", ExpectedError: errors.New(invalidError)}, + "cert_file_instead_key": {CACert: caCertPath, Cert: certPath, Key: caCertPath, ExpectedError: fmt.Errorf("failed to load X509 key pair: tls: found a certificate rather than a key in the PEM for the private key")}, + "invalid_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: invalidFile, ExpectedError: errors.New(invalidError)}, + // "real_files": {CACert: "/home/antonia/tony/kanto/mosquitto_new_certs/ca.crt", Cert: "/home/antonia/tony/kanto/mosquitto_new_certs/client.crt", Key: "/home/antonia/tony/kanto/mosquitto_new_certs/client.key", ExpectedError: nil}, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + cfg, err := NewTLSConfig(testCase.CACert, testCase.Cert, testCase.Key) + if testCase.ExpectedError != nil { + if testCase.ExpectedError.Error() != err.Error() { + t.Fatalf("expected error : %s, got: %s", testCase.ExpectedError, err) + } + if cfg != nil { + t.Fatalf("expected nil, got: %v", cfg) + } + } else { + if err != nil { + t.Fatal(err) + } + if len(cfg.Certificates) == 0 && testCase.Cert != "" && testCase.Key != "" { + t.Fatal("certificates length must not be 0") + } + if len(cfg.CipherSuites) == 0 { + t.Fatal("cipher suites length must not be 0") + } + // assert that cipher suites identifiers are contained in tls.CipherSuites + for _, csID := range cfg.CipherSuites { + if !func() bool { + for _, cs := range tls.CipherSuites() { + if cs.ID == csID { + return true + } + } + return false + }() { + t.Fatalf("cipher suite %d is not implemented", csID) + } + } + if cfg.InsecureSkipVerify { + t.Fatal("skip verify is set to true") + } + if cfg.MinVersion != tls.VersionTLS12 { + t.Fatalf("invalid min TLS version %d", cfg.MinVersion) + } + if cfg.MaxVersion != tls.VersionTLS13 { + t.Fatalf("invalid max TLS version %d", cfg.MaxVersion) + } + } + }) + } +}