Skip to content

Commit

Permalink
Make Login Provider List Dynamic (#158)
Browse files Browse the repository at this point in the history
This has been hard coded for far too long, make it dynamic based on what
providers are available.  Also preemptively fix a some bugs in the
oauth2 provider lookup function.
  • Loading branch information
spjmurray authored Jan 24, 2025
1 parent e897d64 commit e383560
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 236 deletions.
4 changes: 2 additions & 2 deletions charts/identity/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ description: A Helm chart for deploying Unikorn's IdP

type: application

version: v0.2.52-rc3
appVersion: v0.2.52-rc3
version: v0.2.52-rc4
appVersion: v0.2.52-rc4

icon: https://raw.githubusercontent.com/unikorn-cloud/assets/main/images/logos/dark-on-light/icon.png

Expand Down
77 changes: 66 additions & 11 deletions pkg/oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import (

var (
ErrUnsupportedProviderType = goerrors.New("unhandled provider type")
ErrReference = goerrors.New("resource reference error")
ErrUserNotDomainMapped = goerrors.New("user is not domain mapped to an organization")
ErrEmailLookup = goerrors.New("failed to lookup email")
)
Expand Down Expand Up @@ -403,12 +404,18 @@ func (a *Authenticator) Authorization(w http.ResponseWriter, r *http.Request) {
return
}

supportedTypes, err := a.getProviderTypes(r.Context())
if err != nil {
authorizationError(w, r, client.Spec.RedirectURI, ErrorServerError, "failed to get oauth2 providers")
return
}

loginQuery := url.Values{}

loginQuery.Set("state", state)
loginQuery.Set("callback", "https://"+r.Host+"/oauth2/v2/login")
// TODO: this needs to be driven by the available oauth2providers
loginQuery.Set("providers", "google microsoft github")
loginQuery.Set("providers", strings.Join(supportedTypes, " "))

// Redirect to an external login handler, if you have chosen to.
if client.Spec.LoginURI != nil {
Expand Down Expand Up @@ -461,11 +468,38 @@ func (a *Authenticator) lookupOrganization(ctx context.Context, email string) (*
return nil, ErrUserNotDomainMapped
}

// getProviders lists all identity providers.
func (a *Authenticator) getProviders(ctx context.Context) (*unikornv1.OAuth2ProviderList, error) {
resources := &unikornv1.OAuth2ProviderList{}

if err := a.client.List(ctx, resources, &client.ListOptions{Namespace: a.namespace}); err != nil {
return nil, err
}

return resources, nil
}

func (a *Authenticator) getProviderTypes(ctx context.Context) ([]string, error) {
resources, err := a.getProviders(ctx)
if err != nil {
return nil, err
}

result := make([]string, 0, len(resources.Items))

for _, resource := range resources.Items {
if resource.Spec.Type != nil && *resource.Spec.Type != "" {
result = append(result, string(*resource.Spec.Type))
}
}

return result, nil
}

// lookupProviderByType finds the provider configuration by the type chosen by the user.
func (a *Authenticator) lookupProviderByType(ctx context.Context, t unikornv1.IdentityProviderType) (*unikornv1.OAuth2Provider, error) {
var resources unikornv1.OAuth2ProviderList

if err := a.client.List(ctx, &resources, &client.ListOptions{Namespace: a.namespace}); err != nil {
resources, err := a.getProviders(ctx)
if err != nil {
return nil, err
}

Expand All @@ -479,14 +513,35 @@ func (a *Authenticator) lookupProviderByType(ctx context.Context, t unikornv1.Id
}

// lookupProviderByID finds the provider based on ID.
func (a *Authenticator) lookupProviderByID(ctx context.Context, id string) (*unikornv1.OAuth2Provider, error) {
var providerResource unikornv1.OAuth2Provider
func (a *Authenticator) lookupProviderByID(ctx context.Context, id string, organization *unikornv1.Organization) (*unikornv1.OAuth2Provider, error) {
providers := &unikornv1.OAuth2ProviderList{}

if err := a.client.Get(ctx, client.ObjectKey{Namespace: a.namespace, Name: id}, &providerResource); err != nil {
if err := a.client.List(ctx, providers); err != nil {
return nil, err
}

return &providerResource, nil
find := func(provider unikornv1.OAuth2Provider) bool {
return provider.Name == id
}

index := slices.IndexFunc(providers.Items, find)
if index < 0 {
return nil, fmt.Errorf("%w: requested provider does not exist", ErrReference)
}

provider := &providers.Items[index]

// If the provider is neither global, nor scoped to the provided organization, reject.
// NOTE: when called by the authorization endpoint and an email is provided, that email
// maps to an organization, and the provider must be in that organization to avoid
// jailbreaking. In later provider authorization and token exchanges we can trust the
// ID as it's already been checked and it has been cryptographically protected against
// tamering.
if provider.Namespace != a.namespace && (organization == nil || provider.Namespace != organization.Status.Namespace) {
return nil, fmt.Errorf("%w: requested provider not allowed", ErrReference)
}

return provider, nil
}

// newOIDCProvider abstracts away any hacks for specific providers.
Expand Down Expand Up @@ -640,7 +695,7 @@ func (a *Authenticator) Login(w http.ResponseWriter, r *http.Request) {
return
}

provider, err := a.lookupProviderByID(r.Context(), *organization.Spec.ProviderID)
provider, err := a.lookupProviderByID(r.Context(), *organization.Spec.ProviderID, organization)
if err != nil {
authorizationError(w, r, query.Get("redirect_uri"), ErrorServerError, err.Error())
return
Expand Down Expand Up @@ -886,7 +941,7 @@ func (a *Authenticator) Callback(w http.ResponseWriter, r *http.Request) {
return
}

provider, err := a.lookupProviderByID(r.Context(), state.OAuth2Provider)
provider, err := a.lookupProviderByID(r.Context(), state.OAuth2Provider, nil)
if err != nil {
authorizationError(w, r, state.ClientRedirectURI, ErrorServerError, "failed to get oauth2 provider")
return
Expand Down Expand Up @@ -1136,7 +1191,7 @@ func (a *Authenticator) TokenRefreshToken(w http.ResponseWriter, r *http.Request

// Lookup the provider details, then do a token refresh against that to update
// the access token.
provider, err := a.lookupProviderByID(r.Context(), claims.Custom.Provider)
provider, err := a.lookupProviderByID(r.Context(), claims.Custom.Provider, nil)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit e383560

Please sign in to comment.