diff --git a/pkg/broker/broker.go b/pkg/broker/broker.go index 0c98a79b..c0229a6f 100644 --- a/pkg/broker/broker.go +++ b/pkg/broker/broker.go @@ -77,8 +77,8 @@ func (d *ProvisioningSettings) ForService(service string) (*ServiceProvisioningS type MinibrokerClient interface { Init(repoURL string) error ListServices() ([]osb.Service, error) - Provision(instanceID, serviceID, planID, namespace string, acceptsIncomplete bool, provisionParams map[string]interface{}) (string, error) - Bind(instanceID, serviceID, bindingID string, acceptsIncomplete bool, bindParams map[string]interface{}) (string, error) + Provision(instanceID, serviceID, planID, namespace string, acceptsIncomplete bool, provisionParams *minibroker.ProvisionParams) (string, error) + Bind(instanceID, serviceID, bindingID string, acceptsIncomplete bool, bindParams *minibroker.BindParams) (string, error) Unbind(instanceID, bindingID string) error GetBinding(instanceID, bindingID string) (*osb.GetBindingResponse, error) Deprovision(instanceID string, acceptsIncomplete bool) (string, error) @@ -182,7 +182,14 @@ func (b *Broker) Provision(request *osb.ProvisionRequest, _ *broker.RequestConte params = request.Parameters } - operationName, err := b.client.Provision(request.InstanceID, request.ServiceID, request.PlanID, namespace, request.AcceptsIncomplete, params) + operationName, err := b.client.Provision( + request.InstanceID, + request.ServiceID, + request.PlanID, + namespace, + request.AcceptsIncomplete, + minibroker.NewProvisionParams(params), + ) if err != nil { klog.V(4).Infof("broker: failed to provision request %q: %v", request.InstanceID, err) return nil, err @@ -247,7 +254,13 @@ func (b *Broker) Bind(request *osb.BindRequest, _ *broker.RequestContext) (*brok b.Lock() defer b.Unlock() - operationName, err := b.client.Bind(request.InstanceID, request.ServiceID, request.BindingID, request.AcceptsIncomplete, request.Parameters) + operationName, err := b.client.Bind( + request.InstanceID, + request.ServiceID, + request.BindingID, + request.AcceptsIncomplete, + minibroker.NewBindParams(request.Parameters), + ) if err != nil { klog.V(4).Infof("broker: failed to bind %q: %v", request.InstanceID, err) return nil, err diff --git a/pkg/broker/broker_test.go b/pkg/broker/broker_test.go index b08a8499..48734382 100644 --- a/pkg/broker/broker_test.go +++ b/pkg/broker/broker_test.go @@ -27,6 +27,7 @@ import ( "github.com/kubernetes-sigs/minibroker/pkg/broker" "github.com/kubernetes-sigs/minibroker/pkg/broker/mocks" + "github.com/kubernetes-sigs/minibroker/pkg/minibroker" ) //go:generate mockgen -destination=./mocks/mock_broker.go -package=mocks github.com/kubernetes-sigs/minibroker/pkg/broker MinibrokerClient @@ -80,12 +81,12 @@ var _ = Describe("Broker", func() { Describe("Provision", func() { var ( - provisionParams = map[string]interface{}{ + provisionParams = minibroker.NewProvisionParams(map[string]interface{}{ "key": "value", - } + }) provisionRequest = &osb.ProvisionRequest{ ServiceID: "redis", - Parameters: provisionParams, + Parameters: provisionParams.Object, } requestContext = &osbbroker.RequestContext{} ) @@ -113,7 +114,7 @@ var _ = Describe("Broker", func() { provisionRequest.ServiceID = service provisioningSettings, found := provisioningSettings.ForService(service) Expect(found).To(BeTrue()) - params := provisioningSettings.OverrideParams + params := minibroker.NewProvisionParams(provisioningSettings.OverrideParams) mbclient.EXPECT(). Provision(gomock.Any(), gomock.Eq(service), gomock.Any(), gomock.Eq(namespace), gomock.Any(), gomock.Eq(params)) diff --git a/pkg/broker/mocks/mock_broker.go b/pkg/broker/mocks/mock_broker.go index 9f5af9c9..ee8ba924 100644 --- a/pkg/broker/mocks/mock_broker.go +++ b/pkg/broker/mocks/mock_broker.go @@ -6,6 +6,7 @@ package mocks import ( gomock "github.com/golang/mock/gomock" + minibroker "github.com/kubernetes-sigs/minibroker/pkg/minibroker" v2 "github.com/pmorie/go-open-service-broker-client/v2" reflect "reflect" ) @@ -34,7 +35,7 @@ func (m *MockMinibrokerClient) EXPECT() *MockMinibrokerClientMockRecorder { } // Bind mocks base method -func (m *MockMinibrokerClient) Bind(arg0, arg1, arg2 string, arg3 bool, arg4 map[string]interface{}) (string, error) { +func (m *MockMinibrokerClient) Bind(arg0, arg1, arg2 string, arg3 bool, arg4 *minibroker.BindParams) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Bind", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(string) @@ -138,7 +139,7 @@ func (mr *MockMinibrokerClientMockRecorder) ListServices() *gomock.Call { } // Provision mocks base method -func (m *MockMinibrokerClient) Provision(arg0, arg1, arg2, arg3 string, arg4 bool, arg5 map[string]interface{}) (string, error) { +func (m *MockMinibrokerClient) Provision(arg0, arg1, arg2, arg3 string, arg4 bool, arg5 *minibroker.ProvisionParams) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Provision", arg0, arg1, arg2, arg3, arg4, arg5) ret0, _ := ret[0].(string) diff --git a/pkg/minibroker/mariadb.go b/pkg/minibroker/mariadb.go index 876289b6..9ba905da 100644 --- a/pkg/minibroker/mariadb.go +++ b/pkg/minibroker/mariadb.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,26 @@ limitations under the License. package minibroker import ( + "fmt" + "net/url" + "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" ) -const mariadbProtocolName = "mysql" +const ( + mariadbProtocolName = "mysql" + rootMariadbUsername = "root" +) type MariadbProvider struct{} -func (p MariadbProvider) Bind(services []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) { +func (p MariadbProvider) Bind( + services []corev1.Service, + _ *BindParams, + provisionParams *ProvisionParams, + chartSecrets Object, +) (Object, error) { service := services[0] if len(service.Spec.Ports) == 0 { return nil, errors.Errorf("no ports found") @@ -34,58 +45,40 @@ func (p MariadbProvider) Bind(services []corev1.Service, params map[string]inter host := buildHostFromService(service) - dbParams, ok := params["db"].(map[string]interface{}) - if !ok { - dbParams = make(map[string]interface{}) + database, err := provisionParams.DigStringOr("db.name", "") + if err != nil { + return nil, fmt.Errorf("failed to get database name: %w", err) } - - database := "" - dbVal, ok := dbParams["name"] - if ok { - database, ok = dbVal.(string) - if !ok { - return nil, errors.Errorf("db.name not a string") - } + user, err := provisionParams.DigStringOr("db.user", rootMariadbUsername) + if err != nil { + return nil, fmt.Errorf("failed to get username: %w", err) } - var user, password string - userVal, ok := dbParams["user"] - if ok { - user, ok = userVal.(string) - if !ok { - return nil, errors.Errorf("db.user not a string") - } - - passwordVal, ok := chartSecrets["mariadb-password"] - if !ok { - return nil, errors.Errorf("mariadb-password not found in secret keys") - } - password, ok = passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + var passwordKey string + if user == rootMariadbUsername { + passwordKey = "mariadb-root-password" } else { - user = "root" - - rootPassword, ok := chartSecrets["mariadb-root-password"] - if !ok { - return nil, errors.Errorf("mariadb-root-password not found in secret keys") - } - password, ok = rootPassword.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + passwordKey = "mariadb-password" + } + password, err := chartSecrets.DigString(passwordKey) + if err != nil { + return nil, fmt.Errorf("failed to get password: %w", err) } - creds := Credentials{ - Protocol: mariadbProtocolName, - Port: svcPort.Port, - Host: host, - Username: user, - Password: password, - Database: database, + creds := Object{ + "protocol": mariadbProtocolName, + "port": svcPort.Port, + "host": host, + "username": user, + "password": password, + "database": database, + "uri": (&url.URL{ + Scheme: mariadbProtocolName, + User: url.UserPassword(user, password), + Host: fmt.Sprintf("%s:%d", host, svcPort.Port), + Path: database, + }).String(), } - creds.URI = buildURI(creds) - return &creds, nil + return creds, nil } diff --git a/pkg/minibroker/minibroker.go b/pkg/minibroker/minibroker.go index 9766869e..86776af5 100644 --- a/pkg/minibroker/minibroker.go +++ b/pkg/minibroker/minibroker.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -284,7 +284,7 @@ func (c *Client) ListServices() ([]osb.Service, error) { // Provision a new service instance. Returns the async operation key (if // acceptsIncomplete is set). -func (c *Client) Provision(instanceID, serviceID, planID, namespace string, acceptsIncomplete bool, provisionParams map[string]interface{}) (string, error) { +func (c *Client) Provision(instanceID, serviceID, planID, namespace string, acceptsIncomplete bool, provisionParams *ProvisionParams) (string, error) { klog.V(3).Infof("minibroker: provisioning intance %q, service %q, namespace %q, params %v", instanceID, serviceID, namespace, provisionParams) ctx := context.TODO() @@ -369,7 +369,7 @@ func (c *Client) Provision(instanceID, serviceID, planID, namespace string, acce } // provisionSynchronously will provision the service instance synchronously. -func (c *Client) provisionSynchronously(instanceID, namespace, serviceID, planID, chartName, chartVersion string, provisionParams map[string]interface{}) error { +func (c *Client) provisionSynchronously(instanceID, namespace, serviceID, planID, chartName, chartVersion string, provisionParams *ProvisionParams) error { klog.V(3).Infof("minibroker: provisioning %s/%s using helm chart %s@%s", serviceID, planID, chartName, chartVersion) chartDef, err := c.helm.GetChart(chartName, chartVersion) @@ -377,7 +377,7 @@ func (c *Client) provisionSynchronously(instanceID, namespace, serviceID, planID return err } - release, err := c.helm.ChartClient().Install(chartDef, namespace, provisionParams) + release, err := c.helm.ChartClient().Install(chartDef, namespace, provisionParams.Object) if err != nil { return err } @@ -394,7 +394,7 @@ func (c *Client) provisionSynchronously(instanceID, namespace, serviceID, planID return err } for _, service := range services.Items { - err := c.labelService(service, instanceID, provisionParams) + err := c.labelService(service, instanceID) if err != nil { return err } @@ -424,7 +424,7 @@ func (c *Client) provisionSynchronously(instanceID, namespace, serviceID, planID return nil } -func (c *Client) labelService(service corev1.Service, instanceID string, params map[string]interface{}) error { +func (c *Client) labelService(service corev1.Service, instanceID string) error { ctx := context.TODO() labeledService := service.DeepCopy() @@ -484,7 +484,7 @@ func (c *Client) labelSecret(secret corev1.Secret, instanceID string) error { // Bind the given service instance (of the given service) asynchronously; the // binding operation key is returned. -func (c *Client) Bind(instanceID, serviceID, bindingID string, acceptsIncomplete bool, bindParams map[string]interface{}) (string, error) { +func (c *Client) Bind(instanceID, serviceID, bindingID string, acceptsIncomplete bool, bindParams *BindParams) (string, error) { klog.V(3).Infof("minibroker: binding instance %q, service %q, binding %q, binding params %v", instanceID, serviceID, bindingID, bindParams) config, err := c.getConfigMap(instanceID) if err != nil { @@ -501,7 +501,7 @@ func (c *Client) Bind(instanceID, serviceID, bindingID string, acceptsIncomplete rawProvisionParams := config.Data[ProvisionParamsKey] operationName := generateOperationName(OperationPrefixBind) - var provisionParams map[string]interface{} + var provisionParams *ProvisionParams err = json.Unmarshal([]byte(rawProvisionParams), &provisionParams) if err != nil { return "", errors.Wrapf(err, "could not unmarshall provision parameters for instance %q", instanceID) @@ -510,14 +510,28 @@ func (c *Client) Bind(instanceID, serviceID, bindingID string, acceptsIncomplete if acceptsIncomplete { klog.V(3).Infof("minibroker: initializing asynchronous binding %q", bindingID) go func() { - _ = c.bindSynchronously(instanceID, serviceID, bindingID, releaseNamespace, bindParams, provisionParams) + _ = c.bindSynchronously( + instanceID, + serviceID, + bindingID, + releaseNamespace, + bindParams, + provisionParams, + ) klog.V(3).Infof("minibroker: asynchronously bound instance %q, service %q, binding %q", instanceID, serviceID, bindingID) }() return operationName, nil } klog.V(3).Infof("minibroker: initializing synchronous binding %q", bindingID) - if err = c.bindSynchronously(instanceID, serviceID, bindingID, releaseNamespace, bindParams, provisionParams); err != nil { + if err := c.bindSynchronously( + instanceID, + serviceID, + bindingID, + releaseNamespace, + bindParams, + provisionParams, + ); err != nil { return "", err } @@ -529,20 +543,18 @@ func (c *Client) Bind(instanceID, serviceID, bindingID string, acceptsIncomplete // bindSynchronously creates a new binding for the given service instance. All // results are only reported via the service instance configmap (under the // appropriate key for the binding) for lookup by LastBindingOperationState(). -func (c *Client) bindSynchronously(instanceID, serviceID, bindingID, releaseNamespace string, bindParams, provisionParams map[string]interface{}) error { +func (c *Client) bindSynchronously( + instanceID, + serviceID, + bindingID, + releaseNamespace string, + bindParams *BindParams, + provisionParams *ProvisionParams, +) error { ctx := context.TODO() // Wrap most of the code in an inner function to simplify error handling err := func() error { - // Smoosh all the params together - params := make(map[string]interface{}, len(bindParams)+len(provisionParams)) - for k, v := range provisionParams { - params[k] = v - } - for k, v := range bindParams { - params[k] = v - } - filterByInstance := metav1.ListOptions{ LabelSelector: labels.SelectorFromSet(map[string]string{ InstanceLabel: instanceID, @@ -569,7 +581,7 @@ func (c *Client) bindSynchronously(instanceID, serviceID, bindingID, releaseName return osb.HTTPStatusCodeError{StatusCode: http.StatusNotFound} } - data := make(map[string]interface{}) + data := make(Object) for _, secret := range secrets.Items { for key, value := range secret.Data { data[key] = string(value) @@ -579,11 +591,16 @@ func (c *Client) bindSynchronously(instanceID, serviceID, bindingID, releaseName // Apply additional provisioning logic for Service Catalog Enabled services provider, ok := c.providers[serviceID] if ok { - creds, err := provider.Bind(services.Items, params, data) + creds, err := provider.Bind( + services.Items, + bindParams, + provisionParams, + data, + ) if err != nil { return errors.Wrapf(err, "unable to bind instance %s", instanceID) } - for k, v := range creds.ToMap() { + for k, v := range creds { data[k] = v } } @@ -591,7 +608,7 @@ func (c *Client) bindSynchronously(instanceID, serviceID, bindingID, releaseName // Record the result for later fetching bindingResponse := osb.GetBindingResponse{ Credentials: data, - Parameters: bindParams, + Parameters: bindParams.Object, } bindingResponseJSON, err := json.Marshal(bindingResponse) if err != nil { diff --git a/pkg/minibroker/minibroker_test.go b/pkg/minibroker/minibroker_test.go index 31aeff31..186f833d 100644 --- a/pkg/minibroker/minibroker_test.go +++ b/pkg/minibroker/minibroker_test.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/pkg/minibroker/mongodb.go b/pkg/minibroker/mongodb.go index 20278617..e501a1a2 100644 --- a/pkg/minibroker/mongodb.go +++ b/pkg/minibroker/mongodb.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,26 @@ limitations under the License. package minibroker import ( + "fmt" + "net/url" + "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" ) -const mongodbProtocolName = "mongodb" +const ( + mongodbProtocolName = "mongodb" + rootMongodbUsername = "root" +) type MongodbProvider struct{} -func (p MongodbProvider) Bind(services []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) { +func (p MongodbProvider) Bind( + services []corev1.Service, + _ *BindParams, + provisionParams *ProvisionParams, + chartSecrets Object, +) (Object, error) { service := services[0] if len(service.Spec.Ports) == 0 { return nil, errors.Errorf("no ports found") @@ -34,53 +45,40 @@ func (p MongodbProvider) Bind(services []corev1.Service, params map[string]inter host := buildHostFromService(service) - database := "" - dbVal, ok := params["mongodbDatabase"] - if ok { - database, ok = dbVal.(string) - if !ok { - return nil, errors.Errorf("mongodbDatabase not a string") - } + database, err := provisionParams.DigStringOr("mongodbDatabase", "") + if err != nil { + return nil, fmt.Errorf("failed to get database name: %w", err) + } + user, err := provisionParams.DigStringOr("mongodbUsername", rootMongodbUsername) + if err != nil { + return nil, fmt.Errorf("failed to get username: %w", err) } - var user, password string - userVal, ok := params["mongodbUsername"] - if ok { - user, ok = userVal.(string) - if !ok { - return nil, errors.Errorf("mongodbUsername not a string") - } - - passwordVal, ok := chartSecrets["mongodb-password"] - if !ok { - return nil, errors.Errorf("mongodb-password not found in secret keys") - } - password, ok = passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + var passwordKey string + if user == rootMongodbUsername { + passwordKey = "mongodb-root-password" } else { - user = "root" - - rootPassword, ok := chartSecrets["mongodb-root-password"] - if !ok { - return nil, errors.Errorf("mongodb-root-password not found in secret keys") - } - password, ok = rootPassword.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + passwordKey = "mongodb-password" + } + password, err := chartSecrets.DigString(passwordKey) + if err != nil { + return nil, fmt.Errorf("failed to get password: %w", err) } - creds := Credentials{ - Protocol: mongodbProtocolName, - Port: svcPort.Port, - Host: host, - Username: user, - Password: password, - Database: database, + creds := Object{ + "protocol": mongodbProtocolName, + "port": svcPort.Port, + "host": host, + "username": user, + "password": password, + "database": database, + "uri": (&url.URL{ + Scheme: mongodbProtocolName, + User: url.UserPassword(user, password), + Host: fmt.Sprintf("%s:%d", host, svcPort.Port), + Path: database, + }).String(), } - creds.URI = buildURI(creds) - return &creds, nil + return creds, nil } diff --git a/pkg/minibroker/mysql.go b/pkg/minibroker/mysql.go index bc6d3958..5703cab5 100644 --- a/pkg/minibroker/mysql.go +++ b/pkg/minibroker/mysql.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,26 @@ limitations under the License. package minibroker import ( + "fmt" + "net/url" + "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" ) -const mysqlProtocolName = "mysql" +const ( + mysqlProtocolName = "mysql" + rootMysqlUsername = "root" +) type MySQLProvider struct{} -func (p MySQLProvider) Bind(services []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) { +func (p MySQLProvider) Bind( + services []corev1.Service, + _ *BindParams, + provisionParams *ProvisionParams, + chartSecrets Object, +) (Object, error) { service := services[0] if len(service.Spec.Ports) == 0 { return nil, errors.Errorf("no ports found") @@ -34,53 +45,40 @@ func (p MySQLProvider) Bind(services []corev1.Service, params map[string]interfa host := buildHostFromService(service) - database := "" - dbVal, ok := params["mysqlDatabase"] - if ok { - database, ok = dbVal.(string) - if !ok { - return nil, errors.Errorf("mysqlDatabase not a string") - } + database, err := provisionParams.DigStringOr("mysqlDatabase", "") + if err != nil { + return nil, fmt.Errorf("failed to get database name: %w", err) + } + user, err := provisionParams.DigStringOr("mysqlUser", rootMysqlUsername) + if err != nil { + return nil, fmt.Errorf("failed to get username: %w", err) } - var user, password string - userVal, ok := params["mysqlUser"] - if ok { - user, ok = userVal.(string) - if !ok { - return nil, errors.Errorf("mysqlUser not a string") - } - - passwordVal, ok := chartSecrets["mysql-password"] - if !ok { - return nil, errors.Errorf("mysql-password not found in secret keys") - } - password, ok = passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + var passwordKey string + if user == rootMysqlUsername { + passwordKey = "mysql-root-password" } else { - user = "root" - - rootPassword, ok := chartSecrets["mysql-root-password"] - if !ok { - return nil, errors.Errorf("mysql-root-password not found in secret keys") - } - password, ok = rootPassword.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + passwordKey = "mysql-password" + } + password, err := chartSecrets.DigString(passwordKey) + if err != nil { + return nil, fmt.Errorf("failed to get password: %w", err) } - creds := Credentials{ - Protocol: mysqlProtocolName, - Port: svcPort.Port, - Host: host, - Username: user, - Password: password, - Database: database, + creds := Object{ + "protocol": mysqlProtocolName, + "port": svcPort.Port, + "host": host, + "username": user, + "password": password, + "database": database, + "uri": (&url.URL{ + Scheme: mysqlProtocolName, + User: url.UserPassword(user, password), + Host: fmt.Sprintf("%s:%d", host, svcPort.Port), + Path: database, + }).String(), } - creds.URI = buildURI(creds) - return &creds, nil + return creds, nil } diff --git a/pkg/minibroker/postgres.go b/pkg/minibroker/postgres.go index d10aa618..82ec8d22 100644 --- a/pkg/minibroker/postgres.go +++ b/pkg/minibroker/postgres.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,15 +17,26 @@ limitations under the License. package minibroker import ( + "fmt" + "net/url" + "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" ) -const postgresqlProtocolName = "postgresql" +const ( + postgresqlProtocolName = "postgresql" + defaultPostgresqlUsername = "postgres" +) type PostgresProvider struct{} -func (p PostgresProvider) Bind(services []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) { +func (p PostgresProvider) Bind( + services []corev1.Service, + _ *BindParams, + provisionParams *ProvisionParams, + chartSecrets Object, +) (Object, error) { service := services[0] if len(service.Spec.Ports) == 0 { return nil, errors.Errorf("no ports found") @@ -34,72 +45,53 @@ func (p PostgresProvider) Bind(services []corev1.Service, params map[string]inte host := buildHostFromService(service) - dbVal, ok := params["postgresqlDatabase"] - if !ok { + database, err := provisionParams.DigStringAltOr( // Some older chart versions use postgresDatabase instead of postgresqlDatabase. - dbVal, ok = params["postgresDatabase"] - if !ok { - dbVal = "" - } - } - database, ok := dbVal.(string) - if !ok { - return nil, errors.Errorf("database not a string") + []string{"postgresqlDatabase", "postgresDatabase"}, + "", + ) + if err != nil { + return nil, fmt.Errorf("failed to get database name: %w", err) } - - userVal, ok := params["postgresqlUsername"] - if !ok { + user, err := provisionParams.DigStringAltOr( // Some older chart versions use postgresUsername instead of postgresqlUsername. - userVal, ok = params["postgresUsername"] - if !ok { - userVal = "postgres" - } - } - user, ok := userVal.(string) - if !ok { - return nil, errors.Errorf("username not a string") + []string{"postgresqlUsername", "postgresUsername"}, + defaultPostgresqlUsername, + ) + if err != nil { + return nil, fmt.Errorf("failed to get username: %w", err) } - var password string - if user != "postgres" { - // postgresql-postgres-password is used when postgresqlPostgresPassword is set and - // postgresqlUsername is not 'postgres'. - passwordVal, ok := chartSecrets["postgresql-postgres-password"] - if !ok { - passwordVal, ok = chartSecrets["postgresql-password"] - if !ok { - return nil, errors.Errorf("password not found in secret keys") - } - } - password, ok = passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + var passwordKey, altPasswordKey string + // postgresql-postgres-password is used when postgresqlPostgresPassword is set and + // postgresqlUsername is not 'postgres'. + if _, ok := provisionParams.Dig("postgresqlPostgresPassword"); ok && user != defaultPostgresqlUsername { + passwordKey = "postgresql-postgres-password" } else { - passwordVal, ok := chartSecrets["postgresql-password"] - if !ok { - // Chart versions <2.0 use postgres-password instead of postgresql-password. - // See https://github.com/kubernetes-sigs/minibroker/issues/17 - passwordVal, ok = chartSecrets["postgres-password"] - if !ok { - return nil, errors.Errorf("password not found in secret keys") - } - } - password, ok = passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") - } + passwordKey = "postgresql-password" + // Chart versions <2.0 use postgres-password instead of postgresql-password. + // See https://github.com/kubernetes-sigs/minibroker/issues/17 + altPasswordKey = "postgres-password" + } + password, err := chartSecrets.DigStringAlt([]string{passwordKey, altPasswordKey}) + if err != nil { + return nil, fmt.Errorf("failed to get password: %w", err) } - creds := Credentials{ - Protocol: postgresqlProtocolName, - Port: svcPort.Port, - Host: host, - Username: user, - Password: password, - Database: database, + creds := Object{ + "protocol": postgresqlProtocolName, + "port": svcPort.Port, + "host": host, + "username": user, + "password": password, + "database": database, + "uri": (&url.URL{ + Scheme: postgresqlProtocolName, + User: url.UserPassword(user, password), + Host: fmt.Sprintf("%s:%d", host, svcPort.Port), + Path: database, + }).String(), } - creds.URI = buildURI(creds) - return &creds, nil + return creds, nil } diff --git a/pkg/minibroker/provider.go b/pkg/minibroker/provider.go index bcdaf19e..80e1ef56 100644 --- a/pkg/minibroker/provider.go +++ b/pkg/minibroker/provider.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,53 +17,135 @@ limitations under the License. package minibroker import ( - "encoding/json" "fmt" + "strings" corev1 "k8s.io/api/core/v1" ) +// Provider is the interface for the Service Provider. Its methods wrap service-specific logic. type Provider interface { - Bind(service []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) + Bind( + service []corev1.Service, + bindParams *BindParams, + provisionParams *ProvisionParams, + chartSecrets Object, + ) (Object, error) } -type Credentials struct { - Protocol string - URI string `json:"uri,omitempty"` - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` - Host string `json:"host,omitempty"` - Port int32 `json:"port,omitempty"` - Database string `json:"database,omitempty"` +// Object is a wrapper around map[string]interface{} that implements methods for helping with +// digging and type asserting. +type Object map[string]interface{} + +var ( + // ErrDigNotFound is the error for a key not found in the Object. + ErrDigNotFound = fmt.Errorf("key not found") + // ErrDigNotString is the error for a key that is not a string. + ErrDigNotString = fmt.Errorf("key is not a string") +) + +// Dig digs the Object based on the provided key. +// key must be in the format "foo.bar.baz". Each segment represents a level in the Object. +func (o Object) Dig(key string) (interface{}, bool) { + if key == "" { + return nil, false + } + keyParts := strings.Split(key, ".") + var part interface{} = o + var ok bool + for _, keyPart := range keyParts { + if keyPart == "" { + return nil, false + } + switch p := part.(type) { + case map[string]interface{}: + if part, ok = p[keyPart]; !ok { + return nil, false + } + case Object: + if part, ok = p[keyPart]; !ok { + return nil, false + } + default: + return nil, false + } + } + return part, ok +} + +// DigString wraps Object.Dig and type-asserts the found key. +func (o Object) DigString(key string) (string, error) { + val, ok := o.Dig(key) + if !ok { + return "", ErrDigNotFound + } + valStr, ok := val.(string) + if !ok { + return "", ErrDigNotString + } + return valStr, nil +} + +// DigStringAlt digs for any of the given keys, returning the first found. It returns an error if +// none of the alternative keys are found. +func (o Object) DigStringAlt(altKeys []string) (string, error) { + for _, altKey := range altKeys { + valStr, err := o.DigString(altKey) + if err == ErrDigNotFound { + continue + } + if err != nil { + return "", err + } + return valStr, nil + } + return "", ErrDigNotFound } -// ToMap converts the credentials into the OSB API credentials response -// see https://github.com/openservicebrokerapi/servicebroker/blob/master/spec.md#device-object -// { -// "credentials": { -// "uri": "mysql://mysqluser:pass@mysqlhost:3306/dbname", -// "username": "mysqluser", -// "password": "pass", -// "host": "mysqlhost", -// "port": 3306, -// "database": "dbname" -// } -// } -func (c Credentials) ToMap() map[string]interface{} { - var result map[string]interface{} - j, _ := json.Marshal(c) - json.Unmarshal(j, &result) - return result +// DigStringOr wraps Object.DigString and returns defaultValue if the value was not found. +func (o Object) DigStringOr(key string, defaultValue string) (string, error) { + str, err := o.DigString(key) + if err == ErrDigNotFound { + return defaultValue, nil + } + if err != nil { + return "", err + } + return str, nil } -func buildURI(c Credentials) string { - if c.Database == "" { - return fmt.Sprintf("%s://%s:%s@%s:%d", - c.Protocol, c.Username, c.Password, c.Host, c.Port) +// DigStringAltOr wraps Object.DigStringAlt and returns defaultValue if none of the alternative +// keys are found. +func (o Object) DigStringAltOr(altKeys []string, defaultValue string) (string, error) { + str, err := o.DigStringAlt(altKeys) + if err == ErrDigNotFound { + return defaultValue, nil } + if err != nil { + return "", err + } + return str, nil +} + +// BindParams is a specialization of Object for binding parameters, ensuring type checking. +type BindParams struct { + Object +} + +// NewBindParams constructs a new BindParams. +func NewBindParams(m map[string]interface{}) *BindParams { + return &BindParams{Object: m} +} + +// ProvisionParams is a specialization of Object for provisioning parameters, ensuring type +// checking. +type ProvisionParams struct { + Object +} - return fmt.Sprintf("%s://%s:%s@%s:%d/%s", - c.Protocol, c.Username, c.Password, c.Host, c.Port, c.Database) +// NewProvisionParams constructs a new ProvisionParams. +func NewProvisionParams(m map[string]interface{}) *ProvisionParams { + return &ProvisionParams{Object: m} } func buildHostFromService(service corev1.Service) string { diff --git a/pkg/minibroker/provider_test.go b/pkg/minibroker/provider_test.go new file mode 100644 index 00000000..c17ca490 --- /dev/null +++ b/pkg/minibroker/provider_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package minibroker + +import "testing" + +func TestObjectDig(t *testing.T) { + tests := []struct { + obj Object + key string + expectedVal interface{} + expectedOk bool + }{ + { + Object{"foo": "baz"}, + "bar", + nil, + false, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "foo.foo", + nil, + false, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "foo.bar.bar", + nil, + false, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "", + nil, + false, + }, + { + Object{"foo": Object{"": "baz"}}, + "foo.", + nil, + false, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "foo.", + nil, + false, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "foo..bar", + nil, + false, + }, + { + Object{"": Object{"bar": "baz"}}, + ".bar", + nil, + false, + }, + { + Object{"foo": "baz"}, + "foo", + "baz", + true, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "foo.bar", + "baz", + true, + }, + } + + for _, tt := range tests { + val, ok := tt.obj.Dig(tt.key) + if ok != tt.expectedOk { + t.Errorf("Object.Dig(%s): expected ok %v, actual ok %v", tt.key, tt.expectedOk, ok) + } + if val != tt.expectedVal { + t.Errorf("Object.Dig(%s): expected val %v, actual val %v", tt.key, tt.expectedVal, val) + } + } +} + +func TestObjectDigString(t *testing.T) { + tests := []struct { + obj Object + key string + expectedVal string + expectedErr error + }{ + { + Object{"foo": "baz"}, + "bar", + "", + ErrDigNotFound, + }, + { + Object{"foo": 3}, + "foo", + "", + ErrDigNotString, + }, + { + Object{"foo": Object{"bar": "baz"}}, + "foo.bar", + "baz", + nil, + }, + } + + for _, tt := range tests { + val, err := tt.obj.DigString(tt.key) + if err != tt.expectedErr { + t.Errorf("Object.DigString(%s): expected err %v, actual err %v", tt.key, tt.expectedErr, err) + } + if val != tt.expectedVal { + t.Errorf("Object.DigString(%s): expected val %v, actual val %v", tt.key, tt.expectedVal, val) + } + } +} + +func TestObjectDigStringAlt(t *testing.T) { + tests := []struct { + obj Object + altKeys []string + expectedVal string + expectedErr error + }{ + { + Object{"foo": "baz"}, + []string{"bar", "baz"}, + "", + ErrDigNotFound, + }, + { + Object{"foo": 3}, + []string{"bar", "foo"}, + "", + ErrDigNotString, + }, + { + Object{"foo": Object{"bar": "baz"}}, + []string{"foo", "foo.bar"}, + "", + ErrDigNotString, + }, + { + Object{"foo": Object{"bar": "baz"}}, + []string{"foo.foo", "foo.bar"}, + "baz", + nil, + }, + } + + for _, tt := range tests { + val, err := tt.obj.DigStringAlt(tt.altKeys) + if err != tt.expectedErr { + t.Errorf("Object.DigStringAlt(%v): expected err %v, actual err %v", tt.altKeys, tt.expectedErr, err) + } + if val != tt.expectedVal { + t.Errorf("Object.DigStringAlt(%v): expected val %v, actual val %v", tt.altKeys, tt.expectedVal, val) + } + } +} + +func TestObjectDigStringOr(t *testing.T) { + tests := []struct { + obj Object + key string + defaultVal string + expectedVal string + expectedErr error + }{ + { + Object{"foo": 1}, + "foo", + "default", + "", + ErrDigNotString, + }, + { + Object{"foo": "baz"}, + "bar", + "default", + "default", + nil, + }, + { + Object{"foo": "baz"}, + "foo", + "default", + "baz", + nil, + }, + } + + for _, tt := range tests { + val, err := tt.obj.DigStringOr(tt.key, tt.defaultVal) + if err != tt.expectedErr { + t.Errorf("Object.DigStringOr(%s): expected err %v, actual err %v", tt.key, tt.expectedErr, err) + } + if val != tt.expectedVal { + t.Errorf("Object.DigStringOr(%s): expected val %v, actual val %v", tt.key, tt.expectedVal, val) + } + } +} + +func TestObjectDigStringAltOr(t *testing.T) { + tests := []struct { + obj Object + altKeys []string + defaultVal string + expectedVal string + expectedErr error + }{ + { + Object{"foo": 1}, + []string{"bar", "foo"}, + "default", + "", + ErrDigNotString, + }, + { + Object{"foo": "baz"}, + []string{"bar", "baz"}, + "default", + "default", + nil, + }, + { + Object{"foo": "baz"}, + []string{"bar", "foo"}, + "default", + "baz", + nil, + }, + } + + for _, tt := range tests { + val, err := tt.obj.DigStringAltOr(tt.altKeys, tt.defaultVal) + if err != tt.expectedErr { + t.Errorf("Object.DigStringAltOr(%v): expected err %v, actual err %v", tt.altKeys, tt.expectedErr, err) + } + if val != tt.expectedVal { + t.Errorf("Object.DigStringAltOr(%v): expected val %v, actual val %v", tt.altKeys, tt.expectedVal, val) + } + } +} diff --git a/pkg/minibroker/rabbitmq.go b/pkg/minibroker/rabbitmq.go index b4366003..e9b3bec8 100644 --- a/pkg/minibroker/rabbitmq.go +++ b/pkg/minibroker/rabbitmq.go @@ -24,67 +24,58 @@ import ( corev1 "k8s.io/api/core/v1" ) -const amqpProtocolName = "amqp" +const ( + amqpProtocolName = "amqp" + defaultRabbitmqUsername = "user" +) type RabbitmqProvider struct{} -func (p RabbitmqProvider) Bind(services []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) { +func (p RabbitmqProvider) Bind( + services []corev1.Service, + _ *BindParams, + provisionParams *ProvisionParams, + chartSecrets Object, +) (Object, error) { if len(services) == 0 { return nil, errors.Errorf("no services to process") } service := services[0] - var amqpPort *corev1.ServicePort + var svcPort *corev1.ServicePort for _, port := range service.Spec.Ports { if port.Name == amqpProtocolName { - amqpPort = &port + svcPort = &port break } } - if amqpPort == nil { + if svcPort == nil { return nil, errors.Errorf("no amqp port found") } - rabbitmqParams, ok := params["rabbitmq"].(map[string]interface{}) - if !ok { - rabbitmqParams = make(map[string]interface{}) + user, err := provisionParams.DigStringOr("rabbitmq.username", defaultRabbitmqUsername) + if err != nil { + return nil, fmt.Errorf("failed to get username: %w", err) } - var username string - usernameVal, ok := rabbitmqParams["username"] - if ok { - username, ok = usernameVal.(string) - if !ok { - return nil, errors.Errorf("username not a string") - } - } else { - username = "user" - } - - passwordVal, ok := chartSecrets["rabbitmq-password"] - if !ok { - return nil, errors.Errorf("password not found in secret keys") - } - password, ok := passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") + password, err := chartSecrets.DigString("rabbitmq-password") + if err != nil { + return nil, fmt.Errorf("failed to get password: %w", err) } host := buildHostFromService(service) - creds := Credentials{ - Protocol: amqpProtocolName, - Port: amqpPort.Port, - Host: host, - Username: username, - Password: password, - Database: "/", + creds := Object{ + "protocol": amqpProtocolName, + "port": svcPort.Port, + "host": host, + "username": user, + "password": password, + "uri": (&url.URL{ + Scheme: amqpProtocolName, + User: url.UserPassword(user, password), + Host: fmt.Sprintf("%s:%d", host, svcPort.Port), + }).String(), } - creds.URI = buildRabbitmqURI(creds) - - return &creds, nil -} -func buildRabbitmqURI(c Credentials) string { - return fmt.Sprintf("%s://%s:%s@%s:%d/%s", - c.Protocol, c.Username, c.Password, c.Host, c.Port, url.QueryEscape(c.Database)) + return creds, nil } diff --git a/pkg/minibroker/redis.go b/pkg/minibroker/redis.go index 9e9892a9..49d6d580 100644 --- a/pkg/minibroker/redis.go +++ b/pkg/minibroker/redis.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Kubernetes Authors. +Copyright 2020 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,9 @@ limitations under the License. package minibroker import ( + "fmt" + "net/url" + "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" ) @@ -25,7 +28,12 @@ const redisProtocolName = "redis" type RedisProvider struct{} -func (p RedisProvider) Bind(services []corev1.Service, params map[string]interface{}, chartSecrets map[string]interface{}) (*Credentials, error) { +func (p RedisProvider) Bind( + services []corev1.Service, + _ *BindParams, + _ *ProvisionParams, + chartSecrets Object, +) (Object, error) { var masterSvc *corev1.Service for _, svc := range services { if svc.Spec.Selector["role"] == "master" { @@ -44,23 +52,22 @@ func (p RedisProvider) Bind(services []corev1.Service, params map[string]interfa host := buildHostFromService(*masterSvc) - var password string - passwordVal, ok := chartSecrets["redis-password"] - if !ok { - return nil, errors.Errorf("redis-password not found in secret keys") - } - password, ok = passwordVal.(string) - if !ok { - return nil, errors.Errorf("password not a string") + password, err := chartSecrets.DigString("redis-password") + if err != nil { + return nil, fmt.Errorf("failed to get password: %w", err) } - creds := Credentials{ - Protocol: redisProtocolName, - Port: svcPort.Port, - Host: host, - Password: password, + creds := Object{ + "protocol": redisProtocolName, + "port": svcPort.Port, + "host": host, + "password": password, + "uri": (&url.URL{ + Scheme: redisProtocolName, + User: url.UserPassword("", password), + Host: fmt.Sprintf("%s:%d", host, svcPort.Port), + }).String(), } - creds.URI = buildURI(creds) - return &creds, nil + return creds, nil }