From 000102ceed5419f00f0c5cccfd8da0f9c9a88e50 Mon Sep 17 00:00:00 2001 From: Ashok Kumar Kannan Date: Thu, 26 Dec 2024 09:36:18 +0000 Subject: [PATCH] Lstm with Batchsize Decomposition --- python/tvm/contrib/forge_compile.py | 2 - python/tvm/relay/frontend/pytorch.py | 300 +++++++----------- .../relay/op/contrib/forge/forge_passes.py | 2 +- 3 files changed, 124 insertions(+), 180 deletions(-) diff --git a/python/tvm/contrib/forge_compile.py b/python/tvm/contrib/forge_compile.py index c367e5804..8c6d58754 100644 --- a/python/tvm/contrib/forge_compile.py +++ b/python/tvm/contrib/forge_compile.py @@ -1174,5 +1174,3 @@ def default(self, obj): os.makedirs(os.path.dirname(store_path), exist_ok=True) with open(store_path, 'w') as file: file.write(serilized_str) - - logger.info(f"Successfully stored serilized TVM graph to {store_path} path") diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c4a7b06b7..ff0245a55 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -36,7 +36,7 @@ from tvm.ir import IRModule from tvm.ir.type import DictType from tvm.topi.utils import get_const_tuple - +from tvm import relay from .. import analysis as _analysis from .. import expr as _expr from .. import function as _function @@ -60,6 +60,7 @@ __all__ = ["from_pytorch"] + # This returns a "subgraph" which puts variables whenever # the type is known. It also records things to map the input # nodes to the extracted graph's nodes. @@ -157,7 +158,7 @@ def _is_int_seq(seq): class PyTorchOpConverter: """A helper class for holding PyTorch op converters.""" - def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): + def __init__(self, prelude, default_dtype,use_parser_friendly_name=False): self.prelude = prelude self.default_dtype = default_dtype self.create_convert_map() @@ -3923,193 +3924,137 @@ def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dro return _op.stack(input_seqs, 0), final_hiddens - def lstm(self, inputs, input_types): + + def lstm_cell_tvm(self,input_t, h_prev, c_prev, weight_ih, weight_hh, bias_ih, bias_hh, batch_size_t, hidden_size): """ - Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html - Native implementation for torch version less than 1.8.0 (projection is unsupported): - https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \ - src/ATen/native/RNN.cpp#L1396 - Native implementation for torch version from 1.8.0 and higher (projection is supported): - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483 + Single LSTM cell operation using TVM ops for packed input. + + Args: + input_t: Input tensor at current timestep [batch_size_t, input_size]. + h_prev: Previous hidden state [1, batch_size_t, hidden_size]. + c_prev: Previous cell state [1, batch_size_t, hidden_size]. + weight_ih: Input-hidden weights [4 * hidden_size, input_size]. + weight_hh: Hidden-hidden weights [4 * hidden_size, hidden_size]. + bias_ih: Input-hidden biases [4 * hidden_size]. + bias_hh: Hidden-hidden biases [4 * hidden_size]. + batch_size_t: Current batch size. + hidden_size: Hidden state size. + + Returns: + h_next: Next hidden state [1, batch_size_t, hidden_size]. + c_next: Next cell state [1, batch_size_t, hidden_size]. """ - # TODO (vvchernov): support dropout - assert len(inputs) == 9, "Input of size 9 is expected" - # Unpack inputs, note that if optional and not provided then value will be None. - _X = inputs[0] - # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size) + # Multiply input by input-hidden weights and add bias + gates_input = relay.nn.dense(input_t, weight_ih) + bias_ih - hidden_states = inputs[1] - assert len(hidden_states) == 2, "lstm expects two hidden states" - h_0 = hidden_states[0] - c_0 = hidden_states[1] - # H0 shape (hidden_layers_num, batch, proj_size) if projection - # else (hidden_layers_num, batch, hidden_size) - # C0 shape (hidden_layers_num, batch, hidden_size) + # Multiply previous hidden state by hidden-hidden weights and add bias + h_prev_squeezed = relay.squeeze(h_prev, axis=[0]) + gates_hidden = relay.nn.dense(h_prev_squeezed, weight_hh) + bias_hh - _weights = inputs[2] - # If no projection - # Wi layer[0] shape (4 * hidden_size, feature_size) - # Wh layer[0] shape (4 * hidden_size, hidden_size) - # Bi layer[0] shape (4 * hidden_size) - # Bh layer[0] shape (4 * hidden_size) - - # Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions) - # Wh layer[>0] shape (4 * hidden_size, hidden_size) - # Bi layer[>0] shape (4 * hidden_size) - # Bh layer[>0] shape (4 * hidden_size) - - # If projection - # Wi layer[0] shape (4 * hidden_size, feature_size) - # Wh layer[0] shape (4 * hidden_size, proj_size) - # Bi layer[0] shape (4 * hidden_size) - # Bh layer[0] shape (4 * hidden_size) - # P layer[0] shape (proj_size, hidden_size) - - # Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions) - # Wh layer[>0] shape (4 * hidden_size, proj_size) - # Bi layer[>0] shape (4 * hidden_size) - # Bh layer[>0] shape (4 * hidden_size) - # P layer[>0] shape (proj_size, hidden_size) + # Compute gates by adding the input and hidden projections + gates = relay.add(gates_input, gates_hidden) - # Scalar inputs - has_biases = inputs[3] - num_layers = inputs[4] - dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout - # train = inputs[6] - bidirectional = inputs[7] - batch_first = inputs[8] + # Split gate values into input, forget, cell, and output gates + i, f, g, o = relay.split(gates, indices_or_sections=4, axis=-1) - num_directions = 1 - if bidirectional: - num_directions = 2 + # Apply activations: sigmoid for input, forget, output gates, and tanh for cell gate + i = relay.sigmoid(i) + f = relay.sigmoid(f) + g = relay.tanh(g) + o = relay.sigmoid(o) - rsd = len(_weights) % num_layers - assert rsd == 0, "The number of weights must be a multiple of the number of layers!" - rsd = (len(_weights) / num_layers) % num_directions - assert ( - rsd == 0 - ), "The number of weights in layer must be a multiple of the number of directions!" - has_proj = False - proj_size = 0 - weights_num = int(len(_weights) / num_layers / num_directions) - if has_biases: - if weights_num == 5: - has_proj = True - proj_size = _infer_shape(_weights[4])[0] - else: - assert weights_num == 4, "The weights number in layer is expected equal to 4" - else: - if weights_num == 3: - has_proj = True - proj_size = _infer_shape(_weights[2])[0] - else: - assert weights_num == 2, "The weights number in layer is expected equal to 2" + # Compute next cell state + c_next = relay.add(relay.multiply(f, c_prev), relay.multiply(i, g)) - X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X - # TODO (vvchernov): Which data type should be used? from input or weights? - # Instead of it _infer_type(X).checked_type.dtype can be used - X_dtype = input_types[0] - X_shape = _infer_shape(X) # (seq_num, batch, feature_size) + # Compute next hidden state + h_next = relay.multiply(o, relay.tanh(c_next)) - hidden_size = _infer_shape(_weights[0])[0] / 4 - batch_size = X_shape[1] + # # Add batch dimension back to hidden and cell states + # h_next = relay.expand_dims(h_next, axis=0) + # c_next = relay.expand_dims(c_next, axis=0) - # Initialize hidden states if not provided. - layers_h = [] - layers_c = [] - hidden_layers_num = num_directions * num_layers - if h_0 is None: - if has_proj: - h_0 = _op.zeros((batch_size, proj_size), X_dtype) - else: - h_0 = _op.zeros((batch_size, hidden_size), X_dtype) - for i in range(hidden_layers_num): - layers_h.append(h_0) - else: - layers_h = unbind(h_0, 0) - if c_0 is None: - c_0 = _op.zeros((batch_size, hidden_size), X_dtype) - for i in range(hidden_layers_num): - layers_c.append(c_0) - else: - layers_c = unbind(c_0, 0) + return h_next, c_next - layer_weights_dicts = [] - k = 0 # layer counter - if has_biases: - names = ["hidden_state", "cell_state", "w_inp", "w_hid", "b_inp", "b_hid"] - if bidirectional: - rsd = len(_weights) % (2 * weights_num) - assert rsd == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), 2 * weights_num): - fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 4]] - fw_weights_dict = dict(zip(names, fw_tensors)) - if has_proj: - fw_weights_dict["proj"] = _weights[i + 4] - j = i + weights_num - rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 4]] - rev_weights_dict = dict(zip(names, rev_tensors)) - if has_proj: - rev_weights_dict["proj"] = _weights[j + 4] - layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) - k += 1 - else: - assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), weights_num): - fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 4]] - fw_weights_dict = dict(zip(names, fw_tensors)) - if has_proj: - fw_weights_dict["proj"] = _weights[i + 4] - layer_weights_dicts.append([fw_weights_dict]) - k += 1 - else: - names = ["hidden_state", "cell_state", "w_inp", "w_hid"] - if bidirectional: - rsd = len(_weights) % (2 * weights_num) - assert rsd == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), 2 * weights_num): - fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 2]] - fw_weights_dict = dict(zip(names, fw_tensors)) - if has_proj: - fw_weights_dict["proj"] = _weights[i + 2] - j = i + weights_num - rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 2]] - rev_weights_dict = dict(zip(names, rev_tensors)) - if has_proj: - rev_weights_dict["proj"] = _weights[j + 2] - layer_weights_dicts.append([fw_weights_dict, rev_weights_dict]) - k += 1 - else: - assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights" - for i in range(0, len(_weights), weights_num): - fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 2]] - fw_weights_dict = dict(zip(names, fw_tensors)) - if has_proj: - fw_weights_dict["proj"] = _weights[i + 2] - layer_weights_dicts.append([fw_weights_dict]) - k += 1 - assert ( - len(layer_weights_dicts) == num_layers and k == num_layers - ), "For stacked LSTM number of weights sets should be the same as number of layers!" - - outputs = self.lstm_layers( - X, layer_weights_dicts, bidirectional, dtype=X_dtype, dropout_p=dropout_p - ) - - # output shape = (seq_num, batch, hidden_size) or - # (seq_num, batch, 2*feature_size) for bidirectional - output = outputs[0] - - hy = [] - cy = [] - for hidden in outputs[1]: - hy.append(hidden[0]) - cy.append(hidden[1]) + def lstm(self, inputs, input_types): + """ + self, inputs, input_types + Custom LSTM operation for packed sequences using TVM Relay ops. + + Args: + input: Relay expression for packed input data (concatenated sequences) of shape [225, 6]. + batch_sizes: Relay expression for batch sizes at each timestep of shape [9]. + hx: Relay expression for initial hidden states of shape [1, 25, 32]. + cx: Relay expression for initial cell states of shape [1, 25, 32]. + weight_ih: Relay expression for input-hidden weights. + weight_hh: Relay expression for hidden-hidden weights. + bias_ih: Relay expression for input-hidden biases. + bias_hh: Relay expression for hidden-hidden biases. + hidden_size: Hidden state size of the LSTM. + + Returns: + output: Packed output data of shape [225, 32]. + h_n: Final hidden state of shape [1, 25, 32]. + c_n: Final cell state of shape [1, 25, 32]. + """ + # Initialize output list to store hidden states at each timestep + outputs = [] - if batch_first: - output = _op.transpose(output, (1, 0, 2)) + # Initialize hidden and cell states + input = inputs[0] + + batch_sizes = inputs[1] + weight_ih = inputs[3][0] + weight_hh = inputs[3][1] + bias_ih = inputs[3][2] + bias_hh = inputs[3][3] + hidden_size = _infer_value(inputs[2][0],{}).numpy().shape + hidden_size = hidden_size[-1] + seq_len = len(inputs) + hx = inputs[2][0] + cx = inputs[2][1] + hidden, cell = hx, cx + offset = 0 # To keep track of position in the packed input tensor + + # Loop through each timestep in the sequence + # Manually setting batch_sizes_val + batch_sizes_data = np.array([25] * seq_len).astype("int32") + # input_shape = (225, 6) + # hidden_shape = (1, 25, 32) + + for t in range(batch_sizes_data.shape[0]): + # Extract the batch size for the current timestep + # batch_size_t = relay.take(batch_sizes, relay.const(t, dtype="int32")) + batch_size_t = batch_sizes_data[t].item() + + # Extract the input slice for the current batch size + input_t = relay.strided_slice(input, begin=[offset], end=[offset + batch_size_t], axes=[0], slice_mode="end") + + # Adjust hidden and cell states to match current batch size + hidden_t = relay.strided_slice(hidden, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end") + cell_t = relay.strided_slice(cell, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end") + + # Apply LSTM cell for the current timestep + h_next, c_next = self.lstm_cell_tvm(input_t, hidden_t, cell_t, weight_ih, weight_hh, bias_ih, bias_hh, batch_size_t, hidden_size) + + # Append the hidden state to outputs + outputs.append(h_next) + + # Update hidden and cell states for the next timestep + # sliced_hidden = relay.strided_slice(hidden, begin=[0, 0, 0], end=[hidden_shape[0], batch_size_t, hidden_shape[2]]) + sliced_hidden = relay.strided_slice(hidden, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end") + hidden = relay.concatenate([h_next, sliced_hidden], axis=1) + sliced_cell = relay.strided_slice(cell, begin=[0], end=[batch_size_t], axes=[1], slice_mode="end") + cell = relay.concatenate([c_next, sliced_cell], axis=1) + + # Update offset to process the next batch of active sequences + offset += batch_size_t + + # Concatenate the output hidden states for all timesteps to form packed output + output = relay.concatenate(outputs, axis=1) # Output should be [225, 32] as given + output = relay.squeeze(output, axis=[0]) + + return output, h_next, c_next - return (output, _op.stack(hy, 0), _op.stack(cy, 0)) def all_any_common(self, op, inputs, input_types): # return True @@ -6070,6 +6015,7 @@ def outplace_inplace_ops(opnodes): def from_pytorch( script_module, input_infos, + custom_convert_map=None, default_dtype="float32", use_parser_friendly_name=False, @@ -6137,7 +6083,7 @@ def from_pytorch( prelude = Prelude(mod) enable_lower_all_tuples = True - converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name) + converter = PyTorchOpConverter(prelude, default_dtype,use_parser_friendly_name) graph = script_module.graph.copy() # Check if lower_all_tuples pass can be enabled diff --git a/python/tvm/relay/op/contrib/forge/forge_passes.py b/python/tvm/relay/op/contrib/forge/forge_passes.py index 8b85d4a21..b5e903759 100644 --- a/python/tvm/relay/op/contrib/forge/forge_passes.py +++ b/python/tvm/relay/op/contrib/forge/forge_passes.py @@ -4310,4 +4310,4 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None target=target, framework_outputs=framework_outputs, verify_cfg=verify_cfg - ) + ) \ No newline at end of file