Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather tests #40

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::pair<tt::target::DataType, size_t> MapBufferTypeToElementType(PJRT_Buffer_T
case PJRT_Buffer_Type_S4:
case PJRT_Buffer_Type_S8:
case PJRT_Buffer_Type_S16:
return std::make_pair(tt::target::DataType::UInt16, 2);
case PJRT_Buffer_Type_S32:
case PJRT_Buffer_Type_S64:
case PJRT_Buffer_Type_U4:
Expand Down
8 changes: 6 additions & 2 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,15 @@ void ModuleBuilder::BuildModule(std::string_view code, std::string_view format,
{
throw std::runtime_error("Failed to run MLIR compiler pass pipeline.");
}
DLOG_F(LOG_DEBUG, "TTIR Module");
shlo_pm.addPass(mlir::tt::ttir::createTTIRGatherPatternMatch());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding pass manually, can you just invoke the pipeline, like we do for TTIRToTTNN below?

if (mlir::failed(shlo_pm.run(mlir_module.get())))
{
throw std::runtime_error("Failed to convert gather op");
}
DLOG_F(LOG_DEBUG, "TTIR to TTIR Module");
if (log_level > 0)
mlir_module->dump();


mlir::PassManager pm(mlir_module.get()->getName());
mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options;
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options);
Expand Down
142 changes: 142 additions & 0 deletions tests/TTIR/test_gather_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@

from jax import grad, jit, vmap
import jax.numpy as jnp
import jax
import os
import sys
import jax._src.xla_bridge as xb
from jax.lax import GatherDimensionNumbers
import flax.linen as nn

def initialize():
backend = "tt"
path = os.path.join(os.path.dirname(__file__), "../../build/src/tt/pjrt_plugin_tt.so")
if not os.path.exists(path):
raise FileNotFoundError(f"Could not find tt_pjrt C API plugin at {path}")
print("Loading tt_pjrt C API plugin", file=sys.stderr)
xb.discover_pjrt_plugins()
plugin = xb.register_plugin('tt', priority=500, library_path=path, options=None)
print("Loaded", file=sys.stderr)
jax.config.update("jax_platforms", "tt,cpu")

def jax_take():
print("\n\n Before operand:\n\n")
tensor = jnp.zeros((32000, 1024), dtype=jnp.float32)
print("\n\nBefore start_indices:\n\n")
indices = jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],dtype=jnp.int16)

print("\n\nBefore take:\n\n")

try:
# Use jit to force compilation and IR generation
@jax.jit
def take_fn(tensor, indices):
return jnp.take(tensor, indices, axis=0)

print("\n\nBefore take:\n\n")
gathered = take_fn(tensor, indices)
print(gathered.shape)
except Exception as e:
print("Error:", e)

def jax_indexing():
print("\n\n Before operand:\n\n")
tensor = jnp.zeros((32000, 1024), dtype=jnp.float32)
print("\n\nBefore start_indices:\n\n")
indices = jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],dtype=jnp.int16)

print("\n\nBefore indexing:\n\n")

try:
# Use jit to force compilation and IR generation
@jax.jit
def indexing_fn(tensor, indices):
return tensor[indices[0]]

print("\n\nBefore take:\n\n")
gathered = indexing_fn(tensor, indices)
print(gathered.shape)
except Exception as e:
print("Error:", e)

def jax_vmap():
print("\n\n Before operand:\n\n")
tensor = jnp.zeros((32000, 1024), dtype=jnp.float32)
print("\n\nBefore start_indices:\n\n")
indices = jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],dtype=jnp.int16)

print("\n\nBefore vmap:\n\n")

try:
# Use jit to force compilation and IR generation
@jax.jit
def vmap_fn(index):
return tensor[index]

print("\n\nBefore vmap:\n\n")
gathered = vmap(vmap_fn)(indices[0])
print(gathered.shape)
except Exception as e:
print("Error:", e)

class EmbeddingModel(nn.Module):
vocab_size: int
embedding_dim: int

@nn.compact
def __call__(self, indices):
embedding = nn.Embed(
num_embeddings=self.vocab_size,
features=self.embedding_dim,
dtype=jnp.float32
)
return embedding(indices)

def flax_embed():
print("\n\nInitializing model:\n\n")

# Model parameters
vocab_size = 32000
embedding_dim = 1024

# Create and initialize the model
model = EmbeddingModel(vocab_size=vocab_size, embedding_dim=embedding_dim)

# Create sample indices
indices = jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],
dtype=jnp.int16) # Changed to int32 as per Flax conventions

print("\n\nBefore embedding lookup:\n\n")

try:
# Initialize parameters
key = jax.random.PRNGKey(0)
params = model.init(key, indices)

# JIT the forward pass
@jax.jit
def embed_fn(params, indices):
return model.apply(params, indices)

print("\n\nPerforming embedding lookup:\n\n")
embedded = embed_fn(params, indices)
print(embedded.shape)

except Exception as e:
print("Error:", e)


if __name__ == "__main__":
initialize()
print("\n\nBefore valid_jax_gather_example\n\n")
jax_take() # output sizes match with gather, fails during stablehlo

# the following tests fail before shlo
# flax_embed()
# jax_indexing() # output shape does not match
## jax_vmap() # uses dynamic slice which fails

2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "50f0f035f53cd3c755be8c3650c55cf3d4e3170b")
set(TT_MLIR_VERSION "1d8da1fe3160960c0c8428fa8eddc17ec615e1bc")
set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1")

if (TOOLCHAIN STREQUAL "ON")
Expand Down
Loading