Skip to content

Commit

Permalink
Improve TLS configuration in migrate
Browse files Browse the repository at this point in the history
commit_hash:1ce0d11c56ce96bcc35e63c10e8b64789afc5553
  • Loading branch information
MikailBag committed Feb 10, 2025
1 parent ec2c9f8 commit 08e87c5
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions perforator/cmd/migrate/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,36 @@ var migrationsClickhouse embed.FS
//go:embed migrations/postgres/*.sql
var migrationsPostgres embed.FS

type tlsConfig struct {
deprecatedInsecure bool
plaintext bool
skipTLSVerification bool
serverCA string
}

func (t *tlsConfig) validate(deprHandler func(string)) error {
if t.deprecatedInsecure {
deprHandler("--insecure is deprecated, use --plaintext or --tls-trust-all instead")
}
if t.plaintext && t.skipTLSVerification {
return errors.New("--plaintext and --tls-trust-all are mutually exclusive")
}
if t.plaintext && t.serverCA != "" {
return errors.New("--plaintext and --tls-ca are mutually exclusive")
}
if t.serverCA != "" && t.skipTLSVerification {
return errors.New("--tls-ca and --tls-trust-all are mutually exclusive")
}
return nil
}

var (
hosts []string
port uint16
database string
username string
password string
insecure bool
tls tlsConfig

rootCmd = &cobra.Command{
Use: "migrate",
Expand All @@ -72,7 +95,7 @@ func migrateCmdRunE(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse version %s: %w", args[0], err)
}

return runMigrations(dbByCmd(cmd), hosts, port, database, username, password, insecure, func(m *migrate.Migrate) error {
return runMigrations(dbByCmd(cmd), hosts, port, database, username, password, tls, func(m *migrate.Migrate) error {
return m.Migrate(uint(version))
})
}
Expand All @@ -83,13 +106,13 @@ func forceCmdRunE(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to parse version %s: %w", args[0], err)
}

return runMigrations(dbByCmd(cmd), hosts, port, database, username, password, insecure, func(m *migrate.Migrate) error {
return runMigrations(dbByCmd(cmd), hosts, port, database, username, password, tls, func(m *migrate.Migrate) error {
return m.Force(int(version))
})
}

func upCmdRunE(cmd *cobra.Command, _ []string) error {
return runMigrations(dbByCmd(cmd), hosts, port, database, username, password, insecure, func(m *migrate.Migrate) error {
return runMigrations(dbByCmd(cmd), hosts, port, database, username, password, tls, func(m *migrate.Migrate) error {
return m.Up()
})
}
Expand Down Expand Up @@ -133,7 +156,10 @@ func init() {
subcommand.Flags().StringVar(&database, "db", "perforator", "Database name")
subcommand.Flags().StringVar(&username, "user", "perforator", "Username")
subcommand.Flags().StringVar(&password, "pass", "", "Password")
subcommand.Flags().BoolVar(&insecure, "insecure", false, "Disable TLS")
subcommand.Flags().BoolVar(&tls.deprecatedInsecure, "insecure", false, "(Deprecated) disable transport security")
subcommand.Flags().BoolVar(&tls.plaintext, "plaintext", false, "Use plaintext connection")
subcommand.Flags().BoolVar(&tls.skipTLSVerification, "tls-trust-all", false, "Skip TLS verification")
subcommand.Flags().StringVar(&tls.serverCA, "tls-ca", "", "Path to CA certificate")
}
}
}
Expand All @@ -153,15 +179,19 @@ func runMigrations(
database string,
username string,
password string,
insecure bool,
tls tlsConfig,
callback func(*migrate.Migrate) error,
) error {
tlsErr := tls.validate(func(s string) { log.Printf("Warning: %s", s) })
if tlsErr != nil {
return fmt.Errorf("invalid tls configuration: %w", tlsErr)
}
errs := make([]error, 0)

log.Printf("Starting migrations")

for _, host := range hosts {
mig, err := newMigrate(db, host, port, database, username, password, insecure)
mig, err := newMigrate(db, host, port, database, username, password, tls)

if err != nil {
errs = append(errs, fmt.Errorf("failed to migrate host %s: %w", host, err))
Expand Down Expand Up @@ -191,7 +221,7 @@ func newMigrate(
database string,
username string,
password string,
insecure bool,
tls tlsConfig,
) (*migrate.Migrate, error) {
if port == 0 {
port = defaultPorts[db]
Expand Down Expand Up @@ -222,17 +252,48 @@ func newMigrate(
database,
)

var queryParams []string

switch db {
case Clickhouse:
uri += fmt.Sprintf("?x-multi-statement=true&secure=%s", strconv.FormatBool(!insecure))
queryParams = append(queryParams, "x-multi-statement=true")
if tls.deprecatedInsecure || tls.plaintext {
queryParams = append(queryParams, "secure=false")
} else {
queryParams = append(queryParams, "secure=true")
}
if tls.skipTLSVerification {
queryParams = append(queryParams, "skip_verify=false")
}
if tls.serverCA != "" {
return nil, errors.New("tls-ca is not supported for clickhouse")
}
case Postgres:
sslmode := "require"

if insecure {
if tls.deprecatedInsecure || tls.plaintext {
sslmode = "disable"
} else if tls.skipTLSVerification {
// TODO: this case looks broken in postgres
sslmode = "require"
} else {
if tls.serverCA == "" {
queryParams = append(queryParams, "sslrootcert=system")
} else {
queryParams = append(queryParams, "sslrootcert=", tls.serverCA)
}
}

uri += "?sslmode=" + sslmode
queryParams = append(queryParams, fmt.Sprint("sslmode=", sslmode))
}
if len(queryParams) > 0 {
uri += "?"
}
for i, qp := range queryParams {
if i > 0 {
uri += "&"
}
uri += qp
}

m, err := migrate.NewWithSourceInstance("iofs", d, uri)
Expand Down

0 comments on commit 08e87c5

Please sign in to comment.