-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathupgrade.py
executable file
·575 lines (439 loc) · 19.5 KB
/
upgrade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
#!/usr/bin/env python3
import argparse
from datetime import timedelta
import math
import os
import re
import shutil
import subprocess
import sys
import time
#
# Constants
#
# Tag patterns to filter
CONSUL_TAG = 'consul'
NOMAD_SERVER_TAG = 'nomad-server'
NOMAD_CLIENT_TAG = 'nomad-client'
VAULT_TAG = 'vault'
# Default addresses
CONSUL_ADDR = 'http://127.0.0.1:8500'
NOMAD_ADDR = 'http://127.0.0.1:4646'
# Vault defaults
VAULT_USERNAME = 'ubuntu'
VAULT_TLS_SERVER = 'vault.service.consul'
VAULT_PORT = 8200
VAULT_CONSUL_SERVICE_NAME = 'vault'
VAULT_UNSEAL_COUNT = 3 # Default number of unseals required
SEALED_RE = re.compile(r'Sealed\s+true')
UNSEALED_RE = re.compile(r'Sealed\s+false')
# Interval of checking the change in set entries
CHECK_INTERVAL_SECS = 5
# Max time value to allow for ensuring both AWS and services have same set after killing
TIMEOUT_SECS = 300
# Commands to ensure to be available before starting
CMDS = [
'aws',
'consul',
'curl',
'nomad',
'jq',
'scp',
'ssh'
]
#
# Assertion commands
#
def assert_yes_no(response):
conv_reply = response.lower()
if conv_reply != 'y':
if conv_reply == 'n':
sys.exit('Aborting...')
else:
sys.exit('Invalid response, aborting...')
def assert_instance_count(instances):
if len(instances) < 3:
sys.exit(
'This process can only be run when there are >= 3 servers to maintain quorum.')
def assert_kill_count(kill_count):
if kill_count == 0:
sys.exit(
'Kill count cannot be 0.')
def assert_arg(flag, arg):
if not arg:
sys.exit('Flag "{}" must be set!'.format(flag))
def assert_collection_len(collection, length):
if len(collection) != length:
sys.exit('Collection "{}" does not have expected length {}!'.format(
collection, length))
def assert_file_exists(path):
if not os.path.exists(path):
sys.exit('"{}" does not exist!'.format(path))
def assert_cmds_exist(cmds):
for cmd in cmds:
if not shutil.which(cmd):
sys.exit('Require command "{}"!'.format(cmd))
def assert_same_instances(aws_instances, service_nodes):
if not aws_instances == service_nodes:
sys.exit('AWS set of instances: "{}" is different from current set of service nodes: "{}"! Aborting...'
.format(aws_instances, service_nodes))
#
# Test
#
# Can be used to swap out kill_fn
def fake_kill_fn(aws_instance):
print('Fake killing instance "{}"'.format(aws_instance))
# Can be used to swap out check_fn
def fake_check_fn(prev_instances, killing_idx):
print('Fake checking {} at killing index {}'.format(
prev_instances, killing_idx))
return True
# Can be used to override the default get_new_instances_fn
def fake_get_new_instances(prev_instances, curr_instances):
return list(curr_instances)[0]
#
# General
#
def unique(seq):
# Inefficient but works
# http://www.martinbroadhurst.com/removing-duplicates-from-a-list-while-preserving-order-in-python.html
seen = set()
return [x for x in seq if not (x in seen or seen.add(x))]
def calc_max_kill_count(instances):
# The number of instances left must be simple majority to maintain quorum
# https://www.consul.io/docs/internals/consensus.html#deployment-table
return max((len(instances) + 1) // 2 - 1, 0)
def invoke_shell(cmd):
return subprocess.check_output(cmd, shell=True).decode('utf8').strip()
def get_instance_ids_from_tag(tag_pattern):
try:
return set(invoke_shell("""
aws ec2 describe-instances --filter \
"Name=tag:Name,Values=*{}*" \
"Name=instance-state-name,Values=running" | \
jq --raw-output '.Reservations[].Instances[].InstanceId'
""".format(tag_pattern)).split())
except:
raise AssertionError(
'Cannot find instance IDs with tag pattern "{}"!'.format(tag_pattern))
def get_instance_ip_addrs_from_tag(tag_pattern):
try:
return set(invoke_shell("""
aws ec2 describe-instances --filter \
"Name=tag:Name,Values=*{}*" \
"Name=instance-state-name,Values=running" | \
jq --raw-output '.Reservations[].Instances[].PrivateIpAddress'
""".format(tag_pattern)).split())
except:
raise AssertionError(
'Cannot get instance IP addresses with tag pattern "{}"!'.format(tag_pattern))
def get_instance_ip_addr_from_id(id):
try:
return invoke_shell("""
aws ec2 describe-instances --filter \
"Name=instance-state-name,Values=running" | \
jq --raw-output '.Reservations[].Instances[] |
select(.InstanceId == "{}") |
.PrivateIpAddress'
""".format(id)).strip()
except:
raise AssertionError(
'Cannot find instance IP address with Node ID: "{}"!'.format(id))
def get_instance_ip_addrs_from_ids(ids):
return map(get_instance_ip_addr_from_id, ids)
def get_new_instances_from_prev(prev_instances, curr_instances):
return curr_instances.difference(prev_instances)
def kill_fn(aws_instance):
try:
invoke_shell("""
aws autoscaling terminate-instance-in-auto-scaling-group \
--no-should-decrement-desired-capacity \
--instance-id {}
""".format(aws_instance))
except:
raise AssertionError('Unable to terminate AWS instance "{}" from its ASG!'
.format(aws_instance))
def try_until_timeout(check_fn, prev_aws_instances, killing_idx, check_interval, timeout):
elapsed_time = timedelta(seconds=0)
while True:
if check_fn(prev_aws_instances, killing_idx):
return True
if elapsed_time >= timeout:
return False
wait_secs = check_interval.total_seconds()
print('> Waiting for {}s ({}s elapsed)'.format(
wait_secs, elapsed_time.total_seconds()))
time.sleep(wait_secs)
elapsed_time += check_interval
def check_service_up(prev_aws_instances, killing_idx, kill_count, tag_pattern, get_service_nodes_fn):
# We assume that prev_aws_instances contain the same number of entries as
# the original
instance_count = len(prev_aws_instances)
curr_aws_instances = get_instance_ids_from_tag(tag_pattern)
# Cater for remaining instances when the count is less than N
wait_n = min(instance_count - killing_idx, kill_count)
return \
len(curr_aws_instances) == instance_count and \
len(curr_aws_instances.difference(prev_aws_instances)) == wait_n and \
curr_aws_instances == get_service_nodes_fn()
def kill_check_post(kill_fn, check_fn, post_fn, tag_pattern, kill_count, check_interval, timeout, get_new_instances_fn=get_new_instances_from_prev):
print("Killing {} instance(s) in one go...".format(kill_count))
orig_aws_instances = get_instance_ids_from_tag(tag_pattern)
instance_count = len(orig_aws_instances)
for idx, orig_aws_instance in enumerate(orig_aws_instances):
prev_aws_instances = get_instance_ids_from_tag(tag_pattern)
print('Killing instance "{} ({}/{})"...'.format(orig_aws_instance,
idx + 1, instance_count))
kill_fn(orig_aws_instance)
# Check only after every N kills
if (idx + 1) % kill_count == 0 or (idx + 1) == instance_count:
print('KILL okay! Waiting for new instance(s) to spin up...')
# Check for new instances to be up
if not try_until_timeout(check_fn, prev_aws_instances, idx, check_interval, timeout):
raise AssertionError(
'New instance(s) is/are unable to join the service after timeout of {}s, aborting...'.format(timeout.total_seconds()))
curr_instances = get_instance_ids_from_tag(tag_pattern)
new_instances = get_new_instances_fn(
prev_aws_instances, curr_instances)
print('New instance(s) found: {}'.format(new_instances))
# Post check_fn action
post_fn(new_instances)
#
# Consul specifics
#
def list_consul_peers(address):
try:
return set(invoke_shell("""
consul operator raft list-peers --http-addr {} | \
grep -E i-[0-9a-f]+ | \
cut -d " " -f 1
""".format(address)).split())
except:
raise AssertionError('Unable to obtain peers from Consul operator raft list from "{}"!'
.format(address))
#
# Nomad Server specifics
#
def list_nomad_server_members(address):
try:
return set(invoke_shell(r"""
nomad server members -address {} | grep alive | sed -E 's/^(i-[0-9a-f]+)\..+$/\1/'
""".format(address)).split())
except:
raise AssertionError('Unable to obtain Nomad Server members from "{}"!'
.format(address))
#
# Vault specifics
#
def list_vault_members(consul_addr, service_name):
try:
return set(invoke_shell("""
( \
curl -s {addr}/v1/catalog/service/{name}?tag=standby | jq --raw-output '.[].Node' && \
curl -s {addr}/v1/catalog/service/{name}?tag=active | jq --raw-output '.[].Node' \
) | \
cat
""".format(
addr=consul_addr,
name=service_name)).split())
except:
raise AssertionError(
'Unable to obtain Vault members from Consul catalog!')
def unseal_vault(ip_addr, vault_port, tls_server, ca_cert, unseal_key):
try:
return invoke_shell("""
vault operator unseal \
-address https://{addr}:{port} \
-tls-server-name={tls_server} \
-ca-cert={ca_cert} {unseal_key}
""".format(
addr=ip_addr,
port=vault_port,
tls_server=tls_server,
ca_cert=ca_cert,
unseal_key=unseal_key))
except:
raise AssertionError(
'Unable to unseal vault in "https://{}:{}" with given unseal key!'
.format(ip_addr, vault_port))
def unseal_and_check_vault(new_instances, vault_port, tls_server, ca_cert_path, unseal_keys):
new_ip_addrs = get_instance_ip_addrs_from_ids(new_instances)
for new_ip_addr in new_ip_addrs:
# Unseal and check status each time
for idx, unseal_key in enumerate(unseal_keys):
key_idx = idx + 1
seal_check_re = UNSEALED_RE if key_idx == len(
unseal_keys) else SEALED_RE
print('Unsealing vault with key #{} for "{}"...'.format(
key_idx, new_ip_addr))
unseal_output = unseal_vault(
new_ip_addr, vault_port, tls_server, ca_cert_path, unseal_key)
if not seal_check_re.search(unseal_output):
raise AssertionError(
'Unexpected seal status after using key #{}'.format(key_idx))
print('Unseal vault with key #{} okay!'.format(key_idx))
#
# High level
#
def upgrade_consul(consul_tag_pattern, address, check_interval, timeout, fast_mode):
print("Upgrading Consul instances...")
# Sanity check
aws_instances = get_instance_ids_from_tag(consul_tag_pattern)
assert_instance_count(aws_instances)
consul_nodes = list_consul_peers(address)
print('AWS instances: {}'.format(aws_instances))
print(' Consul nodes: {}'.format(consul_nodes))
assert_same_instances(aws_instances, consul_nodes)
kill_count = 1 if not fast_mode else calc_max_kill_count(aws_instances)
assert_kill_count(kill_count)
def check_fn(prev_aws_instances, idx):
return check_service_up(prev_aws_instances, idx, kill_count,
consul_tag_pattern, lambda: list_consul_peers(address))
def post_fn(new_instances):
# Do nothing
pass
kill_check_post(
kill_fn,
check_fn,
post_fn,
consul_tag_pattern,
kill_count,
check_interval,
timeout)
def upgrade_nomad_server(nomad_server_tag_pattern, address, check_interval, timeout, fast_mode):
print("Upgrading Nomad Servers instances...")
# Sanity check
aws_instances = get_instance_ids_from_tag(nomad_server_tag_pattern)
assert_instance_count(aws_instances)
nomad_servers = list_nomad_server_members(address)
print('AWS instances: {}'.format(aws_instances))
print('Nomad servers: {}'.format(nomad_servers))
assert_same_instances(aws_instances, nomad_servers)
kill_count = 1 if not fast_mode else calc_max_kill_count(aws_instances)
assert_kill_count(kill_count)
def check_fn(prev_aws_instances, idx):
return check_service_up(prev_aws_instances, idx, kill_count, nomad_server_tag_pattern,
lambda: list_nomad_server_members(address))
def post_fn(new_instances):
# Do nothing
pass
kill_check_post(
kill_fn,
check_fn,
post_fn,
nomad_server_tag_pattern,
kill_count,
check_interval,
timeout)
# TODO - Need to figure out how to properly wait for all upgraded allocs get
# reallocated first synchronously
# def upgrade_nomad_client(nomad_client_tag_pattern, address, check_interval, timeout, fast_mode):
# pass
def upgrade_vault(tag_pattern, consul_addr, service_name, vault_port, tls_server, ca_cert_path, unseal_count, check_interval, timeout, fast_mode):
aws_instances = get_instance_ids_from_tag(tag_pattern)
assert_instance_count(aws_instances)
vault_servers = list_vault_members(consul_addr, service_name)
print('AWS instances: {}'.format(aws_instances))
print('Vault servers: {}'.format(vault_servers))
assert_same_instances(aws_instances, vault_servers)
kill_count = 1 if not fast_mode else calc_max_kill_count(aws_instances)
assert_kill_count(kill_count)
if unseal_count > 0:
# Prompt for unseal key only after all the instance assertion
print('Enter any {} Vault unseal key(s):'.format(unseal_count))
unseal_keys = set()
for _ in range(0, unseal_count):
unseal_keys.add(input().strip())
assert_collection_len(unseal_keys, unseal_count)
def check_fn(prev_aws_instances, idx):
return check_service_up(prev_aws_instances, idx, kill_count, tag_pattern,
lambda: list_vault_members(consul_addr, service_name))
def post_fn(new_instances):
unseal_and_check_vault(new_instances, vault_port,
tls_server, ca_cert_path, unseal_keys)
kill_check_post(
kill_fn,
check_fn,
post_fn,
tag_pattern,
kill_count,
check_interval,
timeout)
#
# Main
#
if __name__ == '__main__':
parser = argparse.ArgumentParser('Script to upgrade service instances')
# Dashes get converted into underscore when accessing the fields
parser.add_argument(
'service', nargs='+', help='Service type to upgrade. consul | nomad-server | nomad-client | vault')
parser.add_argument('--consul-tag', default=CONSUL_TAG,
help='Tag pattern of Consul instances. Defaults to "{}".'.format(CONSUL_TAG))
parser.add_argument('--consul-addr', default=CONSUL_ADDR,
help='Consul server address to connect to. Used by both consul and vault commands. Defaults to "{}".'.format(CONSUL_ADDR))
parser.add_argument('--nomad-server-tag', default=NOMAD_SERVER_TAG,
help='Tag pattern of Nomad Server instances. Defaults to "{}".'.format(NOMAD_SERVER_TAG))
parser.add_argument('--nomad-client-tag', default=NOMAD_CLIENT_TAG,
help='Tag pattern of Nomad Client instances. Defaults to "{}".'.format(NOMAD_CLIENT_TAG))
parser.add_argument('--nomad-addr', default=NOMAD_ADDR,
help='Nomad Server address to connect to. Defaults to "{}".'.format(NOMAD_ADDR))
parser.add_argument('--vault-tag', default=VAULT_TAG,
help='Tag pattern of Vault instances. Defaults to "{}".'.format(VAULT_TAG))
parser.add_argument('--vault-tls-server', default=VAULT_TLS_SERVER,
help='TLS server to point to when connecting to the Vault server via TLS. Defaults to "{}".'.format(VAULT_TLS_SERVER))
parser.add_argument('--vault-ca-cert',
help='Path to CA certificate on this host machine for unsealing')
parser.add_argument('--vault-port', default=VAULT_PORT,
help='Port to use for Vault unsealing. Defaults to {}.'.format(VAULT_PORT))
parser.add_argument('--vault-consul-service-name', default=VAULT_CONSUL_SERVICE_NAME,
help='Vault consul service name to perform API calls on. Defaults to "{}".'.format(VAULT_CONSUL_SERVICE_NAME))
parser.add_argument('--vault-unseal-count', type=int, default=VAULT_UNSEAL_COUNT,
help='Number of unseal keys required to fully unseal a new Vault server. Set to 0 if you are using auto unseal. Defaults to {}.'.format(VAULT_UNSEAL_COUNT))
parser.add_argument('--check-interval', type=int, default=CHECK_INTERVAL_SECS,
help='Interval of checking success for every upgrade step. Defaults to "{}" seconds.'.format(CHECK_INTERVAL_SECS))
parser.add_argument('--timeout', type=int, default=TIMEOUT_SECS,
help='Max time value to allow for ensuring consistency in upgrade after killing. Defaults to "{}" seconds.'.format(TIMEOUT_SECS))
parser.add_argument('--fast', action='store_true',
help='Activates the FAST and FURIOUS (but DANGEROUS) mode. Kill max allowable instances within quorum each time to reduce total time taken.')
args = parser.parse_args()
services = unique(args.service)
check_interval = timedelta(seconds=args.check_interval)
timeout = timedelta(seconds=args.timeout)
fast_mode = args.fast
# Give verbose warning here
if fast_mode:
sys.stdout.write(
'Are you sure you want the FAST and FURIOUS mode turned on (y/n)? ')
sys.stdout.flush()
assert_yes_no(sys.stdin.readline().strip())
print('FAST and FURIOUS mode activated! Be mindful about maintaining the quorum in each service.')
print('Services to upgrade (in order): {}'.format(services))
assert_cmds_exist(CMDS)
for service in services:
# All environment addresses should not contain trailing slash
# e.g. http://consul.x.y OR https://consul.x.y
if service == 'consul':
upgrade_consul(args.consul_tag, args.consul_addr,
check_interval, timeout, fast_mode)
print('DONE Consul upgrading!')
elif service == 'nomad-server':
upgrade_nomad_server(args.nomad_server_tag,
args.nomad_addr, check_interval, timeout, fast_mode)
print('DONE Nomad Server upgrading!')
elif service == 'nomad-client':
# TODO
# upgrade_nomad_client(args.nomad_client_tag, args.nomad_addr, check_interval, timeout, fast_mode)
print('DONE Nomad Client upgrading!')
elif service == 'vault':
vault_ca_cert = args.vault_ca_cert
if args.vault_unseal_count > 0:
assert_arg('--vault-ca-cert', vault_ca_cert)
assert_file_exists(vault_ca_cert)
upgrade_vault(args.vault_tag, args.consul_addr, args.vault_consul_service_name,
args.vault_port, args.vault_tls_server, vault_ca_cert,
args.vault_unseal_count,
check_interval, timeout, fast_mode)
print('DONE Vault upgrading!')
else:
print('Ignoring unknown command "{}"'.format(service))
print("DONE!")