Skip to content

Commit b0894c8

Browse files
authored
Add GCP VPC creation if no VPC available (skypilot-org#1594)
* add vpc creation * shorten * create vpcnet in ray up instead * implement best effort * fix * add constants.py * fix * quote.. * comments * comment * fix * backward compatibility * comment * list_instances * fix * fix
1 parent 7adb54e commit b0894c8

File tree

5 files changed

+208
-23
lines changed

5 files changed

+208
-23
lines changed

sky/backends/backend_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -911,13 +911,19 @@ def write_cluster_config(
911911
tpu_name = cluster_name
912912

913913
user_file_dir = os.path.expanduser(f'{SKY_USER_FILE_PATH}/')
914+
915+
from sky.skylet.providers.gcp import config as gcp_config # pylint: disable=import-outside-toplevel
916+
config = common_utils.read_yaml(os.path.expanduser(config_dict['ray']))
917+
vpc_name = gcp_config.get_usable_vpc(config)
918+
914919
scripts = tuple(
915920
fill_template(
916921
template_name,
917922
dict(
918923
resources_vars, **{
919924
'tpu_name': tpu_name,
920925
'gcp_project_id': gcp_project_id,
926+
'vpc_name': vpc_name,
921927
}),
922928
# Use new names for TPU scripts so that different runs can use
923929
# different TPUs. Put in SKY_USER_FILE_PATH to be consistent

sky/skylet/providers/gcp/config.py

+134-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from google.oauth2.credentials import Credentials as OAuthCredentials
1313
from googleapiclient import discovery, errors
1414

15-
from sky.skylet.providers.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType
15+
from sky.skylet.providers.gcp.node import MAX_POLLS, POLL_INTERVAL, GCPNodeType, GCPCompute
16+
from sky.skylet.providers.gcp.constants import SKYPILOT_VPC_NAME, VPC_TEMPLATE, FIREWALL_RULES_TEMPLATE
1617
from ray.autoscaler._private.util import check_legacy_fields
1718

1819
logger = logging.getLogger(__name__)
@@ -46,7 +47,6 @@
4647
# NOTE: iam.serviceAccountUser allows the Head Node to create worker nodes
4748
# with ServiceAccounts.
4849

49-
5050
def get_node_type(node: dict) -> GCPNodeType:
5151
"""Returns node type based on the keys in ``node``.
5252
@@ -488,6 +488,96 @@ def _configure_key_pair(config, compute):
488488
return config
489489

490490

491+
def _check_firewall_rules(vpc_name, config, compute):
492+
"""Check if the firewall rules in the VPC are sufficient."""
493+
required_rules = FIREWALL_RULES_TEMPLATE.copy()
494+
495+
operation = (
496+
compute.networks().
497+
getEffectiveFirewalls(
498+
project=config["provider"]["project_id"],
499+
network=vpc_name
500+
)
501+
)
502+
response = operation.execute()
503+
if len(response) == 0:
504+
return False
505+
effective_rules = response["firewalls"]
506+
507+
def _get_refined_rule(rule):
508+
KEY_TO_COMPARE = {"sourceRanges", "allowed", "direction"}
509+
refined_rule = {}
510+
for k in KEY_TO_COMPARE:
511+
if k not in rule:
512+
continue
513+
if k == "allowed":
514+
refined_rule[k] = sorted(rule[k], key=lambda x: x["IPProtocol"])
515+
else:
516+
refined_rule[k] = rule[k]
517+
return refined_rule
518+
519+
required_rules = list(map(_get_refined_rule, required_rules))
520+
effective_rules = list(map(_get_refined_rule, effective_rules))
521+
522+
for rule in required_rules:
523+
if rule not in effective_rules:
524+
return False
525+
return True
526+
527+
528+
def get_usable_vpc(config):
529+
"""Return a usable VPC.
530+
531+
If not found, create a new one with sufficient firewall rules.
532+
"""
533+
_, _, compute, _ = construct_clients_from_provider_config(config["provider"])
534+
535+
# For backward compatibility, reuse the VPC if the VM is launched.
536+
resource = GCPCompute(
537+
compute,
538+
config["provider"]["project_id"],
539+
config["provider"]["availability_zone"],
540+
config["cluster_name"],
541+
)
542+
node = resource._list_instances(label_filters=None, status_filter=None)
543+
if len(node) > 0:
544+
netInterfaces = node[0].get("networkInterfaces", [])
545+
if len(netInterfaces) > 0:
546+
vpc_name = netInterfaces[0]["network"].split("/")[-1]
547+
return vpc_name
548+
549+
vpcnets_all = _list_vpcnets(config, compute)
550+
551+
usable_vpc_name = None
552+
for vpc in vpcnets_all:
553+
if _check_firewall_rules(vpc["name"], config, compute):
554+
usable_vpc_name = vpc["name"]
555+
break
556+
557+
if usable_vpc_name is None:
558+
logger.info(f"Creating a default VPC network, {SKYPILOT_VPC_NAME}...")
559+
560+
# Create a default VPC network
561+
proj_id = config["provider"]["project_id"]
562+
body = VPC_TEMPLATE.copy()
563+
body["name"] = body["name"].format(VPC_NAME=SKYPILOT_VPC_NAME)
564+
body["selfLink"] = body["selfLink"].format(PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME)
565+
_create_vpcnet(config, compute, body)
566+
567+
# Create firewall rules
568+
for rule in FIREWALL_RULES_TEMPLATE:
569+
body = rule.copy()
570+
body["name"] = body["name"].format(VPC_NAME=SKYPILOT_VPC_NAME)
571+
body["network"] = body["network"].format(PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME)
572+
body["selfLink"] = body["selfLink"].format(PROJ_ID=proj_id, VPC_NAME=SKYPILOT_VPC_NAME)
573+
_create_firewall_rule(config, compute, body)
574+
575+
usable_vpc_name = SKYPILOT_VPC_NAME
576+
logger.info(f"A VPC network {SKYPILOT_VPC_NAME} created.")
577+
578+
return usable_vpc_name
579+
580+
491581
def _configure_subnet(config, compute):
492582
"""Pick a reasonable subnet if not specified by the config."""
493583
config = copy.deepcopy(config)
@@ -506,14 +596,9 @@ def _configure_subnet(config, compute):
506596
):
507597
return config
508598

509-
subnets = _list_subnets(config, compute)
510-
511-
if not subnets:
512-
raise NotImplementedError("Should be able to create subnet.")
513-
514-
# TODO: make sure that we have usable subnet. Maybe call
515-
# compute.subnetworks().listUsable? For some reason it didn't
516-
# work out-of-the-box
599+
# SkyPilot: make sure there's a usable VPC
600+
usable_vpc_name = get_usable_vpc(config)
601+
subnets = _list_subnets(config, compute, filter=f"(name=\"{usable_vpc_name}\")")
517602
default_subnet = subnets[0]
518603

519604
default_interfaces = [
@@ -542,17 +627,54 @@ def _configure_subnet(config, compute):
542627
return config
543628

544629

545-
def _list_subnets(config, compute):
630+
def _create_firewall_rule(config, compute, body):
631+
operation = (
632+
compute.firewalls()
633+
.insert(project=config["provider"]["project_id"], body=body)
634+
.execute()
635+
)
636+
response = wait_for_compute_global_operation(
637+
config["provider"]["project_id"], operation, compute
638+
)
639+
return response
640+
641+
642+
def _create_vpcnet(config, compute, body):
643+
operation = (
644+
compute.networks()
645+
.insert(project=config["provider"]["project_id"], body=body)
646+
.execute()
647+
)
648+
response = wait_for_compute_global_operation(
649+
config["provider"]["project_id"], operation, compute
650+
)
651+
return response
652+
653+
654+
def _list_vpcnets(config, compute):
655+
response = (
656+
compute.networks()
657+
.list(
658+
project=config["provider"]["project_id"],
659+
)
660+
.execute()
661+
)
662+
663+
return response["items"] if "items" in response else []
664+
665+
666+
def _list_subnets(config, compute, filter=None):
546667
response = (
547668
compute.subnetworks()
548669
.list(
549670
project=config["provider"]["project_id"],
550671
region=config["provider"]["region"],
672+
filter=filter,
551673
)
552674
.execute()
553675
)
554676

555-
return response["items"]
677+
return response["items"] if "items" in response else []
556678

557679

558680
def _get_subnet(config, subnet_id, compute):

sky/skylet/providers/gcp/constants.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
SKYPILOT_VPC_NAME = "skypilot-vpc"
2+
3+
# Below parameters are from the default VPC on GCP.
4+
# https://cloud.google.com/vpc/docs/firewalls#more_rules_default_vpc
5+
VPC_TEMPLATE = {
6+
"name": "{VPC_NAME}",
7+
"selfLink": "projects/{PROJ_ID}/global/networks/{VPC_NAME}",
8+
"autoCreateSubnetworks": True,
9+
"mtu": 1460,
10+
"routingConfig": {"routingMode": "GLOBAL"},
11+
}
12+
FIREWALL_RULES_TEMPLATE = [
13+
{
14+
"name": "{VPC_NAME}-allow-custom",
15+
"description": "Allows connection from any source to any instance on the network using custom protocols.",
16+
"network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}",
17+
"selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-custom",
18+
"direction": "INGRESS",
19+
"priority": 65534,
20+
"allowed": [
21+
{'IPProtocol': 'tcp', 'ports': ['0-65535']},
22+
{'IPProtocol': 'udp', 'ports': ['0-65535']},
23+
{'IPProtocol': 'icmp'}
24+
],
25+
"sourceRanges": ["10.128.0.0/9"],
26+
},
27+
{
28+
"name": "{VPC_NAME}-allow-ssh",
29+
"description": "Allows TCP connections from any source to any instance on the network using port 22.",
30+
"network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}",
31+
"selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-ssh",
32+
"direction": "INGRESS",
33+
"priority": 65534,
34+
"allowed": [{
35+
"IPProtocol": "tcp",
36+
"ports": ["22"],
37+
}],
38+
"sourceRanges": ["0.0.0.0/0"],
39+
},
40+
{
41+
"name": "{VPC_NAME}-allow-icmp",
42+
"description": "Allows ICMP connections from any source to any instance on the network.",
43+
"network": "projects/{PROJ_ID}/global/networks/{VPC_NAME}",
44+
"selfLink": "projects/{PROJ_ID}/global/firewalls/{VPC_NAME}-allow-icmp",
45+
"direction": "INGRESS",
46+
"priority": 65534,
47+
"allowed": [{
48+
"IPProtocol": "icmp",
49+
}],
50+
"sourceRanges": ["0.0.0.0/0"],
51+
},
52+
]

sky/skylet/providers/gcp/node.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,8 @@ def list_instances(self, label_filters: Optional[dict] = None,
378378
return self._list_instances(label_filters, non_terminated_status)
379379

380380
def _list_instances(self, label_filters: Optional[dict],
381-
status_filter: List[str]) -> List[GCPComputeNode]:
381+
status_filter: Optional[List[str]]
382+
) -> List[GCPComputeNode]:
382383
label_filters = label_filters or {}
383384

384385
if label_filters:
@@ -395,16 +396,19 @@ def _list_instances(self, label_filters: Optional[dict],
395396
else:
396397
label_filter_expr = ""
397398

398-
instance_state_filter_expr = (
399-
"("
400-
+ " OR ".join(
401-
[
402-
"(status = {status})".format(status=status)
403-
for status in status_filter
404-
]
399+
if status_filter:
400+
instance_state_filter_expr = (
401+
"("
402+
+ " OR ".join(
403+
[
404+
"(status = {status})".format(status=status)
405+
for status in status_filter
406+
]
407+
)
408+
+ ")"
405409
)
406-
+ ")"
407-
)
410+
else:
411+
instance_state_filter_expr = ""
408412

409413
cluster_name_filter_expr = "(labels.{key} = {value})".format(
410414
key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name

sky/templates/gcp-tpu-create.sh.j2

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ gcloud config set project $PROJECT_ID
99
yes | gcloud compute tpus create {{tpu_name}} \
1010
--zone={{zones}} \
1111
--version={{runtime_version}} \
12-
--accelerator-type={{tpu_type}}
12+
--accelerator-type={{tpu_type}} \
13+
--network={{vpc_name}}

0 commit comments

Comments
 (0)