- CuPy is much closer to CUDA than JAX, so you can get better performance
- CuPy is generally more mature than JAX (fewer bugs)
- CuPy is more flexible thanks to cp.RawKernel
- (For those familiar with NumPy) CuPy is closer to NumPy than jax.numpy
But CuPy does not support automatic gradient computation, so if you do deep learning, use JAX instead. Or PyTorch, if you do not trust Google to maintain a project for a prolonged period of time https://killedbygoogle.com/
What about CPU-only loads? If one wants to write code that'll eventually run in both CPU and GPU but in the short-to-mid term will only be used in CPU? Since JAX natively support CPU (with numpy backend), but CuPy doesn't, this seems like a potential problem for some.
Just prepare the input on NumPy or CuPy, and then you can just feed it to NumPy APIs. NumPy functions will handle itself if the input is NumPy ndarray, or dispatch the execution to CuPy if the input is CuPy ndarray.
There is but then you're using two separate libraries, that seems like a fragile point of failure compared to just using jax. But regardless since jax will use different backends anyway, it's arguably not any worse (but it ends up being your responsibility to ensure correctness as opposed to the jax team).
Real answer: CuPy has a name that is very similar to SciPy. I don’t know GPU, that’s why I’m using this sort of library, haha. The branding for CuPy makes it obvious. Is Jax the same thing, but implemented better somehow?
Yeah, Jax provides a one-to-one reimplementation of the Numpy interface, and a decent chunk of the scipy interface. Random number handling is a bit different, but Numpy random number handling seeeeems to be trending in the Jax direction (explicitly passed RNG objects).
Jax also provides back-propagation wherever possible, so you can optimize.