You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When profiling FP, runtime seems to be dominated by MemcpyD2H as seen in the Perfetto UI. The actual computations take only very little time in between.
I think this could be related to JAX calling each step in jax.lax.scan from host, requiring some synchronization at each iteration. This is discussed here and here.
I don't know how to resolve this in pure JAX. Alternatives to current implementation is vmap'ing instead of scan, or using some unroll in scan. IIRC both led to longer runtime.
Ultimately we should avoid the Memcpy at each projection -- which is what I think is happening. We could switch the levels of scan and vmap, i.e. vmap over angles, but scan over detector rows. I would expect that to be slower, but we could try it.
The text was updated successfully, but these errors were encountered:
When profiling FP, runtime seems to be dominated by
MemcpyD2H
as seen in the Perfetto UI. The actual computations take only very little time in between.I think this could be related to JAX calling each step in
jax.lax.scan
from host, requiring some synchronization at each iteration. This is discussed here and here.I don't know how to resolve this in pure JAX. Alternatives to current implementation is vmap'ing instead of scan, or using some unroll in scan. IIRC both led to longer runtime.
Ultimately we should avoid the Memcpy at each projection -- which is what I think is happening. We could switch the levels of scan and vmap, i.e. vmap over angles, but scan over detector rows. I would expect that to be slower, but we could try it.
The text was updated successfully, but these errors were encountered: