@@ -57,11 +57,55 @@ def tearDown(self):
57
57
shutil .rmtree (self .temp_dir )
58
58
59
59
def test_get_shapes_for_config (self ):
60
+ # Test custom shapes
60
61
shapes = get_shapes_for_config (
61
62
self .test_config ["model_params" ][0 ]["matrix_shapes" ]
62
63
)
63
64
self .assertEqual (len (shapes ), 1 )
64
65
self .assertEqual (shapes [0 ], ("custom" , [1024 , 1024 , 1024 ]))
66
+
67
+ # Test llama shapes
68
+ llama_shapes = get_shapes_for_config ([
69
+ {"name" : "llama" }
70
+ ])
71
+ self .assertEqual (len (llama_shapes ), 4 ) # 4 LLaMa shapes
72
+ self .assertTrue (any (name .startswith ("llama_attn.wqkv" ) for name , _ in llama_shapes ))
73
+ self .assertTrue (any (name .startswith ("llama_attn.w0" ) for name , _ in llama_shapes ))
74
+ self .assertTrue (any (name .startswith ("llama_ffn.w13" ) for name , _ in llama_shapes ))
75
+ self .assertTrue (any (name .startswith ("llama_ffn.w2" ) for name , _ in llama_shapes ))
76
+
77
+ # Test pow2 shapes
78
+ pow2_shapes = get_shapes_for_config ([
79
+ {"name" : "pow2" , "min_power" : 10 , "max_power" : 12 }
80
+ ])
81
+ self .assertEqual (len (pow2_shapes ), 3 ) # 3 powers of 2 (10, 11, 12)
82
+ self .assertEqual (pow2_shapes [0 ], ("pow2_0" , [1024 , 1024 , 1024 ])) # 2^10
83
+ self .assertEqual (pow2_shapes [1 ], ("pow2_1" , [2048 , 2048 , 2048 ])) # 2^11
84
+ self .assertEqual (pow2_shapes [2 ], ("pow2_2" , [4096 , 4096 , 4096 ])) # 2^12
85
+
86
+ # Test pow2_extended shapes
87
+ pow2_extended_shapes = get_shapes_for_config ([
88
+ {"name" : "pow2_extended" , "min_power" : 10 , "max_power" : 11 }
89
+ ])
90
+ self .assertEqual (len (pow2_extended_shapes ), 4 ) # 2 powers of 2, each with 2 variants
91
+ self .assertEqual (pow2_extended_shapes [0 ], ("pow2_extended_0" , [1024 , 1024 , 1024 ])) # 2^10
92
+ self .assertEqual (pow2_extended_shapes [1 ], ("pow2_extended_1" , [1536 , 1536 , 1536 ])) # 2^10 + 2^9
93
+ self .assertEqual (pow2_extended_shapes [2 ], ("pow2_extended_2" , [2048 , 2048 , 2048 ])) # 2^11
94
+ self .assertEqual (pow2_extended_shapes [3 ], ("pow2_extended_3" , [3072 , 3072 , 3072 ])) # 2^11 + 2^10
95
+
96
+ # Test sweep shapes (limited to a small range for testing)
97
+ sweep_shapes = get_shapes_for_config ([
98
+ {"name" : "sweep" , "min_power" : 8 , "max_power" : 9 }
99
+ ])
100
+ # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
101
+ self .assertEqual (len (sweep_shapes ), 8 )
102
+ # Check that all shapes have the expected format
103
+ for name , shape in sweep_shapes :
104
+ self .assertTrue (name .startswith ("sweep_" ))
105
+ self .assertEqual (len (shape ), 3 ) # [M, K, N]
106
+ # Check that all dimensions are powers of 2 between 2^8 and 2^9
107
+ for dim in shape :
108
+ self .assertTrue (dim in [256 , 512 ]) # 2^8, 2^9
65
109
66
110
def test_get_param_combinations (self ):
67
111
model_param = self .test_config ["model_params" ][0 ]
0 commit comments