credit: Justine Zeghal
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)
---
┌──────────┬──────────┐
│ TPU 0 │ TPU 1 │
├──────────┼──────────┤
│ TPU 2 │ TPU 3 │
├──────────┼──────────┤
│ TPU 6 │ TPU 7 │
├──────────┼──────────┤
│ TPU 4 │ TPU 5 │
└──────────┴──────────┘
from jaxlib.hlo_helpers import custom_call
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
[...]
opaque = gpu_ops.create_rms_norm_descriptor([...])
out = custom_call(
b"rms_forward_affine_mixed_dtype",
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
operands=[x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
).results
return out
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
mlir.register_lowering(
_rms_norm_fwd_p,
_rms_norm_fwd_cuda_lowering,
platform="gpu",
)
def partition(mesh, arg_shapes, result_shape):
result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return mesh, fft, \
supported_sharding(arg_shardings[0], arg_shapes[0]), \
(supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return supported_sharding(arg_shardings[0], arg_shapes[0])
@custom_partitioning
def my_fft(x):
return fft(x)
my_fft.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition)
import jax_galsim as galsim
psf_beta = 5 #
psf_re = 1.0 # arcsec
pixel_scale = 0.2 # arcsec / pixel
sky_level = 2.5e3 # counts / arcsec^2
# Define the galaxy profile.
gal = galsim.Exponential(flux=1, scale_radius=2).shear(g1=.3, g2=.4)
gal = gal.shear(g1=0.01, g2=0.05)
# Define the PSF profile.
psf = galsim.Gaussian(flux=1., half_light_radius=psf_re)
# Final profile is the convolution of these.
final = galsim.Convolve([gal, psf])
# Draw the image with a particular pixel scale.
image = final.drawImage(scale=pixel_scale)
@jax.jacfwd
def autometacal(shear, image, psf, rec_psf):
# Step 1: Deconvolve the image
deconvolved = jax_galsim.Convolve([image,
jax_galsim.Deconvolve(psf)])
# Step 2: Apply shear
sheared = deconvolved.shear(g1=shear[0], g2=shear[1])
# Step 3: Reconvolve by slightly dilated PSF
reconvolve = jax_galsim.Convolve([sheared, rec_psf])
return reconvolve.drawImage(scale=scale, method='no_pixel').array