diff --git a/hatchery/ecs.go b/hatchery/ecs.go index 95842600..c445fbc7 100644 --- a/hatchery/ecs.go +++ b/hatchery/ecs.go @@ -353,7 +353,7 @@ func terminateEcsWorkspace(ctx context.Context, userName string, accessToken str return fmt.Sprintf("Service '%s' is in status: %s", userToResourceName(userName, "pod"), *delServiceOutput.Service.Status), nil } -func launchEcsWorkspace(userName string, hash string, accessToken string, payModel PayModel) { +func launchEcsWorkspace(userName string, hash string, accessToken string, payModel PayModel) error { // Set up background context, as this runs in a goroutine ctx := context.Background() @@ -368,7 +368,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod if err != nil { // Log error and return without launching workspace Config.Logger.Printf("Failed to launch ECS workspace for user %v, Error: %v", userName, err) - return + return err } cpu, err := cpu(hatchApp.CPULimit) if err != nil { @@ -380,7 +380,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod _, err = svc.launchEcsCluster(userName) if err != nil { Config.Logger.Printf("Failed to launch ECS cluster for user %v, Error: %v", userName, err) - return + return err } // Get Gen3 API key to be used in workspace @@ -422,7 +422,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod volumes, err := svc.EFSFileSystem(userName) if err != nil { Config.Logger.Printf("Failed to set up EFS for user %v, Error: %v", userName, err) - return + return err } Config.Logger.Printf("Setting up task role for user %s", userName) @@ -430,7 +430,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod if err != nil { // Log the error Config.Logger.Printf("Failed to set up task role for user %v, Error: %v", userName, err) - return + return err } Config.Logger.Printf("Setting up execution role for user %s", userName) @@ -438,7 +438,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod if err != nil { // Log the error Config.Logger.Printf("Failed to set up execution role for user %v, Error: %v", userName, err) - return + return err } Config.Logger.Printf("Setting up ECS task definition for user %s", userName) @@ -518,7 +518,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod if aerr != nil { Config.Logger.Printf("Error occurred when deleting API Key with ID %s for user %s: %s\n", apiKey.KeyID, userName, err.Error()) } - return + return err } Config.Logger.Printf("Launching ECS workspace service for user %s", userName) @@ -530,7 +530,7 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod if aerr != nil { Config.Logger.Printf("Error occurred when deleting API Key with ID %s for user %s: %s\n", apiKey.KeyID, userName, err.Error()) } - return + return err } Config.Logger.Printf("Setting up Transit Gateway for user %s", userName) @@ -538,10 +538,11 @@ func launchEcsWorkspace(userName string, hash string, accessToken string, payMod if err != nil { // Log the error Config.Logger.Printf("Failed to set up Transit Gateway for user %v, Error: %v", userName, err) - return + return err } Config.Logger.Printf("Launched ECS workspace service at %s for user %s\n", launchTask, userName) + return nil } // Launch ECS service for task definition + LB for routing @@ -591,15 +592,11 @@ func (sess *CREDS) launchService(ctx context.Context, taskDefArn string, userNam if aerr.Error() == "InvalidParameterException: Creation of service was not idempotent." { Config.Logger.Print("Service already exists.. ") return "", nil - } else { - Config.Logger.Println(ecs.ErrCodeInvalidParameterException, aerr.Error()) } } - } else { - - Config.Logger.Println(err.Error()) - return "", err } + Config.Logger.Println(err.Error()) + return "", err } Config.Logger.Printf("Service launched: %s", *result.Service.ClusterArn) err = createLocalService(ctx, userName, hash, *loadBalancer.LoadBalancers[0].DNSName, payModel) diff --git a/hatchery/hatchery.go b/hatchery/hatchery.go index 3e2f3aa4..fe840979 100644 --- a/hatchery/hatchery.go +++ b/hatchery/hatchery.go @@ -242,7 +242,7 @@ func launch(w http.ResponseWriter, r *http.Request) { // Sending a 200 response straight away, but starting the launch in a goroutine // TODO: Do more sanity checks before returning 200. w.WriteHeader(http.StatusOK) - go launchEcsWorkspace(userName, hash, accessToken, *payModel) + go launchEcsWorkspaceWrapper(userName, hash, accessToken, *payModel) fmt.Fprintf(w, "Launch accepted") return } else { @@ -351,3 +351,18 @@ func statusEcs(ctx context.Context, userName string, accessToken string, awsAcct } return result, nil } + +// Wrapper function to launch ECS workspace in a goroutine. +// Terminates workspace if launch fails for whatever reason +func launchEcsWorkspaceWrapper(userName string, hash string, accessToken string, payModel PayModel) { + + err := launchEcsWorkspace(userName, hash, accessToken, payModel) + if err != nil { + Config.Logger.Printf("Error: %s", err) + // Terminate ECS workspace if launch fails. + _, err = terminateEcsWorkspace(context.Background(), userName, accessToken, payModel.AWSAccountId) + if err != nil { + Config.Logger.Printf("Error: %s", err) + } + } +} diff --git a/hatchery/ram.go b/hatchery/ram.go new file mode 100644 index 00000000..7cbdb278 --- /dev/null +++ b/hatchery/ram.go @@ -0,0 +1,143 @@ +package hatchery + +import ( + "fmt" + "os" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ram" +) + +func acceptTransitGatewayShare(pm *PayModel, userName string, sess *session.Session, ramArn *string) error { + roleARN := "arn:aws:iam::" + pm.AWSAccountId + ":role/csoc_adminvm" + svc := NewSVC(sess, roleARN) + err := svc.acceptTGWShare(ramArn) + if err != nil { + // Log error + Config.Logger.Printf(err.Error()) + return err + } + return nil +} + +func (creds *CREDS) acceptTGWShare(ramArn *string) error { + session := session.Must(session.NewSession(&aws.Config{ + Credentials: creds.creds, + Region: aws.String("us-east-1"), + })) + svc := ram.New(session) + + ramInvitationInput := &ram.GetResourceShareInvitationsInput{ + ResourceShareArns: []*string{ + ramArn, + }, + } + resourceShareInvitation, err := svc.GetResourceShareInvitations(ramInvitationInput) + if err != nil { + // Log error + Config.Logger.Printf(err.Error()) + return err + } + + if len(resourceShareInvitation.ResourceShareInvitations) == 0 { + // Log that there are no invitations + Config.Logger.Printf("No invitations found something fishy is going on") + return nil + } else { + if *resourceShareInvitation.ResourceShareInvitations[0].Status != "ACCEPTED" { + _, err := svc.AcceptResourceShareInvitation(&ram.AcceptResourceShareInvitationInput{ + ResourceShareInvitationArn: resourceShareInvitation.ResourceShareInvitations[0].ResourceShareInvitationArn, + }) + if err != nil { + return err + } + // Log that invitation was accepted + Config.Logger.Printf("Resource share invitation accepted") + return nil + } + // Log that invitation was already accepted + Config.Logger.Printf("Resource share invitation already accepted") + return nil + } +} + +func shareTransitGateway(session *session.Session, tgwArn string, accountid string) (*string, error) { + // Share resources using resource share in Resource Access Manager + // https://docs.aws.amazon.com/sdk-for-go/api/service/ram/#ResourceShare + svc := ram.New(session) + + // RAM name cannot contain dots + ramName := strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-ram" + getResourceShareInput := &ram.GetResourceSharesInput{ + Name: aws.String(ramName), + ResourceOwner: aws.String("SELF"), + ResourceShareStatus: aws.String("ACTIVE"), + } + exRs, err := svc.GetResourceShares(getResourceShareInput) + if err != nil { + return nil, err + } + if len(exRs.ResourceShares) == 0 { + Config.Logger.Printf("Did not find existing resource share, creating a resource share") + resourceShareInput := &ram.CreateResourceShareInput{ + // Indicates whether principals outside your organization in Organizations can + // be associated with a resource share. + AllowExternalPrincipals: aws.Bool(true), + Name: aws.String(ramName), + Principals: []*string{aws.String(accountid)}, + ResourceArns: []*string{aws.String(tgwArn)}, + Tags: []*ram.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(ramName), + }, + { + Key: aws.String("Environment"), + Value: aws.String(os.Getenv("GEN3_ENDPOINT")), + }, + }, + } + resourceShare, err := svc.CreateResourceShare(resourceShareInput) + if err != nil { + return nil, err + } + return resourceShare.ResourceShare.ResourceShareArn, nil + } else { + Config.Logger.Printf("Found existing resource share, associating resource share with account") + listResourcesInput := &ram.ListResourcesInput{ + ResourceOwner: aws.String("SELF"), + ResourceArns: []*string{&tgwArn}, + } + listResources, err := svc.ListResources(listResourcesInput) + if err != nil { + return nil, err + } + + listPrincipalsInput := &ram.ListPrincipalsInput{ + ResourceArn: aws.String(tgwArn), + Principals: []*string{aws.String(accountid)}, + ResourceOwner: aws.String("SELF"), + } + listPrincipals, err := svc.ListPrincipals(listPrincipalsInput) + if err != nil { + Config.Logger.Printf("failed to ListPrincipals: %s", listPrincipalsInput) + return nil, fmt.Errorf("failed to ListPrincipals: %s", err) + } + if len(listPrincipals.Principals) == 0 || len(listResources.Resources) == 0 { + associateResourceShareInput := &ram.AssociateResourceShareInput{ + Principals: []*string{aws.String(accountid)}, + ResourceArns: []*string{&tgwArn}, + ResourceShareArn: exRs.ResourceShares[len(exRs.ResourceShares)-1].ResourceShareArn, + } + _, err := svc.AssociateResourceShare(associateResourceShareInput) + if err != nil { + return nil, err + } + } else { + Config.Logger.Printf("TransitGateway is already shared with AWS account %s ", *listPrincipals.Principals[0].Id) + } + return exRs.ResourceShares[len(exRs.ResourceShares)-1].ResourceShareArn, nil + } +} diff --git a/hatchery/transitgateway.go b/hatchery/transitgateway.go index 26219a6a..52480b9d 100644 --- a/hatchery/transitgateway.go +++ b/hatchery/transitgateway.go @@ -10,21 +10,80 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ram" ) +// This function sets up the transit gateway between the account hatchery is running in, and the account workspaces will run. func setupTransitGateway(userName string) error { - Config.Logger.Printf("Setting up transit gateway") - _, err := createTransitGateway(userName) + // Create new AWS session to be used by this function + sess := session.Must(session.NewSession(&aws.Config{ + // TODO: Make this configurable + Region: aws.String("us-east-1"), + })) + + pm, err := getCurrentPayModel(userName) + if err != nil { + return err + } + + Config.Logger.Printf("Setting up transit gateway in main account") + tgwid, tgwarn, tgwRouteTableId, err := createTransitGateway(sess, userName) if err != nil { return fmt.Errorf("error creating transit gateway: %s", err.Error()) } + + // This transit gateway attachment connects the main account to the remote account + // This is needed once per environment, instead of once per user. + err = createLocalTransitGatewayAttachment(userName, *tgwid, tgwRouteTableId) + if err != nil { + return fmt.Errorf("error creating local transit gateway attachment: %s", err.Error()) + } + + // Make sure transit gateway is shared with the remote account + ramArn, err := shareTransitGateway(sess, *tgwarn, pm.AWSAccountId) + if err != nil { + return err + } + + // Accept transit gateway share in remote account + err = acceptTransitGatewayShare(pm, *tgwarn, sess, ramArn) + if err != nil { + return err + } + Config.Logger.Printf("Setting up remote account ") err = setupRemoteAccount(userName, false) if err != nil { return fmt.Errorf("failed to setup remote account: %s", err.Error()) } - Config.Logger.Printf("Remote account setup complete") + return nil +} + +func createLocalTransitGatewayAttachment(userName string, tgwid string, tgwRouteTableId *string) error { + vpcid := os.Getenv("GEN3_VPCID") + + sess := session.Must(session.NewSession(&aws.Config{ + // TODO: Make this configurable + Region: aws.String("us-east-1"), + })) + + // ec2 session to main AWS account. + ec2Local := ec2.New(sess) + // Create Transit Gateway Attachment in local VPC + // Config.Logger.Printf("Creating tgw attachment in local VPC: %s", vpcid) + tgwAttachment, err := createTransitGatewayAttachments(ec2Local, vpcid, tgwid, true, nil, userName) + if err != nil { + return err + } + Config.Logger.Printf("Attachment created: %s", *tgwAttachment) + + // Create Transit Gateway Route Table + _, err = TGWRoutes(userName, tgwRouteTableId, tgwAttachment, ec2Local, true, false, nil) + if err != nil { + // Log error + Config.Logger.Printf("Failed to create TGW route table: %s", err.Error()) + return err + } + return nil } @@ -39,7 +98,6 @@ func teardownTransitGateway(userName string) error { } -// TODO: Change the name of this function to match HUB/SPOKE model func describeMainNetwork(vpcid string, svc *ec2.EC2) (*NetworkInfo, error) { networkInfo := NetworkInfo{} vpcInput := &ec2.DescribeVpcsInput{ @@ -100,20 +158,11 @@ func describeMainNetwork(vpcid string, svc *ec2.EC2) (*NetworkInfo, error) { return &networkInfo, nil } -func createTransitGateway(userName string) (*string, error) { - pm, err := getCurrentPayModel(userName) - if err != nil { - return nil, err - } - sess := session.Must(session.NewSession(&aws.Config{ - // TODO: Make this configurable - Region: aws.String("us-east-1"), - })) +func createTransitGateway(sess *session.Session, userName string) (tgwid *string, tgwarn *string, tgwRouteTableId *string, err error) { // ec2 session to main AWS account. ec2Local := ec2.New(sess) - vpcid := os.Getenv("GEN3_VPCID") tgwName := strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-tgw" // Check for existing transit gateway exTg, err := ec2Local.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{ @@ -129,7 +178,7 @@ func createTransitGateway(userName string) (*string, error) { }, }) if err != nil { - return nil, fmt.Errorf("failed to DescribeTransitGateways: %s", err.Error()) + return nil, nil, nil, fmt.Errorf("failed to DescribeTransitGateways: %s", err.Error()) } // Create Transit Gateway if it doesn't exist @@ -160,54 +209,14 @@ func createTransitGateway(userName string) (*string, error) { }, }) if err != nil { - return nil, err + return nil, nil, nil, err } Config.Logger.Printf("Transit gateway created: %s", *tg.TransitGateway.TransitGatewayId) - // Create Transit Gateway Attachment in local VPC - Config.Logger.Printf("Creating tgw attachment in local VPC: %s", vpcid) - tgwAttachment, err := createTransitGatewayAttachments(ec2Local, vpcid, *tg.TransitGateway.TransitGatewayId, true, nil, userName) - if err != nil { - return nil, err - } - Config.Logger.Printf("Attachment created: %s", *tgwAttachment) - - // Create Transit Gateway Route Table - _, err = TGWRoutes(userName, tg.TransitGateway.Options.AssociationDefaultRouteTableId, tgwAttachment, ec2Local, true, false, nil) - if err != nil { - // Log error - Config.Logger.Printf("Failed to create TGW route table: %s", err.Error()) - return nil, err - } - resourceshare, err := shareTransitGateway(sess, *tg.TransitGateway.TransitGatewayArn, pm.AWSAccountId) - if err != nil { - return nil, err - } - Config.Logger.Printf("Resources shared: %s", *resourceshare) - return tg.TransitGateway.TransitGatewayId, nil + return tg.TransitGateway.TransitGatewayId, tg.TransitGateway.TransitGatewayArn, tg.TransitGateway.Options.AssociationDefaultRouteTableId, nil } else { Config.Logger.Print("Existing transit gateway found. Skipping creation...") - tgwAttachment, err := createTransitGatewayAttachments(ec2Local, vpcid, *exTg.TransitGateways[len(exTg.TransitGateways)-1].TransitGatewayId, true, nil, userName) - if err != nil { - return nil, err - } - Config.Logger.Printf("Local TGW Attachment created: %s", *tgwAttachment) - resourceshare, err := shareTransitGateway(sess, *exTg.TransitGateways[len(exTg.TransitGateways)-1].TransitGatewayArn, pm.AWSAccountId) - if err != nil { - return nil, err - } - - // Updating the route table to include the new attachment - Config.Logger.Printf("Updating route table to include new attachment: %s", *tgwAttachment) - _, err = TGWRoutes(userName, exTg.TransitGateways[len(exTg.TransitGateways)-1].Options.AssociationDefaultRouteTableId, tgwAttachment, ec2Local, true, false, nil) - if err != nil { - // Log error - Config.Logger.Printf("Failed to create TGW route table: %s", err.Error()) - return nil, err - } - - Config.Logger.Printf("Resources shared: %s", *resourceshare) - return exTg.TransitGateways[len(exTg.TransitGateways)-1].TransitGatewayId, nil + return exTg.TransitGateways[len(exTg.TransitGateways)-1].TransitGatewayId, exTg.TransitGateways[len(exTg.TransitGateways)-1].TransitGatewayArn, exTg.TransitGateways[len(exTg.TransitGateways)-1].Options.AssociationDefaultRouteTableId, nil } } @@ -228,11 +237,8 @@ func createTransitGatewayAttachments(svc *ec2.EC2, vpcid string, tgwid string, l if aerr, ok := err.(awserr.Error); ok { switch aerr.Code() { case "InvalidTransitGatewayID.NotFound": - // Accept any pending invites again - err = sess.acceptTGWShare() - if err != nil { - return nil, err - } + // Sleep for 10 seconds before trying again.. + time.Sleep(10 * time.Second) _, err = svc.DescribeTransitGateways(tgInput) if err != nil { return nil, fmt.Errorf("cannot DescribeTransitGateways again: %s", err.Error()) @@ -244,7 +250,7 @@ func createTransitGatewayAttachments(svc *ec2.EC2, vpcid string, tgwid string, l return nil, err } for *exTg.TransitGateways[0].State != "available" { - Config.Logger.Printf("TransitGateway is in state: %s ... Waiting for 5 seconds", *exTg.TransitGateways[0].State) + Config.Logger.Printf("TransitGateway is in state: %s ... Waiting for 10 seconds", *exTg.TransitGateways[0].State) // sleep for 10 sec time.Sleep(10 * time.Second) exTg, _ = svc.DescribeTransitGateways(tgInput) @@ -258,6 +264,16 @@ func createTransitGatewayAttachments(svc *ec2.EC2, vpcid string, tgwid string, l if err != nil { return nil, fmt.Errorf("Failed to get network info: %s", err) } + + tgwAttachmentName := "" + if local { + // Name the local tgwAttachment after the environment, instead of the user. + // This is shared between all users in the environment. + tgwAttachmentName = os.Getenv("GEN3_ENDPOINT") + "-tgwa" + } else { + tgwAttachmentName = userToResourceName(userName, "service") + "tgwa" + } + exTgwAttachmentInput := &ec2.DescribeTransitGatewayAttachmentsInput{ Filters: []*ec2.Filter{ { @@ -279,7 +295,8 @@ func createTransitGatewayAttachments(svc *ec2.EC2, vpcid string, tgwid string, l return nil, err } if len(exTgwAttachment.TransitGatewayAttachments) == 0 { - tgwAttachmentName := userToResourceName(userName, "service") + "tgwa" + // Create the transit gateway attachment + Config.Logger.Printf("Local transitgateway attachment not found, creating new one") tgwAttachmentInput := &ec2.CreateTransitGatewayVpcAttachmentInput{ TransitGatewayId: exTg.TransitGateways[0].TransitGatewayId, VpcId: networkInfo.vpc.Vpcs[len(networkInfo.vpc.Vpcs)-1].VpcId, @@ -306,8 +323,10 @@ func createTransitGatewayAttachments(svc *ec2.EC2, vpcid string, tgwid string, l if err != nil { return nil, fmt.Errorf("cannot create transitgatewayattachment: %s", err.Error()) } + Config.Logger.Printf("Created transitgatewayattachment: %s", *tgwAttachment.TransitGatewayVpcAttachment.TransitGatewayAttachmentId) return tgwAttachment.TransitGatewayVpcAttachment.TransitGatewayAttachmentId, nil } else { + Config.Logger.Printf("Local transitgateway attachment found, using existing one") return exTgwAttachment.TransitGatewayAttachments[0].TransitGatewayAttachmentId, nil } } @@ -369,80 +388,6 @@ func deleteTransitGatewayAttachment(svc *ec2.EC2, tgwid string, userName string) return delTGWAttachment.TransitGatewayVpcAttachment.TransitGatewayAttachmentId, nil } -func shareTransitGateway(session *session.Session, tgwArn string, accountid string) (*string, error) { - svc := ram.New(session) - - ramName := strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-ram" - getResourceShareInput := &ram.GetResourceSharesInput{ - Name: aws.String(ramName), - ResourceOwner: aws.String("SELF"), - ResourceShareStatus: aws.String("ACTIVE"), - } - exRs, err := svc.GetResourceShares(getResourceShareInput) - if err != nil { - return nil, err - } - if len(exRs.ResourceShares) == 0 { - Config.Logger.Printf("Did not find existing resource share, creating a resource share") - resourceShareInput := &ram.CreateResourceShareInput{ - // Indicates whether principals outside your organization in Organizations can - // be associated with a resource share. - AllowExternalPrincipals: aws.Bool(true), - Name: aws.String(ramName), - Principals: []*string{aws.String(accountid)}, - ResourceArns: []*string{aws.String(tgwArn)}, - Tags: []*ram.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(ramName), - }, - { - Key: aws.String("Environment"), - Value: aws.String(os.Getenv("GEN3_ENDPOINT")), - }, - }, - } - resourceShare, err := svc.CreateResourceShare(resourceShareInput) - if err != nil { - return nil, err - } - return resourceShare.ResourceShare.ResourceShareArn, nil - } else { - listResourcesInput := &ram.ListResourcesInput{ - ResourceOwner: aws.String("SELF"), - ResourceArns: []*string{&tgwArn}, - } - listResources, err := svc.ListResources(listResourcesInput) - if err != nil { - return nil, err - } - - listPrincipalsInput := &ram.ListPrincipalsInput{ - ResourceArn: aws.String(tgwArn), - Principals: []*string{aws.String(accountid)}, - ResourceOwner: aws.String("SELF"), - } - listPrincipals, err := svc.ListPrincipals(listPrincipalsInput) - if err != nil { - return nil, fmt.Errorf("failed to ListPrincipals: %s", err) - } - if len(listPrincipals.Principals) == 0 || len(listResources.Resources) == 0 { - associateResourceShareInput := &ram.AssociateResourceShareInput{ - Principals: []*string{aws.String(accountid)}, - ResourceArns: []*string{&tgwArn}, - ResourceShareArn: exRs.ResourceShares[len(exRs.ResourceShares)-1].ResourceShareArn, - } - _, err := svc.AssociateResourceShare(associateResourceShareInput) - if err != nil { - return nil, err - } - } else { - Config.Logger.Printf("TransitGateway is already shared with AWS account %s ", *listPrincipals.Principals[0].Id) - } - return exRs.ResourceShares[len(exRs.ResourceShares)-1].ResourceShareArn, nil - } -} - func setupRemoteAccount(userName string, teardown bool) error { pm, err := getCurrentPayModel(userName) if err != nil { @@ -462,10 +407,6 @@ func setupRemoteAccount(userName string, teardown bool) error { }))) vpcid := os.Getenv("GEN3_VPCID") - err = svc.acceptTGWShare() - if err != nil { - return err - } exTgInput := &ec2.DescribeTransitGatewaysInput{ Filters: []*ec2.Filter{ { @@ -484,10 +425,10 @@ func setupRemoteAccount(userName string, teardown bool) error { } for len(exTg.TransitGateways) == 0 { Config.Logger.Printf("Waiting to find ex_tgw") - err := svc.acceptTGWShare() - if err != nil { - return err - } + // err := svc.acceptTGWShare() + // if err != nil { + // return err + // } exTg, err = ec2Local.DescribeTransitGateways(exTgInput) if err != nil { return err @@ -532,34 +473,6 @@ func setupRemoteAccount(userName string, teardown bool) error { return nil } -func (creds *CREDS) acceptTGWShare() error { - session := session.Must(session.NewSession(&aws.Config{ - Credentials: creds.creds, - Region: aws.String("us-east-1"), - })) - svc := ram.New(session) - - resourceShareInvitation, err := svc.GetResourceShareInvitations(&ram.GetResourceShareInvitationsInput{}) - if err != nil { - return err - } - - if len(resourceShareInvitation.ResourceShareInvitations) == 0 { - return nil - } else { - if *resourceShareInvitation.ResourceShareInvitations[0].Status != "ACCEPTED" { - _, err := svc.AcceptResourceShareInvitation(&ram.AcceptResourceShareInvitationInput{ - ResourceShareInvitationArn: resourceShareInvitation.ResourceShareInvitations[0].ResourceShareInvitationArn, - }) - if err != nil { - return err - } - return nil - } - return nil - } -} - func TGWRoutes(userName string, tgwRoutetableId *string, tgwAttachmentId *string, svc *ec2.EC2, local bool, teardown bool, sess *CREDS) (*string, error) { var networkInfo *NetworkInfo vpcid := os.Getenv("GEN3_VPCID") @@ -627,13 +540,31 @@ func TGWRoutes(userName string, tgwRoutetableId *string, tgwAttachmentId *string } if len(exRoutes.Routes) == 1 { - delRouteInput := &ec2.DeleteTransitGatewayRouteInput{ - DestinationCidrBlock: networkInfo.vpc.Vpcs[0].CidrBlock, - TransitGatewayRouteTableId: tgwRoutetableId, - } - _, err := svc.DeleteTransitGatewayRoute(delRouteInput) - if err != nil { - return nil, err + if local { + // Delete route only if it's blackhole route + if *exRoutes.Routes[0].State == "blackhole" { + Config.Logger.Printf("Route is blackhole, deleting") + delRouteInput := &ec2.DeleteTransitGatewayRouteInput{ + DestinationCidrBlock: networkInfo.vpc.Vpcs[0].CidrBlock, + TransitGatewayRouteTableId: tgwRoutetableId, + } + _, err := svc.DeleteTransitGatewayRoute(delRouteInput) + if err != nil { + return nil, err + } + } else { + Config.Logger.Printf("Route already exists for %s", *networkInfo.vpc.Vpcs[0].CidrBlock) + return exRoutesInput.TransitGatewayRouteTableId, nil + } + } else { + delRouteInput := &ec2.DeleteTransitGatewayRouteInput{ + DestinationCidrBlock: networkInfo.vpc.Vpcs[0].CidrBlock, + TransitGatewayRouteTableId: tgwRoutetableId, + } + _, err := svc.DeleteTransitGatewayRoute(delRouteInput) + if err != nil { + return nil, err + } } } tgRouteInput := &ec2.CreateTransitGatewayRouteInput{