diff --git a/cmd/update-manager/main.go b/cmd/update-manager/main.go index 5703596..66d57bc 100755 --- a/cmd/update-manager/main.go +++ b/cmd/update-manager/main.go @@ -43,13 +43,15 @@ func main() { var client api.UpdateAgentClient if cfg.ThingsEnabled { - client = mqtt.NewUpdateAgentThingsClient(cfg.Domain, cfg.MQTT) + client, err = mqtt.NewUpdateAgentThingsClient(cfg.Domain, cfg.MQTT) } else { - client = mqtt.NewUpdateAgentClient(cfg.Domain, cfg.MQTT) + client, err = mqtt.NewUpdateAgentClient(cfg.Domain, cfg.MQTT) } - updateManager, err := orchestration.NewUpdateManager(version, cfg, client, orchestration.NewUpdateOrchestrator(cfg)) if err == nil { - err = app.Launch(cfg, client, updateManager) + updateManager, err := orchestration.NewUpdateManager(version, cfg, client, orchestration.NewUpdateOrchestrator(cfg)) + if err == nil { + err = app.Launch(cfg, client, updateManager) + } } if err != nil { diff --git a/config/flags.go b/config/flags.go index 3a81a73..b3fc3d8 100755 --- a/config/flags.go +++ b/config/flags.go @@ -48,6 +48,9 @@ func SetupFlags(flagSet *flag.FlagSet, cfg *BaseConfig) { flagSet.StringVar(&cfg.MQTT.AcknowledgeTimeout, "mqtt-conn-ack-timeout", EnvToString("MQTT_CONN_ACK_TIMEOUT", cfg.MQTT.AcknowledgeTimeout), "Acknowledge timeout for the MQTT requests as duration string") flagSet.StringVar(&cfg.MQTT.SubscribeTimeout, "mqtt-conn-sub-timeout", EnvToString("MQTT_CONN_SUB_TIMEOUT", cfg.MQTT.SubscribeTimeout), "Subscribe timeout for the MQTT requests as duration string") flagSet.StringVar(&cfg.MQTT.UnsubscribeTimeout, "mqtt-conn-unsub-timeout", EnvToString("MQTT_CONN_UNSUB_TIMEOUT", cfg.MQTT.UnsubscribeTimeout), "Unsubscribe timeout for the MQTT requests as duration string") + flagSet.StringVar(&cfg.MQTT.CACert, "mqtt-conn-ca-cert", EnvToString("MQTT_CONN_CA_CERT", cfg.MQTT.CACert), "Specify the PEM encoded CA certificates file") + flagSet.StringVar(&cfg.MQTT.Cert, "mqtt-conn-cert", EnvToString("MQTT_CONN_CERT", cfg.MQTT.Cert), "Specify the PEM encoded certificate file to authenticate to the MQTT server/broker") + flagSet.StringVar(&cfg.MQTT.Key, "mqtt-conn-key", EnvToString("MQTT_CONN_KEY", cfg.MQTT.Key), "Specify the PEM encoded unencrypted private key file to authenticate to the MQTT server/broker") flagSet.StringVar(&cfg.Domain, "domain", EnvToString("DOMAIN", cfg.Domain), "Specify the Domain of this update agent, used as MQTT topic prefix.") diff --git a/config/flags_test.go b/config/flags_test.go index 994b8c4..ea91cc6 100644 --- a/config/flags_test.go +++ b/config/flags_test.go @@ -89,6 +89,18 @@ func TestSetupFlags(t *testing.T) { flag: "mqtt-conn-unsub-timeout", expectedType: reflect.String.String(), }, + "test_flags_mqtt-conn-ca-cert": { + flag: "mqtt-conn-ca-cert", + expectedType: reflect.String.String(), + }, + "test_flags_mqtt-conn-cert": { + flag: "mqtt-conn-cert", + expectedType: reflect.String.String(), + }, + "test_flags_mqtt-conn-key": { + flag: "mqtt-conn-key", + expectedType: reflect.String.String(), + }, "test_flags_domain": { flag: "domain", expectedType: reflect.String.String(), diff --git a/mqtt/config.go b/mqtt/config.go index c7d2e20..5256e75 100755 --- a/mqtt/config.go +++ b/mqtt/config.go @@ -23,6 +23,9 @@ const ( defaultAcknowledgeTimeout = "15s" defaultSubscribeTimeout = "15s" defaultUnsubscribeTimeout = "5s" + defaultCACert = "" + defaultCert = "" + defaultKey = "" ) // ConnectionConfig represents the mqtt client connection config @@ -36,6 +39,9 @@ type ConnectionConfig struct { AcknowledgeTimeout string `json:"acknowledgeTimeout,omitempty"` SubscribeTimeout string `json:"subscribeTimeout,omitempty"` UnsubscribeTimeout string `json:"unsubscribeTimeout,omitempty"` + CACert string `json:"caCert,omitempty"` + Cert string `json:"cert,omitempty"` + Key string `json:"key,omitempty"` } // NewDefaultConfig returns a default mqtt client connection config instance @@ -50,5 +56,8 @@ func NewDefaultConfig() *ConnectionConfig { AcknowledgeTimeout: defaultAcknowledgeTimeout, SubscribeTimeout: defaultSubscribeTimeout, UnsubscribeTimeout: defaultUnsubscribeTimeout, + CACert: defaultCACert, + Cert: defaultCert, + Key: defaultKey, } } diff --git a/mqtt/update_agent_client.go b/mqtt/update_agent_client.go index 39fdecd..b33b2f4 100755 --- a/mqtt/update_agent_client.go +++ b/mqtt/update_agent_client.go @@ -14,6 +14,7 @@ package mqtt import ( "fmt" + "net/url" "strconv" "strings" "time" @@ -21,6 +22,7 @@ import ( "github.com/eclipse-kanto/update-manager/api" "github.com/eclipse-kanto/update-manager/api/types" "github.com/eclipse-kanto/update-manager/logger" + "github.com/eclipse-kanto/update-manager/util/tls" pahomqtt "github.com/eclipse/paho.mqtt.golang" "github.com/google/uuid" @@ -49,6 +51,9 @@ type internalConnectionConfig struct { AcknowledgeTimeout time.Duration SubscribeTimeout time.Duration UnsubscribeTimeout time.Duration + CACert string + Cert string + Key string } func newInternalConnectionConfig(config *ConnectionConfig) *internalConnectionConfig { @@ -62,6 +67,9 @@ func newInternalConnectionConfig(config *ConnectionConfig) *internalConnectionCo AcknowledgeTimeout: parseDuration("mqtt-conn-ack-timeout", config.AcknowledgeTimeout, defaultAcknowledgeTimeout), SubscribeTimeout: parseDuration("mqtt-conn-sub-timeout", config.SubscribeTimeout, defaultSubscribeTimeout), UnsubscribeTimeout: parseDuration("mqtt-conn-unsub-timeout", config.UnsubscribeTimeout, defaultUnsubscribeTimeout), + CACert: config.CACert, + Cert: config.Cert, + Key: config.Key, } } @@ -99,13 +107,16 @@ type updateAgentClient struct { } // NewUpdateAgentClient instantiates a new UpdateAgentClient instance using the provided configuration options. -func NewUpdateAgentClient(domain string, config *ConnectionConfig) api.UpdateAgentClient { +func NewUpdateAgentClient(domain string, config *ConnectionConfig) (api.UpdateAgentClient, error) { client := &updateAgentClient{ mqttClient: newInternalClient(domain, newInternalConnectionConfig(config), nil), domain: domain, } - client.pahoClient = newClient(client.mqttConfig, client.onConnect) - return client + pahoClient, err := newClient(client.mqttConfig, client.onConnect) + if err == nil { + client.pahoClient = pahoClient + } + return client, err } // Domain returns the name of the domain that is handled by this client. @@ -248,7 +259,7 @@ func domainAsTopic(domain string) string { return domain + "update" } -func newClient(config *internalConnectionConfig, onConnect pahomqtt.OnConnectHandler) pahomqtt.Client { +func newClient(config *internalConnectionConfig, onConnect pahomqtt.OnConnectHandler) (pahomqtt.Client, error) { clientOptions := pahomqtt.NewClientOptions(). SetClientID(uuid.New().String()). AddBroker(config.Broker). @@ -261,7 +272,31 @@ func newClient(config *internalConnectionConfig, onConnect pahomqtt.OnConnectHan SetUsername(config.Username). SetPassword(config.Password) - return pahomqtt.NewClient(clientOptions) + u, err := url.Parse(config.Broker) + if err != nil { + return nil, err + } + if isConnectionSecure(u.Scheme) { + if len(config.CACert) == 0 { + return nil, errors.New("connection is secure, but no TLS configuration is provided") + } + tlsConfig, err := tls.NewTLSConfig(config.CACert, config.Cert, config.Key) + if err != nil { + return nil, err + } + clientOptions.SetTLSConfig(tlsConfig) + } + + return pahomqtt.NewClient(clientOptions), nil +} + +func isConnectionSecure(schema string) bool { + switch schema { + case "wss", "ssl", "tls", "mqtts", "mqtt+ssl", "tcps": + return true + default: + } + return false } func getAndPublishCurrentState(domain string, currentStateGetHandler func(string, int64) error) { diff --git a/mqtt/update_agent_things_client.go b/mqtt/update_agent_things_client.go index a7c6e03..560e053 100755 --- a/mqtt/update_agent_things_client.go +++ b/mqtt/update_agent_things_client.go @@ -45,7 +45,7 @@ type updateAgentThingsClient struct { } // NewUpdateAgentThingsClient instantiates a new UpdateAgentClient instance using the provided configuration options. -func NewUpdateAgentThingsClient(domain string, config *ConnectionConfig) api.UpdateAgentClient { +func NewUpdateAgentThingsClient(domain string, config *ConnectionConfig) (api.UpdateAgentClient, error) { internalConfig := newInternalConnectionConfig(config) client := &updateAgentThingsClient{ updateAgentClient: &updateAgentClient{ @@ -53,8 +53,11 @@ func NewUpdateAgentThingsClient(domain string, config *ConnectionConfig) api.Upd domain: domain, }, } - client.pahoClient = newClient(internalConfig, client.onConnect) - return client + pahoClient, err := newClient(internalConfig, client.onConnect) + if err == nil { + client.pahoClient = pahoClient + } + return client, err } // Domain returns the name of the domain that is handled by this client. diff --git a/updatem/orchestration/update_manager_test.go b/updatem/orchestration/update_manager_test.go index eefbb40..fd65b3d 100644 --- a/updatem/orchestration/update_manager_test.go +++ b/updatem/orchestration/update_manager_test.go @@ -47,7 +47,8 @@ func TestNewUpdateManager(t *testing.T) { cfg := createTestConfig(false, false) t.Run("test_no_error", func(t *testing.T) { - uaClient := mqtt.NewUpdateAgentClient("device", &mqtt.ConnectionConfig{}) + uaClient, err := mqtt.NewUpdateAgentClient("device", &mqtt.ConnectionConfig{}) + assert.NoError(t, err) apiUpdateManager, err := NewUpdateManager("dummyVersion", cfg, uaClient, nil) assert.NoError(t, err) updateManager := apiUpdateManager.(*aggregatedUpdateManager) diff --git a/util/tls/testdata/ca.crt b/util/tls/testdata/ca.crt new file mode 100644 index 0000000..59609aa --- /dev/null +++ b/util/tls/testdata/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUbD0mjn2x1H5VKi54xgjfPsEugzAwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQkcxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIyMTAyNzE1MDk0OFoXDTI3MTAyNzE1MDk0OFowWTELMAkGA1UEBhMCQkcxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAqRZyG9gywFvplpfPq60X7JYtkbj7OF+tOfixF9FkGDJ9Hhh4LoMv +W+rdLnxEHfI4Vicn0Wj3r/Ra1bd9yo8hzGiYdNliH9kegSbjB2Xx8N4yTCqaiZ9E +JQLcstQeXHEP3YwPXLnNfTOmbQPAbC2T9J+USlmolG1qpkuU5rQVC/sjW5M8MOmN +TuKEU6pds/j8GQKhQmsIHddwfypnBDilpYOotgeMwDqsyM4+zdSXbFaNmYoh3Tjb +FSN0WJTdPe7uv+nG03NZe6dHvN4C/8Z5uBIkeLw8yrO/2Wb7aAxwtpK3Wswzlya9 +TAwALiOI+hWuIfUmkoQapskIMhcRAnpMxwIDAQABo1MwUTAdBgNVHQ4EFgQUH2fo +QuQyEoszAL3vBOBqs8OLg7YwHwYDVR0jBBgwFoAUH2foQuQyEoszAL3vBOBqs8OL +g7YwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAYNY+8FNt9Pr/ +I6NYzY56opMIXmMErRMBN4i5WkuZxXp7gaqPNrx42O7gaWJz7dwwGUmrb0eHyOIm +XetphU6cW1vNgaXPOggacp2wRJ6AGJn1+v/hfWU6sPt0XWPM9p5umoCeYJeZ1UlE +uodsUGEQc7b/ODROObPHAFc/18nChoiPylXtB5TcgdyzalzhL/d8B/c4QZJwvT+W +L+8IoChNQzeH4yCgoDZXaQpRfrnGjLyrpx4dojNBYd/rJsGNZa+wMwzhFZU3f3QY +jeZnnp+nVBw+/L3q/FVwCee/RsYiR797OL7wyPAJPGd4iqbh09Hv0B2YgIih86X6 +giPwyRzg7g== +-----END CERTIFICATE----- diff --git a/util/tls/testdata/certificate.pem b/util/tls/testdata/certificate.pem new file mode 100644 index 0000000..7697626 --- /dev/null +++ b/util/tls/testdata/certificate.pem @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEjzCCAncCFAEWn/QU1vzg07IeNzhaX/6pNlTeMA0GCSqGSIb3DQEBCwUAMIGD +MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2Fu +IEZyYW5jaXNjbzEVMBMGA1UECgwMRXhhbXBsZSBJbmMuMRYwFAYDVQQLDA1JVCBE +ZXBhcnRtZW50MRgwFgYDVQQDDA93d3cuZXhhbXBsZS5jb20wHhcNMjEwNTEzMDg0 +NTU3WhcNMjExMDEwMDg0NTU3WjCBgzELMAkGA1UEBhMCVVMxFjAUBgNVBAgMDU1h +c3NhY2h1c2V0dHMxDzANBgNVBAcMBkJvc3RvbjEVMBMGA1UECgwMRXhhbXBsZSBJ +bmMuMRYwFAYDVQQLDA1IUiBEZXBhcnRtZW50MRwwGgYDVQQDDBN3d3cuZXhhbXBs +ZS1pbmMuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApmF7J4WB +BujYYNt86YHSWrv3ffX3Odt57y4kqhWuMM8VcA2RXTlJO3SXX8xLF/+lsaZZJmfg +xR1tJ0hKfBYt78H03bjrylWLOQoRqlyVuz0SF3ueR9vx3mPIO0F0E3Q2mC4SLHbr +5kW4aj3/NszLzgvZPbOchfcktdCd1vME+pM8lPPY6z8qWJzlWjOYdymWUV8z9qlA +c3VCnYQ+UwJY51vTcMPpalByrMPWkiNic+9Onl8KHCik27vNMIVVLYLa6763UwU8 +qCKT+jZj6nlCT8w5oqoJNsaSd/EGGFtGR+qZj6V2TrsI2cNknyY7Qf/QN0+zH0nm +XsOxK6jW8q7RGwIDAQABMA0GCSqGSIb3DQEBCwUAA4ICAQBe/kcT2L54PxZqb3GU +liwYJGjB+9fkqTyMwglt8dAm3it9F/POyXtoKB8a1AuaZ/FJlR+AUOFv+f3i0ZnE +Ek0OAsllVPclv7HhywD1HzrbLh0PreGsBnYgyrW7qZKAfevus0U0GrjhcrY7zCoA +EBFWWqcWqhRCFXYwgI13ZNLhYl7r+NIWLza1bPcnWVfY2g19/nctR53ZFFiVkvlk +FCYGat0SPQWvjFIKaCNQQL6IZSxqk95W87kEWrac9A+bQzENpWLfwu86O5r6vKNK +pHmd47Sy8hyhf75/SEOQuWBEkgT+sPXU7TykvFB8kzO1Wmsz7D2/d5pvkjZF31dQ +m4ZHIuclPOETtAwiY13dI94vAhgruK0FRFn7jyfePn20CFqUCOO9cEQytysCO/n7 +4xJPbIVcvUO825Kbos71OWfNkLEi1tlLkFpe73/rSXnZRWweqAThrY7jxGxhveI4 +iYrOOYEqGdM6VfLvVhYXhsc/MDqiqJLdbhQAS/lE8bJrnVnVRYnnb0ExfNTwqVU3 +8YZB6JbT+j9556c675j0sfa0J8qgDRHsWR7EG5u5wENnDH/s142dYcjJ5S6tCumM +K/GExzmrJFHmLfqKTSxAquMYvDVufICcklL067DacRJKHkyx5KvKkgc3R8rTaX2o +D+0ioElFJXVQ7rULiWzxZs+Gdw== +-----END CERTIFICATE----- diff --git a/util/tls/testdata/empty.crt b/util/tls/testdata/empty.crt new file mode 100644 index 0000000..e69de29 diff --git a/util/tls/testdata/invalid.pem b/util/tls/testdata/invalid.pem new file mode 100644 index 0000000..eeb0277 --- /dev/null +++ b/util/tls/testdata/invalid.pem @@ -0,0 +1 @@ +---invalid data--- \ No newline at end of file diff --git a/util/tls/testdata/key.pem b/util/tls/testdata/key.pem new file mode 100644 index 0000000..e0b52a6 --- /dev/null +++ b/util/tls/testdata/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApmF7J4WBBujYYNt86YHSWrv3ffX3Odt57y4kqhWuMM8VcA2R +XTlJO3SXX8xLF/+lsaZZJmfgxR1tJ0hKfBYt78H03bjrylWLOQoRqlyVuz0SF3ue +R9vx3mPIO0F0E3Q2mC4SLHbr5kW4aj3/NszLzgvZPbOchfcktdCd1vME+pM8lPPY +6z8qWJzlWjOYdymWUV8z9qlAc3VCnYQ+UwJY51vTcMPpalByrMPWkiNic+9Onl8K +HCik27vNMIVVLYLa6763UwU8qCKT+jZj6nlCT8w5oqoJNsaSd/EGGFtGR+qZj6V2 +TrsI2cNknyY7Qf/QN0+zH0nmXsOxK6jW8q7RGwIDAQABAoIBAH09m6qgQAOnellO +XrSW2HUcUKwsXjDbGOoF3et57mknOIfkbqux14I9vUSLT2t9MIiNI0ZZo0Q9ZlDP +heHqACId6eiMrlDcG7SP88Q9dShATEII95g349T3X13bYzjRndbntx5pViE8EhlH +GblyZ2duW9SqQwREiQmjQ2zt+a1zuKUAGdysS+4101UHUj6tC/RjiNN/TCXXRJIX +GOJ4WFLY+f2bXgSmqbK7wqN9nPxmxl/+bv4hO32Gsv06ejuy/6+GFJZa+n3ASU2r +/ptr4vgCK+t6I0OWVTpvYUboEwAXam4JfAu12zLtczfXVFJjzklUvSClIG9aJ6DP +2B3LEuECgYEA1RenCuVl6sM2o858X8iOLTMuCZWX5VFKY1+vvLwB/+CBfy7/dDYw +lbv+xaots0rY9Wn784ewi9zJbdnXE1YNgj0utIMHzylXvTolnDYsoO7SYpIwqtpa +PyzPcAV3Khkd5LGe1hf9VmOJfTF/563ztLXip0HUeIvzgB/maqfAQW8CgYEAx+H4 +GZ3ycdL03x7Zvp5g5yDPZGvxqVIEIFmliagEFyBgqogbXuOvsvZDTg+gKV5/QyVg +FWokz5VtC6U9UcWfF7LJses/Hsedh9IeICd9UIzkey8UmS2mDBeGfTHVTNuczml5 +VzmTK8jGUO5aRVRyOt0tqVf6Oozo8ImdI1fOPRUCgYBI4p4wC+agNcUqoiXIXUDE +FQ1aGeCqfvOCqefiFixY6OFiLyERDrfvfy3VTi/zc1ZiGq4izfaE4C/Fcw0tf/F+ +6o5fD7JMGUf5YTocBCufoBA1xur+hVD46srI9hWcQJsI7ff2Ip50Pfd46sVk6QrC +dLPhoZKa6MOQv1iAgoAv4QKBgQDAfiO6N9vmNizQOxujcU8NBxHzOekvEOccaHj9 +Cqt1wh6V3CHPziHEjVjf8jhh3rlcZsATn3b32oV7c5SMDW9bGTkYeN7+u2pABOAy +QxVx312iLAMASW/hsT45jyZFsDFgrz7F+5J51g72nbSdk+e2PI7eyPUYMd+a1kxY +XxUkyQKBgBde+jg2UjGHAfW6ZWSPHWvi74oD42VzbJg6/o/KFaZGEyHdKWgrwBW7 +8wNyNTMOMSHyuvtMnc5oLd492aO2a+yurgzt7yIelVuphSKda0wd/0PljKb0lwlb +VzldR4bXolbyjNXky1KWmvSNkmbPFNGSCeyHmKQgSLXKGuWbj4ic +-----END RSA PRIVATE KEY----- diff --git a/util/tls/tls_config.go b/util/tls/tls_config.go new file mode 100644 index 0000000..0c8ec05 --- /dev/null +++ b/util/tls/tls_config.go @@ -0,0 +1,60 @@ +// Copyright (c) 2023 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package tls + +import ( + "crypto/tls" + "crypto/x509" + "os" + + "github.com/pkg/errors" +) + +// NewTLSConfig initializes the TLS. +func NewTLSConfig(rootCert, cert, key string) (*tls.Config, error) { + caCert, err := os.ReadFile(rootCert) + if err != nil { + return nil, errors.Wrap(err, "failed to load CA") + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.Errorf("failed to parse CA %s", rootCert) + } + + tlsConfig := &tls.Config{ + InsecureSkipVerify: false, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + CipherSuites: supportedCipherSuites(), + } + + if len(cert) > 0 || len(key) > 0 { + cert, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, errors.Wrap(err, "failed to load X509 key pair") + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + return tlsConfig, nil +} + +func supportedCipherSuites() []uint16 { + cs := tls.CipherSuites() + cid := make([]uint16, len(cs)) + for i := range cs { + cid[i] = cs[i].ID + } + return cid +} diff --git a/util/tls/tls_config_test.go b/util/tls/tls_config_test.go new file mode 100644 index 0000000..1c8234b --- /dev/null +++ b/util/tls/tls_config_test.go @@ -0,0 +1,103 @@ +// Copyright (c) 2023 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + +package tls + +import ( + "crypto/tls" + "errors" + "fmt" + "path/filepath" + "testing" +) + +const ( + nonExisting = "nonexisting.test" + invalidFile = "testdata/invalid.pem" + caCertPath = "testdata/ca.crt" + certPath = "testdata/certificate.pem" + keyPath = "testdata/key.pem" + caError = "failed to load CA: open %s: no such file or directory" + keyPairError = "failed to load X509 key pair: open %s: no such file or directory" + invalidError = "failed to load X509 key pair: tls: failed to find any PEM data in key input" +) + +func TestNewTLSConfig(t *testing.T) { + dirAbsPath, _ := filepath.Abs("./") + + tests := map[string]struct { + CACert string + Cert string + Key string + ExpectedError error + }{ + "valid_config_with_credentials": {CACert: caCertPath, Cert: certPath, Key: keyPath, ExpectedError: nil}, + "valid_config_no_credentials": {CACert: caCertPath, Cert: "", Key: "", ExpectedError: nil}, + "no_files_provided": {CACert: "", Cert: "", Key: "", ExpectedError: fmt.Errorf(caError, "")}, + "non_existing_ca_file": {CACert: nonExisting, Cert: "", Key: "", ExpectedError: fmt.Errorf(caError, nonExisting)}, + "invalid_ca_file": {CACert: invalidFile, Cert: certPath, Key: keyPath, ExpectedError: fmt.Errorf("failed to parse CA %s", invalidFile)}, + "invalid_ca_file_arg": {CACert: "\\\000", Cert: certPath, Key: keyPath, ExpectedError: errors.New("failed to load CA: open \\\000: invalid argument")}, + "not_abs_cert_file_provided": {CACert: caCertPath, Cert: nonExisting, Key: "", ExpectedError: fmt.Errorf(keyPairError, nonExisting)}, + "cert_is_directory": {CACert: caCertPath, Cert: dirAbsPath, Key: "", ExpectedError: fmt.Errorf("failed to load X509 key pair: read %s: is a directory", dirAbsPath)}, + "no_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: "", ExpectedError: fmt.Errorf(keyPairError, "")}, + "not_abs_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: nonExisting, ExpectedError: fmt.Errorf(keyPairError, nonExisting)}, + "empty_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: "testdata/empty.crt", ExpectedError: errors.New(invalidError)}, + "cert_file_instead_key": {CACert: caCertPath, Cert: certPath, Key: caCertPath, ExpectedError: fmt.Errorf("failed to load X509 key pair: tls: found a certificate rather than a key in the PEM for the private key")}, + "invalid_key_file_provided": {CACert: caCertPath, Cert: certPath, Key: invalidFile, ExpectedError: errors.New(invalidError)}, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + cfg, err := NewTLSConfig(testCase.CACert, testCase.Cert, testCase.Key) + if testCase.ExpectedError != nil { + if testCase.ExpectedError.Error() != err.Error() { + t.Fatalf("expected error : %s, got: %s", testCase.ExpectedError, err) + } + if cfg != nil { + t.Fatalf("expected nil, got: %v", cfg) + } + } else { + if err != nil { + t.Fatal(err) + } + if len(cfg.Certificates) == 0 && testCase.Cert != "" && testCase.Key != "" { + t.Fatal("certificates length must not be 0") + } + if len(cfg.CipherSuites) == 0 { + t.Fatal("cipher suites length must not be 0") + } + // assert that cipher suites identifiers are contained in tls.CipherSuites + for _, csID := range cfg.CipherSuites { + if !func() bool { + for _, cs := range tls.CipherSuites() { + if cs.ID == csID { + return true + } + } + return false + }() { + t.Fatalf("cipher suite %d is not implemented", csID) + } + } + if cfg.InsecureSkipVerify { + t.Fatal("skip verify is set to true") + } + if cfg.MinVersion != tls.VersionTLS12 { + t.Fatalf("invalid min TLS version %d", cfg.MinVersion) + } + if cfg.MaxVersion != tls.VersionTLS13 { + t.Fatalf("invalid max TLS version %d", cfg.MaxVersion) + } + } + }) + } +}