-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchange_batch.py
31 lines (27 loc) · 1.17 KB
/
change_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import onnx
def change_input_dim(model,):
batch_size = "16"
# The following code changes the first dimension of every input to be batch_size
# Modify as appropriate ... note that this requires all inputs to
# have the same batch_size
inputs = model.graph.input
for input in inputs:
# Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
# Add checks as needed.
dim1 = input.type.tensor_type.shape.dim[0]
# update dim to be a symbolic value
if isinstance(batch_size, str):
# set dynamic batch size
dim1.dim_param = batch_size
elif (isinstance(batch_size, str) and batch_size.isdigit()) or isinstance(batch_size, int):
# set given batch size
dim1.dim_value = int(batch_size)
else:
# set batch size of 1
dim1.dim_value = 1
def apply(transform, infile, outfile):
model = onnx.load(infile)
transform(model,)
onnx.save(model, outfile)
apply(change_input_dim, "resnet50-v1-7.onnx", "resnet50-v1-7_bs16.onnx")
print("The Model with a new batch size saved successfully")