12
12
from google .oauth2 .credentials import Credentials as OAuthCredentials
13
13
from googleapiclient import discovery , errors
14
14
15
- from sky .skylet .providers .gcp .node import MAX_POLLS , POLL_INTERVAL , GCPNodeType
15
+ from sky .skylet .providers .gcp .node import MAX_POLLS , POLL_INTERVAL , GCPNodeType , GCPCompute
16
+ from sky .skylet .providers .gcp .constants import SKYPILOT_VPC_NAME , VPC_TEMPLATE , FIREWALL_RULES_TEMPLATE
16
17
from ray .autoscaler ._private .util import check_legacy_fields
17
18
18
19
logger = logging .getLogger (__name__ )
46
47
# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes
47
48
# with ServiceAccounts.
48
49
49
-
50
50
def get_node_type (node : dict ) -> GCPNodeType :
51
51
"""Returns node type based on the keys in ``node``.
52
52
@@ -488,6 +488,96 @@ def _configure_key_pair(config, compute):
488
488
return config
489
489
490
490
491
+ def _check_firewall_rules (vpc_name , config , compute ):
492
+ """Check if the firewall rules in the VPC are sufficient."""
493
+ required_rules = FIREWALL_RULES_TEMPLATE .copy ()
494
+
495
+ operation = (
496
+ compute .networks ().
497
+ getEffectiveFirewalls (
498
+ project = config ["provider" ]["project_id" ],
499
+ network = vpc_name
500
+ )
501
+ )
502
+ response = operation .execute ()
503
+ if len (response ) == 0 :
504
+ return False
505
+ effective_rules = response ["firewalls" ]
506
+
507
+ def _get_refined_rule (rule ):
508
+ KEY_TO_COMPARE = {"sourceRanges" , "allowed" , "direction" }
509
+ refined_rule = {}
510
+ for k in KEY_TO_COMPARE :
511
+ if k not in rule :
512
+ continue
513
+ if k == "allowed" :
514
+ refined_rule [k ] = sorted (rule [k ], key = lambda x : x ["IPProtocol" ])
515
+ else :
516
+ refined_rule [k ] = rule [k ]
517
+ return refined_rule
518
+
519
+ required_rules = list (map (_get_refined_rule , required_rules ))
520
+ effective_rules = list (map (_get_refined_rule , effective_rules ))
521
+
522
+ for rule in required_rules :
523
+ if rule not in effective_rules :
524
+ return False
525
+ return True
526
+
527
+
528
+ def get_usable_vpc (config ):
529
+ """Return a usable VPC.
530
+
531
+ If not found, create a new one with sufficient firewall rules.
532
+ """
533
+ _ , _ , compute , _ = construct_clients_from_provider_config (config ["provider" ])
534
+
535
+ # For backward compatibility, reuse the VPC if the VM is launched.
536
+ resource = GCPCompute (
537
+ compute ,
538
+ config ["provider" ]["project_id" ],
539
+ config ["provider" ]["availability_zone" ],
540
+ config ["cluster_name" ],
541
+ )
542
+ node = resource ._list_instances (label_filters = None , status_filter = None )
543
+ if len (node ) > 0 :
544
+ netInterfaces = node [0 ].get ("networkInterfaces" , [])
545
+ if len (netInterfaces ) > 0 :
546
+ vpc_name = netInterfaces [0 ]["network" ].split ("/" )[- 1 ]
547
+ return vpc_name
548
+
549
+ vpcnets_all = _list_vpcnets (config , compute )
550
+
551
+ usable_vpc_name = None
552
+ for vpc in vpcnets_all :
553
+ if _check_firewall_rules (vpc ["name" ], config , compute ):
554
+ usable_vpc_name = vpc ["name" ]
555
+ break
556
+
557
+ if usable_vpc_name is None :
558
+ logger .info (f"Creating a default VPC network, { SKYPILOT_VPC_NAME } ..." )
559
+
560
+ # Create a default VPC network
561
+ proj_id = config ["provider" ]["project_id" ]
562
+ body = VPC_TEMPLATE .copy ()
563
+ body ["name" ] = body ["name" ].format (VPC_NAME = SKYPILOT_VPC_NAME )
564
+ body ["selfLink" ] = body ["selfLink" ].format (PROJ_ID = proj_id , VPC_NAME = SKYPILOT_VPC_NAME )
565
+ _create_vpcnet (config , compute , body )
566
+
567
+ # Create firewall rules
568
+ for rule in FIREWALL_RULES_TEMPLATE :
569
+ body = rule .copy ()
570
+ body ["name" ] = body ["name" ].format (VPC_NAME = SKYPILOT_VPC_NAME )
571
+ body ["network" ] = body ["network" ].format (PROJ_ID = proj_id , VPC_NAME = SKYPILOT_VPC_NAME )
572
+ body ["selfLink" ] = body ["selfLink" ].format (PROJ_ID = proj_id , VPC_NAME = SKYPILOT_VPC_NAME )
573
+ _create_firewall_rule (config , compute , body )
574
+
575
+ usable_vpc_name = SKYPILOT_VPC_NAME
576
+ logger .info (f"A VPC network { SKYPILOT_VPC_NAME } created." )
577
+
578
+ return usable_vpc_name
579
+
580
+
491
581
def _configure_subnet (config , compute ):
492
582
"""Pick a reasonable subnet if not specified by the config."""
493
583
config = copy .deepcopy (config )
@@ -506,14 +596,9 @@ def _configure_subnet(config, compute):
506
596
):
507
597
return config
508
598
509
- subnets = _list_subnets (config , compute )
510
-
511
- if not subnets :
512
- raise NotImplementedError ("Should be able to create subnet." )
513
-
514
- # TODO: make sure that we have usable subnet. Maybe call
515
- # compute.subnetworks().listUsable? For some reason it didn't
516
- # work out-of-the-box
599
+ # SkyPilot: make sure there's a usable VPC
600
+ usable_vpc_name = get_usable_vpc (config )
601
+ subnets = _list_subnets (config , compute , filter = f"(name=\" { usable_vpc_name } \" )" )
517
602
default_subnet = subnets [0 ]
518
603
519
604
default_interfaces = [
@@ -542,17 +627,54 @@ def _configure_subnet(config, compute):
542
627
return config
543
628
544
629
545
- def _list_subnets (config , compute ):
630
+ def _create_firewall_rule (config , compute , body ):
631
+ operation = (
632
+ compute .firewalls ()
633
+ .insert (project = config ["provider" ]["project_id" ], body = body )
634
+ .execute ()
635
+ )
636
+ response = wait_for_compute_global_operation (
637
+ config ["provider" ]["project_id" ], operation , compute
638
+ )
639
+ return response
640
+
641
+
642
+ def _create_vpcnet (config , compute , body ):
643
+ operation = (
644
+ compute .networks ()
645
+ .insert (project = config ["provider" ]["project_id" ], body = body )
646
+ .execute ()
647
+ )
648
+ response = wait_for_compute_global_operation (
649
+ config ["provider" ]["project_id" ], operation , compute
650
+ )
651
+ return response
652
+
653
+
654
+ def _list_vpcnets (config , compute ):
655
+ response = (
656
+ compute .networks ()
657
+ .list (
658
+ project = config ["provider" ]["project_id" ],
659
+ )
660
+ .execute ()
661
+ )
662
+
663
+ return response ["items" ] if "items" in response else []
664
+
665
+
666
+ def _list_subnets (config , compute , filter = None ):
546
667
response = (
547
668
compute .subnetworks ()
548
669
.list (
549
670
project = config ["provider" ]["project_id" ],
550
671
region = config ["provider" ]["region" ],
672
+ filter = filter ,
551
673
)
552
674
.execute ()
553
675
)
554
676
555
- return response ["items" ]
677
+ return response ["items" ] if "items" in response else []
556
678
557
679
558
680
def _get_subnet (config , subnet_id , compute ):
0 commit comments