Skip to content

Commit

Permalink
Lstm with Batchsize Decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
ashokkumarkannan1 committed Feb 9, 2025
1 parent 705bca2 commit 000102c
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 180 deletions.
2 changes: 0 additions & 2 deletions python/tvm/contrib/forge_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,5 +1174,3 @@ def default(self, obj):
os.makedirs(os.path.dirname(store_path), exist_ok=True)
with open(store_path, 'w') as file:
file.write(serilized_str)

logger.info(f"Successfully stored serilized TVM graph to {store_path} path")
300 changes: 123 additions & 177 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tvm.ir import IRModule
from tvm.ir.type import DictType
from tvm.topi.utils import get_const_tuple

from tvm import relay
from .. import analysis as _analysis
from .. import expr as _expr
from .. import function as _function
Expand All @@ -60,6 +60,7 @@
__all__ = ["from_pytorch"]



# This returns a "subgraph" which puts variables whenever
# the type is known. It also records things to map the input
# nodes to the extracted graph's nodes.
Expand Down Expand Up @@ -157,7 +158,7 @@ def _is_int_seq(seq):
class PyTorchOpConverter:
"""A helper class for holding PyTorch op converters."""

def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
def __init__(self, prelude, default_dtype,use_parser_friendly_name=False):
self.prelude = prelude
self.default_dtype = default_dtype
self.create_convert_map()
Expand Down Expand Up @@ -3923,193 +3924,137 @@ def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dro

return _op.stack(input_seqs, 0), final_hiddens

def lstm(self, inputs, input_types):

def lstm_cell_tvm(self,input_t, h_prev, c_prev, weight_ih, weight_hh, bias_ih, bias_hh, batch_size_t, hidden_size):
"""
Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
Native implementation for torch version less than 1.8.0 (projection is unsupported):
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \
src/ATen/native/RNN.cpp#L1396
Native implementation for torch version from 1.8.0 and higher (projection is supported):
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483
Single LSTM cell operation using TVM ops for packed input.
Args:
input_t: Input tensor at current timestep [batch_size_t, input_size].
h_prev: Previous hidden state [1, batch_size_t, hidden_size].
c_prev: Previous cell state [1, batch_size_t, hidden_size].
weight_ih: Input-hidden weights [4 * hidden_size, input_size].
weight_hh: Hidden-hidden weights [4 * hidden_size, hidden_size].
bias_ih: Input-hidden biases [4 * hidden_size].
bias_hh: Hidden-hidden biases [4 * hidden_size].
batch_size_t: Current batch size.
hidden_size: Hidden state size.
Returns:
h_next: Next hidden state [1, batch_size_t, hidden_size].
c_next: Next cell state [1, batch_size_t, hidden_size].
"""
# TODO (vvchernov): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)
# Multiply input by input-hidden weights and add bias
gates_input = relay.nn.dense(input_t, weight_ih) + bias_ih

hidden_states = inputs[1]
assert len(hidden_states) == 2, "lstm expects two hidden states"
h_0 = hidden_states[0]
c_0 = hidden_states[1]
# H0 shape (hidden_layers_num, batch, proj_size) if projection
# else (hidden_layers_num, batch, hidden_size)
# C0 shape (hidden_layers_num, batch, hidden_size)
# Multiply previous hidden state by hidden-hidden weights and add bias
h_prev_squeezed = relay.squeeze(h_prev, axis=[0])
gates_hidden = relay.nn.dense(h_prev_squeezed, weight_hh) + bias_hh

_weights = inputs[2]
# If no projection
# Wi layer[0] shape (4 * hidden_size, feature_size)
# Wh layer[0] shape (4 * hidden_size, hidden_size)
# Bi layer[0] shape (4 * hidden_size)
# Bh layer[0] shape (4 * hidden_size)

# Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (4 * hidden_size, hidden_size)
# Bi layer[>0] shape (4 * hidden_size)
# Bh layer[>0] shape (4 * hidden_size)

# If projection
# Wi layer[0] shape (4 * hidden_size, feature_size)
# Wh layer[0] shape (4 * hidden_size, proj_size)
# Bi layer[0] shape (4 * hidden_size)
# Bh layer[0] shape (4 * hidden_size)
# P layer[0] shape (proj_size, hidden_size)

# Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions)
# Wh layer[>0] shape (4 * hidden_size, proj_size)
# Bi layer[>0] shape (4 * hidden_size)
# Bh layer[>0] shape (4 * hidden_size)
# P layer[>0] shape (proj_size, hidden_size)
# Compute gates by adding the input and hidden projections
gates = relay.add(gates_input, gates_hidden)

# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]
# Split gate values into input, forget, cell, and output gates
i, f, g, o = relay.split(gates, indices_or_sections=4, axis=-1)

num_directions = 1
if bidirectional:
num_directions = 2
# Apply activations: sigmoid for input, forget, output gates, and tanh for cell gate
i = relay.sigmoid(i)
f = relay.sigmoid(f)
g = relay.tanh(g)
o = relay.sigmoid(o)

rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"
has_proj = False
proj_size = 0
weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
if weights_num == 5:
has_proj = True
proj_size = _infer_shape(_weights[4])[0]
else:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
if weights_num == 3:
has_proj = True
proj_size = _infer_shape(_weights[2])[0]
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"
# Compute next cell state
c_next = relay.add(relay.multiply(f, c_prev), relay.multiply(i, g))

X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
# TODO (vvchernov): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)
# Compute next hidden state
h_next = relay.multiply(o, relay.tanh(c_next))

hidden_size = _infer_shape(_weights[0])[0] / 4
batch_size = X_shape[1]
# # Add batch dimension back to hidden and cell states
# h_next = relay.expand_dims(h_next, axis=0)
# c_next = relay.expand_dims(c_next, axis=0)

# Initialize hidden states if not provided.
layers_h = []
layers_c = []
hidden_layers_num = num_directions * num_layers
if h_0 is None:
if has_proj:
h_0 = _op.zeros((batch_size, proj_size), X_dtype)
else:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = unbind(h_0, 0)
if c_0 is None:
c_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_c.append(c_0)
else:
layers_c = unbind(c_0, 0)
return h_next, c_next

layer_weights_dicts = []
k = 0 # layer counter
if has_biases:
names = ["hidden_state", "cell_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 4]
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 4]]
rev_weights_dict = dict(zip(names, rev_tensors))
if has_proj:
rev_weights_dict["proj"] = _weights[j + 4]
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 4]
layer_weights_dicts.append([fw_weights_dict])
k += 1
else:
names = ["hidden_state", "cell_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 2]
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 2]]
rev_weights_dict = dict(zip(names, rev_tensors))
if has_proj:
rev_weights_dict["proj"] = _weights[j + 2]
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 2]
layer_weights_dicts.append([fw_weights_dict])
k += 1
assert (
len(layer_weights_dicts) == num_layers and k == num_layers
), "For stacked LSTM number of weights sets should be the same as number of layers!"

outputs = self.lstm_layers(
X, layer_weights_dicts, bidirectional, dtype=X_dtype, dropout_p=dropout_p
)

# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
output = outputs[0]

hy = []
cy = []
for hidden in outputs[1]:
hy.append(hidden[0])
cy.append(hidden[1])
def lstm(self, inputs, input_types):
"""
self, inputs, input_types
Custom LSTM operation for packed sequences using TVM Relay ops.
Args:
input: Relay expression for packed input data (concatenated sequences) of shape [225, 6].
batch_sizes: Relay expression for batch sizes at each timestep of shape [9].
hx: Relay expression for initial hidden states of shape [1, 25, 32].
cx: Relay expression for initial cell states of shape [1, 25, 32].
weight_ih: Relay expression for input-hidden weights.
weight_hh: Relay expression for hidden-hidden weights.
bias_ih: Relay expression for input-hidden biases.
bias_hh: Relay expression for hidden-hidden biases.
hidden_size: Hidden state size of the LSTM.
Returns:
output: Packed output data of shape [225, 32].
h_n: Final hidden state of shape [1, 25, 32].
c_n: Final cell state of shape [1, 25, 32].
"""
# Initialize output list to store hidden states at each timestep
outputs = []

if batch_first:
output = _op.transpose(output, (1, 0, 2))
# Initialize hidden and cell states
input = inputs[0]

batch_sizes = inputs[1]
weight_ih = inputs[3][0]
weight_hh = inputs[3][1]
bias_ih = inputs[3][2]
bias_hh = inputs[3][3]
hidden_size = _infer_value(inputs[2][0],{}).numpy().shape
hidden_size = hidden_size[-1]
seq_len = len(inputs)
hx = inputs[2][0]
cx = inputs[2][1]
hidden, cell = hx, cx
offset = 0 # To keep track of position in the packed input tensor

# Loop through each timestep in the sequence
# Manually setting batch_sizes_val
batch_sizes_data = np.array([25] * seq_len).astype("int32")
# input_shape = (225, 6)
# hidden_shape = (1, 25, 32)

for t in range(batch_sizes_data.shape[0]):
# Extract the batch size for the current timestep
# batch_size_t = relay.take(batch_sizes, relay.const(t, dtype="int32"))
batch_size_t = batch_sizes_data[t].item()

# Extract the input slice for the current batch size
input_t = relay.strided_slice(input, begin=[offset], end=[offset + batch_size_t], axes=[0], slice_mode="end")

# Adjust hidden and cell states to match current batch size
hidden_t = relay.strided_slice(hidden, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end")
cell_t = relay.strided_slice(cell, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end")

# Apply LSTM cell for the current timestep
h_next, c_next = self.lstm_cell_tvm(input_t, hidden_t, cell_t, weight_ih, weight_hh, bias_ih, bias_hh, batch_size_t, hidden_size)

# Append the hidden state to outputs
outputs.append(h_next)

# Update hidden and cell states for the next timestep
# sliced_hidden = relay.strided_slice(hidden, begin=[0, 0, 0], end=[hidden_shape[0], batch_size_t, hidden_shape[2]])
sliced_hidden = relay.strided_slice(hidden, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end")
hidden = relay.concatenate([h_next, sliced_hidden], axis=1)
sliced_cell = relay.strided_slice(cell, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end")
cell = relay.concatenate([c_next, sliced_cell], axis=1)

# Update offset to process the next batch of active sequences
offset += batch_size_t

# Concatenate the output hidden states for all timesteps to form packed output
output = relay.concatenate(outputs, axis=1) # Output should be [225, 32] as given
output = relay.squeeze(output, axis=[0])

return output, h_next, c_next

return (output, _op.stack(hy, 0), _op.stack(cy, 0))

def all_any_common(self, op, inputs, input_types):
# return True
Expand Down Expand Up @@ -6070,6 +6015,7 @@ def outplace_inplace_ops(opnodes):
def from_pytorch(
script_module,
input_infos,

custom_convert_map=None,
default_dtype="float32",
use_parser_friendly_name=False,
Expand Down Expand Up @@ -6137,7 +6083,7 @@ def from_pytorch(
prelude = Prelude(mod)
enable_lower_all_tuples = True

converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name)
converter = PyTorchOpConverter(prelude, default_dtype,use_parser_friendly_name)

graph = script_module.graph.copy()
# Check if lower_all_tuples pass can be enabled
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4310,4 +4310,4 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None
target=target,
framework_outputs=framework_outputs,
verify_cfg=verify_cfg
)
)

0 comments on commit 000102c

Please sign in to comment.