Skip to content

Commit

Permalink
Run ECS launch in wrapper for better error handling (#63)
Browse files Browse the repository at this point in the history
* Simplify, and add better error handling for ECS launches
  • Loading branch information
jawadqur authored May 2, 2023
1 parent 9bcd8d1 commit 105e14f
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 199 deletions.
27 changes: 12 additions & 15 deletions hatchery/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func terminateEcsWorkspace(ctx context.Context, userName string, accessToken str
return fmt.Sprintf("Service '%s' is in status: %s", userToResourceName(userName, "pod"), *delServiceOutput.Service.Status), nil
}

func launchEcsWorkspace(userName string, hash string, accessToken string, payModel PayModel) {
func launchEcsWorkspace(userName string, hash string, accessToken string, payModel PayModel) error {
// Set up background context, as this runs in a goroutine
ctx := context.Background()

Expand All @@ -368,7 +368,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod
if err != nil {
// Log error and return without launching workspace
Config.Logger.Printf("Failed to launch ECS workspace for user %v, Error: %v", userName, err)
return
return err
}
cpu, err := cpu(hatchApp.CPULimit)
if err != nil {
Expand All @@ -380,7 +380,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod
_, err = svc.launchEcsCluster(userName)
if err != nil {
Config.Logger.Printf("Failed to launch ECS cluster for user %v, Error: %v", userName, err)
return
return err
}

// Get Gen3 API key to be used in workspace
Expand Down Expand Up @@ -422,23 +422,23 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod
volumes, err := svc.EFSFileSystem(userName)
if err != nil {
Config.Logger.Printf("Failed to set up EFS for user %v, Error: %v", userName, err)
return
return err
}

Config.Logger.Printf("Setting up task role for user %s", userName)
taskRole, err := svc.taskRole(userName)
if err != nil {
// Log the error
Config.Logger.Printf("Failed to set up task role for user %v, Error: %v", userName, err)
return
return err
}

Config.Logger.Printf("Setting up execution role for user %s", userName)
_, err = svc.CreateEcsTaskExecutionRole()
if err != nil {
// Log the error
Config.Logger.Printf("Failed to set up execution role for user %v, Error: %v", userName, err)
return
return err
}

Config.Logger.Printf("Setting up ECS task definition for user %s", userName)
Expand Down Expand Up @@ -518,7 +518,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod
if aerr != nil {
Config.Logger.Printf("Error occurred when deleting API Key with ID %s for user %s: %s\n", apiKey.KeyID, userName, err.Error())
}
return
return err
}

Config.Logger.Printf("Launching ECS workspace service for user %s", userName)
Expand All @@ -530,18 +530,19 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod
if aerr != nil {
Config.Logger.Printf("Error occurred when deleting API Key with ID %s for user %s: %s\n", apiKey.KeyID, userName, err.Error())
}
return
return err
}

Config.Logger.Printf("Setting up Transit Gateway for user %s", userName)
err = setupTransitGateway(userName)
if err != nil {
// Log the error
Config.Logger.Printf("Failed to set up Transit Gateway for user %v, Error: %v", userName, err)
return
return err
}

Config.Logger.Printf("Launched ECS workspace service at %s for user %s\n", launchTask, userName)
return nil
}

// Launch ECS service for task definition + LB for routing
Expand Down Expand Up @@ -591,15 +592,11 @@ func (sess *CREDS) launchService(ctx context.Context, taskDefArn string, userNam
if aerr.Error() == "InvalidParameterException: Creation of service was not idempotent." {
Config.Logger.Print("Service already exists.. ")
return "", nil
} else {
Config.Logger.Println(ecs.ErrCodeInvalidParameterException, aerr.Error())
}
}
} else {

Config.Logger.Println(err.Error())
return "", err
}
Config.Logger.Println(err.Error())
return "", err
}
Config.Logger.Printf("Service launched: %s", *result.Service.ClusterArn)
err = createLocalService(ctx, userName, hash, *loadBalancer.LoadBalancers[0].DNSName, payModel)
Expand Down
17 changes: 16 additions & 1 deletion hatchery/hatchery.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func launch(w http.ResponseWriter, r *http.Request) {
// Sending a 200 response straight away, but starting the launch in a goroutine
// TODO: Do more sanity checks before returning 200.
w.WriteHeader(http.StatusOK)
go launchEcsWorkspace(userName, hash, accessToken, *payModel)
go launchEcsWorkspaceWrapper(userName, hash, accessToken, *payModel)
fmt.Fprintf(w, "Launch accepted")
return
} else {
Expand Down Expand Up @@ -351,3 +351,18 @@ func statusEcs(ctx context.Context, userName string, accessToken string, awsAcct
}
return result, nil
}

// Wrapper function to launch ECS workspace in a goroutine.
// Terminates workspace if launch fails for whatever reason
func launchEcsWorkspaceWrapper(userName string, hash string, accessToken string, payModel PayModel) {

err := launchEcsWorkspace(userName, hash, accessToken, payModel)
if err != nil {
Config.Logger.Printf("Error: %s", err)
// Terminate ECS workspace if launch fails.
_, err = terminateEcsWorkspace(context.Background(), userName, accessToken, payModel.AWSAccountId)
if err != nil {
Config.Logger.Printf("Error: %s", err)
}
}
}
143 changes: 143 additions & 0 deletions hatchery/ram.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package hatchery

import (
"fmt"
"os"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ram"
)

func acceptTransitGatewayShare(pm *PayModel, userName string, sess *session.Session, ramArn *string) error {
roleARN := "arn:aws:iam::" + pm.AWSAccountId + ":role/csoc_adminvm"
svc := NewSVC(sess, roleARN)
err := svc.acceptTGWShare(ramArn)
if err != nil {
// Log error
Config.Logger.Printf(err.Error())
return err
}
return nil
}

func (creds *CREDS) acceptTGWShare(ramArn *string) error {
session := session.Must(session.NewSession(&aws.Config{
Credentials: creds.creds,
Region: aws.String("us-east-1"),
}))
svc := ram.New(session)

ramInvitationInput := &ram.GetResourceShareInvitationsInput{
ResourceShareArns: []*string{
ramArn,
},
}
resourceShareInvitation, err := svc.GetResourceShareInvitations(ramInvitationInput)
if err != nil {
// Log error
Config.Logger.Printf(err.Error())
return err
}

if len(resourceShareInvitation.ResourceShareInvitations) == 0 {
// Log that there are no invitations
Config.Logger.Printf("No invitations found something fishy is going on")
return nil
} else {
if *resourceShareInvitation.ResourceShareInvitations[0].Status != "ACCEPTED" {
_, err := svc.AcceptResourceShareInvitation(&ram.AcceptResourceShareInvitationInput{
ResourceShareInvitationArn: resourceShareInvitation.ResourceShareInvitations[0].ResourceShareInvitationArn,
})
if err != nil {
return err
}
// Log that invitation was accepted
Config.Logger.Printf("Resource share invitation accepted")
return nil
}
// Log that invitation was already accepted
Config.Logger.Printf("Resource share invitation already accepted")
return nil
}
}

func shareTransitGateway(session *session.Session, tgwArn string, accountid string) (*string, error) {
// Share resources using resource share in Resource Access Manager
// https://docs.aws.amazon.com/sdk-for-go/api/service/ram/#ResourceShare
svc := ram.New(session)

// RAM name cannot contain dots
ramName := strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-ram"
getResourceShareInput := &ram.GetResourceSharesInput{
Name: aws.String(ramName),
ResourceOwner: aws.String("SELF"),
ResourceShareStatus: aws.String("ACTIVE"),
}
exRs, err := svc.GetResourceShares(getResourceShareInput)
if err != nil {
return nil, err
}
if len(exRs.ResourceShares) == 0 {
Config.Logger.Printf("Did not find existing resource share, creating a resource share")
resourceShareInput := &ram.CreateResourceShareInput{
// Indicates whether principals outside your organization in Organizations can
// be associated with a resource share.
AllowExternalPrincipals: aws.Bool(true),
Name: aws.String(ramName),
Principals: []*string{aws.String(accountid)},
ResourceArns: []*string{aws.String(tgwArn)},
Tags: []*ram.Tag{
{
Key: aws.String("Name"),
Value: aws.String(ramName),
},
{
Key: aws.String("Environment"),
Value: aws.String(os.Getenv("GEN3_ENDPOINT")),
},
},
}
resourceShare, err := svc.CreateResourceShare(resourceShareInput)
if err != nil {
return nil, err
}
return resourceShare.ResourceShare.ResourceShareArn, nil
} else {
Config.Logger.Printf("Found existing resource share, associating resource share with account")
listResourcesInput := &ram.ListResourcesInput{
ResourceOwner: aws.String("SELF"),
ResourceArns: []*string{&tgwArn},
}
listResources, err := svc.ListResources(listResourcesInput)
if err != nil {
return nil, err
}

listPrincipalsInput := &ram.ListPrincipalsInput{
ResourceArn: aws.String(tgwArn),
Principals: []*string{aws.String(accountid)},
ResourceOwner: aws.String("SELF"),
}
listPrincipals, err := svc.ListPrincipals(listPrincipalsInput)
if err != nil {
Config.Logger.Printf("failed to ListPrincipals: %s", listPrincipalsInput)
return nil, fmt.Errorf("failed to ListPrincipals: %s", err)
}
if len(listPrincipals.Principals) == 0 || len(listResources.Resources) == 0 {
associateResourceShareInput := &ram.AssociateResourceShareInput{
Principals: []*string{aws.String(accountid)},
ResourceArns: []*string{&tgwArn},
ResourceShareArn: exRs.ResourceShares[len(exRs.ResourceShares)-1].ResourceShareArn,
}
_, err := svc.AssociateResourceShare(associateResourceShareInput)
if err != nil {
return nil, err
}
} else {
Config.Logger.Printf("TransitGateway is already shared with AWS account %s ", *listPrincipals.Principals[0].Id)
}
return exRs.ResourceShares[len(exRs.ResourceShares)-1].ResourceShareArn, nil
}
}
Loading

0 comments on commit 105e14f

Please sign in to comment.