Skip to content

Commit

Permalink
Refactor node-join script to take safer options and reuse install opt…
Browse files Browse the repository at this point in the history
…ion logic (#52196)

* Add install script using teleport-update and oneoff.sh

* Refactor node-join script to take safer options and reuse install option logic

* GoDoc + make functions private

* Address edoardo's feedback
  • Loading branch information
hugoShaka authored Feb 21, 2025
1 parent 317860d commit eadbe2d
Show file tree
Hide file tree
Showing 6 changed files with 475 additions and 392 deletions.
2 changes: 1 addition & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2250,7 +2250,7 @@ func (h *Handler) installer(w http.ResponseWriter, r *http.Request, p httprouter
// https://updates.releases.teleport.dev/v1/stable/cloud/version
installUpdater := automaticUpgrades(*ping.ServerFeatures)
if installUpdater {
repoChannel = stableCloudChannelRepo
repoChannel = automaticupgrades.DefaultCloudChannelName
}
azureClientID := r.URL.Query().Get("azure-client-id")

Expand Down
1 change: 1 addition & 0 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3643,6 +3643,7 @@ func TestKnownWebPathsWithAndWithoutV1Prefix(t *testing.T) {

func TestInstallDatabaseScriptGeneration(t *testing.T) {
const username = "[email protected]"
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildCommunity})

// Users should be able to create Tokens even if they can't update them
roleTokenCRD, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{
Expand Down
254 changes: 52 additions & 202 deletions lib/web/join_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,26 @@
package web

import (
"bytes"
"context"
"encoding/hex"
"fmt"
"hash/fnv"
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"time"

"github.com/google/safetext/shsprintf"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
"k8s.io/apimachinery/pkg/util/validation"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/ui"
Expand All @@ -55,8 +48,7 @@ import (
)

const (
stableCloudChannelRepo = "stable/cloud"
HeaderTokenName = "X-Teleport-TokenName"
HeaderTokenName = "X-Teleport-TokenName"
)

// nodeJoinToken contains node token fields for the UI.
Expand All @@ -80,15 +72,9 @@ type scriptSettings struct {
appURI string
joinMethod string
databaseInstallMode bool
installUpdater bool

discoveryInstallMode bool
discoveryGroup string

// automaticUpgradesVersion is the target automatic upgrades version.
// The version must be valid semver, with the leading 'v'. e.g. v15.0.0-dev
// Required when installUpdater is true.
automaticUpgradesVersion string
}

// automaticUpgrades returns whether automaticUpgrades should be enabled.
Expand Down Expand Up @@ -377,41 +363,16 @@ func (h *Handler) createTokenForDiscoveryHandle(w http.ResponseWriter, r *http.R
}, nil
}

// getAutoUpgrades checks if automaticUpgrades are enabled and returns the
// version that should be used according to auto upgrades default channel.
// If something bad happens, the error is logged and the function falls back to
// the process Teleport version.
func (h *Handler) getAutoUpgrades(ctx context.Context) (bool, string) {
var autoUpgradesVersion string
var err error
autoUpgrades := automaticUpgrades(h.GetClusterFeatures())
if autoUpgrades {
const group, updaterUUID = "", ""
autoUpgradesVersion, err = h.autoUpdateAgentVersion(ctx, group, updaterUUID)
if err != nil {
h.logger.WarnContext(ctx, "Failed to get auto upgrades version, falling back to self version.", "error", err)
return autoUpgrades, teleport.Version
}
autoUpgradesVersion = fmt.Sprintf("v%s", autoUpgradesVersion)
}
return autoUpgrades, autoUpgradesVersion

}

func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) {
httplib.SetScriptHeaders(w.Header())

autoUpgrades, autoUpgradesVersion := h.getAutoUpgrades(r.Context())

settings := scriptSettings{
token: params.ByName("token"),
appInstallMode: false,
joinMethod: r.URL.Query().Get("method"),
installUpdater: autoUpgrades,
automaticUpgradesVersion: autoUpgradesVersion,
token: params.ByName("token"),
appInstallMode: false,
joinMethod: r.URL.Query().Get("method"),
}

script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
script, err := h.getJoinScript(r.Context(), settings)
if err != nil {
h.logger.InfoContext(r.Context(), "Failed to return the node install script", "error", err)
w.Write(scripts.ErrorBashScript)
Expand Down Expand Up @@ -451,18 +412,14 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request,
return nil, nil
}

autoUpgrades, autoUpgradesVersion := h.getAutoUpgrades(r.Context())

settings := scriptSettings{
token: params.ByName("token"),
appInstallMode: true,
appName: name,
appURI: uri,
installUpdater: autoUpgrades,
automaticUpgradesVersion: autoUpgradesVersion,
token: params.ByName("token"),
appInstallMode: true,
appName: name,
appURI: uri,
}

script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
script, err := h.getJoinScript(r.Context(), settings)
if err != nil {
h.logger.InfoContext(r.Context(), "Failed to return the app install script", "error", err)
w.Write(scripts.ErrorBashScript)
Expand All @@ -481,16 +438,12 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request,
func (h *Handler) getDatabaseJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) {
httplib.SetScriptHeaders(w.Header())

autoUpgrades, autoUpgradesVersion := h.getAutoUpgrades(r.Context())

settings := scriptSettings{
token: params.ByName("token"),
databaseInstallMode: true,
installUpdater: autoUpgrades,
automaticUpgradesVersion: autoUpgradesVersion,
token: params.ByName("token"),
databaseInstallMode: true,
}

script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
script, err := h.getJoinScript(r.Context(), settings)
if err != nil {
h.logger.InfoContext(r.Context(), "Failed to return the database install script", "error", err)
w.Write(scripts.ErrorBashScript)
Expand All @@ -511,8 +464,6 @@ func (h *Handler) getDiscoveryJoinScriptHandle(w http.ResponseWriter, r *http.Re
queryValues := r.URL.Query()
const discoveryGroupQueryParam = "discoveryGroup"

autoUpgrades, autoUpgradesVersion := h.getAutoUpgrades(r.Context())

discoveryGroup, err := url.QueryUnescape(queryValues.Get(discoveryGroupQueryParam))
if err != nil {
h.logger.DebugContext(r.Context(), "Failed to return the discovery install script",
Expand All @@ -531,14 +482,12 @@ func (h *Handler) getDiscoveryJoinScriptHandle(w http.ResponseWriter, r *http.Re
}

settings := scriptSettings{
token: params.ByName("token"),
discoveryInstallMode: true,
discoveryGroup: discoveryGroup,
installUpdater: autoUpgrades,
automaticUpgradesVersion: autoUpgradesVersion,
token: params.ByName("token"),
discoveryInstallMode: true,
discoveryGroup: discoveryGroup,
}

script, err := getJoinScript(r.Context(), settings, h.GetProxyClient())
script, err := h.getJoinScript(r.Context(), settings)
if err != nil {
h.logger.InfoContext(r.Context(), "Failed to return the discovery install script", "error", err)
w.Write(scripts.ErrorBashScript)
Expand All @@ -554,8 +503,9 @@ func (h *Handler) getDiscoveryJoinScriptHandle(w http.ResponseWriter, r *http.Re
return nil, nil
}

func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter) (string, error) {
switch types.JoinMethod(settings.joinMethod) {
func (h *Handler) getJoinScript(ctx context.Context, settings scriptSettings) (string, error) {
joinMethod := types.JoinMethod(settings.joinMethod)
switch joinMethod {
case types.JoinMethodUnspecified, types.JoinMethodToken:
if err := validateJoinToken(settings.token); err != nil {
return "", trace.Wrap(err)
Expand All @@ -565,141 +515,55 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter
return "", trace.BadParameter("join method %q is not supported via script", settings.joinMethod)
}

clt := h.GetProxyClient()

// The provided token can be attacker controlled, so we must validate
// it with the backend before using it to generate the script.
token, err := m.GetToken(ctx, settings.token)
token, err := clt.GetToken(ctx, settings.token)
if err != nil {
return "", trace.BadParameter("invalid token")
}

// Get hostname and port from proxy server address.
proxyServers, err := m.GetProxies()
if err != nil {
return "", trace.Wrap(err)
}

if len(proxyServers) == 0 {
return "", trace.NotFound("no proxy servers found")
}

version := proxyServers[0].GetTeleportVersion()

publicAddr := proxyServers[0].GetPublicAddr()
if publicAddr == "" {
return "", trace.Errorf("proxy public_addr is not set, you must set proxy_service.public_addr to the publicly reachable address of the proxy before you can generate a node join script")
}

hostname, portStr, err := utils.SplitHostPort(publicAddr)
if err != nil {
return "", trace.Wrap(err)
}
// TODO(hugoShaka): hit the local accesspoint which has a cache instead of asking the auth every time.

// Get the CA pin hashes of the cluster to join.
localCAResponse, err := m.GetClusterCACert(ctx)
localCAResponse, err := clt.GetClusterCACert(ctx)
if err != nil {
return "", trace.Wrap(err)
}

caPins, err := tlsca.CalculatePins(localCAResponse.TLSCA)
if err != nil {
return "", trace.Wrap(err)
}

labelsList := []string{}
for labelKey, labelValues := range token.GetSuggestedLabels() {
labels := strings.Join(labelValues, " ")
labelsList = append(labelsList, fmt.Sprintf("%s=%s", labelKey, labels))
}

var dbServiceResourceLabels []string
if settings.databaseInstallMode {
suggestedAgentMatcherLabels := token.GetSuggestedAgentMatcherLabels()
dbServiceResourceLabels, err = scripts.MarshalLabelsYAML(suggestedAgentMatcherLabels, 6)
if err != nil {
return "", trace.Wrap(err)
}
}

var buf bytes.Buffer
var appServerResourceLabels []string
// If app install mode is requested but parameters are blank for some reason,
// we need to return an error.
if settings.appInstallMode {
if errs := validation.IsDNS1035Label(settings.appName); len(errs) > 0 {
return "", trace.BadParameter("appName %q must be a valid DNS subdomain: https://goteleport.com/docs/enroll-resources/application-access/guides/connecting-apps/#application-name", settings.appName)
}
if !appURIPattern.MatchString(settings.appURI) {
return "", trace.BadParameter("appURI %q contains invalid characters", settings.appURI)
}

suggestedLabels := token.GetSuggestedLabels()
appServerResourceLabels, err = scripts.MarshalLabelsYAML(suggestedLabels, 4)
if err != nil {
return "", trace.Wrap(err)
}
}

if settings.discoveryInstallMode {
if settings.discoveryGroup == "" {
return "", trace.BadParameter("discovery group is required")
}
}

packageName := types.PackageNameOSS
if modules.GetModules().BuildType() == modules.BuildEnterprise {
packageName = types.PackageNameEnt
}

// By default, it will use `stable/v<majorVersion>`, eg stable/v12
repoChannel := ""

// The install script will install the updater (teleport-ent-updater) for Cloud customers enrolled in Automatic Upgrades.
// The repo channel used must be `stable/cloud` which has the available packages for the Cloud Customer's agents.
// It pins the teleport version to the one specified by the default version channel
// This ensures the initial installed version is the same as the `teleport-ent-updater` would install.
if settings.installUpdater {
if settings.automaticUpgradesVersion == "" {
return "", trace.Wrap(err, "automatic upgrades version must be set when installUpdater is true")
}

repoChannel = stableCloudChannelRepo
// automaticUpgradesVersion has vX.Y.Z format, however the script
// expects the version to not include the `v` so we strip it
version = strings.TrimPrefix(settings.automaticUpgradesVersion, "v")
}

// This section relies on Go's default zero values to make sure that the settings
// are correct when not installing an app.
err = scripts.InstallNodeBashScript.Execute(&buf, map[string]interface{}{
"token": settings.token,
"hostname": hostname,
"port": portStr,
// The install.sh script has some manually generated configs and some
// generated by the `teleport <service> config` commands. The old bash
// version used space delimited values whereas the teleport command uses
// a comma delimeter. The Old version can be removed when the install.sh
// file has been completely converted over.
"caPinsOld": strings.Join(caPins, " "),
"caPins": strings.Join(caPins, ","),
"packageName": packageName,
"repoChannel": repoChannel,
"installUpdater": strconv.FormatBool(settings.installUpdater),
"version": shsprintf.EscapeDefaultContext(version),
"appInstallMode": strconv.FormatBool(settings.appInstallMode),
"appServerResourceLabels": appServerResourceLabels,
"appName": shsprintf.EscapeDefaultContext(settings.appName),
"appURI": shsprintf.EscapeDefaultContext(settings.appURI),
"joinMethod": shsprintf.EscapeDefaultContext(settings.joinMethod),
"labels": strings.Join(labelsList, ","),
"databaseInstallMode": strconv.FormatBool(settings.databaseInstallMode),
"db_service_resource_labels": dbServiceResourceLabels,
"discoveryInstallMode": settings.discoveryInstallMode,
"discoveryGroup": shsprintf.EscapeDefaultContext(settings.discoveryGroup),
})
installOpts, err := h.installScriptOptions(ctx)
if err != nil {
return "", trace.Wrap(err)
}

return buf.String(), nil
return "", trace.Wrap(err, "Building install script options")
}

nodeInstallOpts := scripts.InstallNodeScriptOptions{
InstallOptions: installOpts,
Token: token.GetName(),
CAPins: caPins,
// We are using the joinMethod from the script settings instead of the one from the token
// to reproduce the previous script behavior. I'm also afraid that using the
// join method from the token would provide an oracle for an attacker wanting to discover
// the join method.
// We might want to change this in the future to lookup the join method from the token
// to avoid potential mismatch and allow the caller to not care about the join method.
JoinMethod: joinMethod,
Labels: token.GetSuggestedLabels(),
LabelMatchers: token.GetSuggestedAgentMatcherLabels(),
AppServiceEnabled: settings.appInstallMode,
AppName: settings.appName,
AppURI: settings.appURI,
DatabaseServiceEnabled: settings.databaseInstallMode,
DiscoveryServiceEnabled: settings.discoveryInstallMode,
DiscoveryGroup: settings.discoveryGroup,
}

return scripts.GetNodeInstallScript(ctx, nodeInstallOpts)
}

// validateJoinToken validate a join token.
Expand Down Expand Up @@ -789,17 +653,3 @@ func isSameAzureRuleSet(r1, r2 []*types.ProvisionTokenSpecV2Azure_Rule) bool {
sortAzureRules(r2)
return reflect.DeepEqual(r1, r2)
}

type nodeAPIGetter interface {
// GetToken looks up a provisioning token.
GetToken(ctx context.Context, token string) (types.ProvisionToken, error)

// GetClusterCACert returns the CAs for the local cluster without signing keys.
GetClusterCACert(ctx context.Context) (*proto.GetClusterCACertResponse, error)

// GetProxies returns a list of registered proxies.
GetProxies() ([]types.Server, error)
}

// appURIPattern is a regexp excluding invalid characters from application URIs.
var appURIPattern = regexp.MustCompile(`^[-\w/:. ]+$`)
Loading

0 comments on commit eadbe2d

Please sign in to comment.