Skip to content

Commit

Permalink
add sharding to play examples
Browse files Browse the repository at this point in the history
  • Loading branch information
ACea15 committed Jan 31, 2025
1 parent 4602fac commit ad5ae99
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions docs/play_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,43 @@ def outer_function(x, static_val):
# Call the compiled function with a static argument
result = compiled_function(3, 10)
print(result) # Expected to print 13

#############################################################


from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
jax.config.update("jax_enable_x64", True)

import os
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"

mesh = jax.make_mesh((4, 2), ('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x'))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
#return c_block
return jnp.broadcast_to(c_block.reshape(c_block.shape + (1,)), c_block.shape + (3,))

c = matmul_basic(a, b) # c: f32[8, 4]

from jax.tree_util import tree_map, tree_all

def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c[:,:,2], jnp.dot(a, b))

0 comments on commit ad5ae99

Please sign in to comment.