-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathElasticCloudAdapter.py
140 lines (116 loc) · 5.35 KB
/
ElasticCloudAdapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import io
import os
import paramiko
from paramiko import ssh_exception
import yaml
from socket import timeout
# TODO: When paramiko is updated > 2.4.2 remove this warning squelch
# Last checked March 14, 2019
import warnings
warnings.filterwarnings(action='ignore', module='.*paramiko.*')
# Uncomment to enable paramiko logging
# import logging
# logging.basicConfig()
# logging.getLogger("paramiko").setLevel(logging.DEBUG)
class ElasticCloudAdapter:
ACTION_SHRINK = 'shrink'
ACTION_EXPAND = 'expand'
ACTION_DO_NOTHING = 'do_nothing'
def __init__(self, provider_name):
self._load_configuration(provider_name)
self._load_ssh_configuration()
def _load_configuration(self, service_name):
service_config = None
config_filename = 'cloud_config/config.yaml'
# Load config from yaml OR from environment
if os.path.exists(config_filename):
with open(config_filename) as f:
service_config = yaml.load(f, Loader=yaml.FullLoader)
else:
service_config = {
'BROKER_URL': os.environ.get("BROKER_URL"),
'services': {
'gce': {
'max': int(os.environ.get('GCE_MAX', 3)),
'min': int(os.environ.get('GCE_MIN', 1)),
'shrink_sensitivity': int(os.environ.get('GCE_SHRINK_SENSITIVITY', 3)),
'expand_sensitivity': int(os.environ.get('GCE_EXPAND_SENSITIVITY', 1)),
'image_name': os.environ.get('GCE_IMAGE_NAME'),
'use_gpus': os.environ.get('GCE_USE_GPUS'),
'vm_size': os.environ.get('GCE_VM_SIZE', "n1-standard-1"),
'datacenter': os.environ.get('GCE_DATACENTER', "us-west1-a"),
'service_account_key': os.environ.get('GCE_SERVICE_ACCOUNT_KEY'),
'service_account_file': os.environ.get('GCE_SERVICE_ACCOUNT_FILE'),
}
}
}
self.config = service_config['services'][service_name]
self.config['BROKER_URL'] = service_config['BROKER_URL']
def _load_ssh_configuration(self):
# Paramiko ssh library set up
ssh_config_filename = os.path.expanduser('~/.ssh/config')
self.ssh_client = paramiko.SSHClient()
self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
config = paramiko.config.SSHConfig()
try:
with open(ssh_config_filename) as f:
config.parse(f)
elastic_cloud_ssh_config = 'ElasticCloud'
user_config = config.lookup(elastic_cloud_ssh_config)
pkey_fn = user_config['identityfile'][0]
self.username = user_config['user']
self.pkey = paramiko.RSAKey.from_private_key_file(pkey_fn)#, password="placeholder")
except (FileNotFoundError, KeyError):
# We don't have a specific entry for this, use defaults
self.username = "ubuntu"
ssh_key = os.environ.get("GCE_SSH_PRIV")
print(f"Unable to find ElasticCloud SSHConfig, attempting to use GCE_SSH_PRIV env var = \n{ssh_key}")
if ssh_key:
self.pkey = paramiko.RSAKey.from_private_key(io.StringIO(ssh_key))
else:
# try to use default local one
self.pkey = paramiko.RSAKey.from_private_key_file(os.path.expanduser("~/.ssh/id_rsa"))
def _connect(self, host):
# Delete old known hosts entry for GCE VM ip address
known_hosts_filename = os.path.expanduser('~/.ssh/known_hosts')
if os.path.exists(known_hosts_filename):
kh = None
with open(known_hosts_filename, 'r') as f:
kh = f.readlines()
with open(known_hosts_filename, 'w') as f:
for line in kh:
line_ip = line.split()[0]
if not line_ip == host:
f.write(line)
try:
print(f"Attempting to connect to {self.username}@{host}")
self.ssh_client.connect(host, username=self.username, pkey=self.pkey, timeout=10)
except (ssh_exception.NoValidConnectionsError, ssh_exception.AuthenticationException):
print("ERROR :: Could not connect to host, maybe it is spinning up/down?")
raise
except timeout:
print('ssh timed out')
raise
except Exception as e:
print(e)
raise
def _run_ssh_command(self, host, command):
try:
self._connect(host)
except timeout:
print('ssh timed out')
raise
except (ssh_exception.NoValidConnectionsError, ssh_exception.AuthenticationException):
print("ERROR :: Could not connect to host, maybe it is spinning down?")
try:
stdin, stdout, stderr = self.ssh_client.exec_command(command)
except (ssh_exception.NoValidConnectionsError, ssh_exception.AuthenticationException):
print("ERROR :: Could not exec command on host, maybe it is spinning up/down?")
return (stdin, stdout, stderr)
def expand(self):
raise NotImplementedError
def shrink(self):
raise NotImplementedError
def dump_state(self):
raise NotImplementedError