forked from livingshade/Metis
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcost_het_cluster.py
87 lines (73 loc) · 4.49 KB
/
cost_het_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright 2024 Samsung Electronics Co., Ltd. All Rights Reserved
import argparse
from typing import Dict, List, Tuple
import sys
import os
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from arguments import parse_args
from data_loader import ProfileDataLoader
from model.cost_estimator import HeteroCostEstimator
from model.activation_parameter import GPTActivationAndParam
from model.device_group import StagePerformance
from model.load_balancer import LayerLoadBalancer
from search_space.plan import IntraStagePlanGenerator, InterStagePlanGenerator
from gpu_cluster import GPUCluster
from utils import ModelConfig
def cost_het_cluster(args: argparse.Namespace, gpu_cluster: GPUCluster, profile_data: Dict, model_config: ModelConfig,
cost_estimator: HeteroCostEstimator, layer_load_balancer:LayerLoadBalancer, cache) -> List[Tuple]:
estimate_costs = []
inter_stage_plan_generator = InterStagePlanGenerator(device_types=set(gpu_cluster.get_device_types()),
num_devices=gpu_cluster.get_total_num_devices(),
gbs=args.gbs, num_layers=args.num_layers,
variance=args.min_group_scale_variance,
max_permute_len=args.max_permute_len)
total_count = 0
for inter_stage_plan in inter_stage_plan_generator:
# print(f'\n\ninter_stage_plan: {inter_stage_plan}')
stage_performance = StagePerformance(model_config, profile_data, gpu_cluster, inter_stage_plan)
rank_device_map = stage_performance.get_device_placement()
intra_stage_plan_generator = IntraStagePlanGenerator(inter_stage_plan, stage_performance, layer_load_balancer, args.max_profiled_tp_degree, args.max_profiled_batch_size, args.use_strat)
while intra_stage_plan_generator.has_next:
intra_stage_plan = intra_stage_plan_generator.next()
try:
cost = cost_estimator.get_cost(inter_stage_plan, intra_stage_plan.strategies,
intra_stage_plan.layer_partition, rank_device_map)
# print(f'cost: {cost}')
estimate_costs.append((inter_stage_plan.node_sequence, inter_stage_plan.device_groups,
intra_stage_plan.strategies, inter_stage_plan.batches,
intra_stage_plan.layer_partition, intra_stage_plan.num_repartition, cost))
except KeyError as e:
print(f'KeyError: {e}')
total_count += intra_stage_plan_generator.count
return estimate_costs, cache, total_count
if __name__ == '__main__':
args = parse_args()
gpu_cluster = GPUCluster(hostfile_path=args.hostfile_path, clusterfile_path=args.clusterfile_path)
data_loader = ProfileDataLoader(args.profile_data_path)
profile_data, _ = data_loader.load_profile_data_all()
# print(profile_data)
assert len(profile_data.keys()) > 0, 'There is no profiled data at the specified path.'
model_config = ModelConfig(model_name=args.model_name, num_layers=args.num_layers,
sequence_length=args.sequence_length, vocab_size=args.vocab_size,
hidden_size=args.hidden_size, attention_head_size=args.attention_head_size)
cache = {}
model_volume = GPTActivationAndParam(model_config, profile_data['model']['parameters'])
cost_estimator = HeteroCostEstimator(profile_data, model_config, model_volume, gpu_cluster)
layer_load_balancer = LayerLoadBalancer(gpu_cluster, profile_data, model_config, args.gbs)
trials = args.trials
total_time = 0
for i in range(trials):
start_time = time.time()
estimate_costs, cache, count = cost_het_cluster(args, gpu_cluster, profile_data, model_config, cost_estimator, layer_load_balancer, cache)
end_time = time.time()
total_time += (end_time - start_time) * 1000
print(f'Average time: {total_time / trials} ms')
sorted_result = sorted(estimate_costs, key=lambda kv: kv[6])
# print(f'len(costs): {len(estimate_costs)}')
print("count:", count)
print(
'rank, cost, node_sequence, device_groups, strategies(dp_deg, tp_deg), batches(number of batch), layer_partition')
for idx, result in enumerate(sorted_result):
print(f'{idx + 1}, {result[6]}, {result[0]}, {result[1]}, {result[2]}, {result[3]}, {result[4]}')
break