forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_metal_shader_lib.py
68 lines (54 loc) · 1.65 KB
/
gen_metal_shader_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import os
import sys
import yaml
if len(sys.argv) != 2:
print("Usage: gen_metal_shader_lib.py <output_file>")
sys.exit(1)
# Output file where the generated code will be written
OUTPUT_FILE = sys.argv[1]
MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# Path to yaml file containing the list of .metal files to include
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
metal_files = set()
with open(METAL_YAML, "r") as yamlf:
metal_config = yaml.safe_load(yamlf)
for op in metal_config:
if "file" in op:
metal_files.add(op["file"])
metal_files = sorted(metal_files)
# Path to the folder containing the .metal files
METAL_DIR = os.path.join(MPS_DIR, "metal")
prefix = """/**
* This file is generated by gen_metal_shader_lib.py
*/
#ifdef USE_ATEN
using at::native::mps::MetalShaderLibrary;
#else
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
#endif
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
"""
suffix = """
)METAL_LOWBIT");
"""
comment = """
/**
* Contents of {}
*/
"""
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
with open(OUTPUT_FILE, "w") as outf:
outf.write(prefix)
for file in metal_files:
with open(os.path.join(METAL_DIR, file), "r") as f:
content = f.read()
outf.write(comment.format(file))
outf.write(content)
outf.write("\n\n")
outf.write(suffix)