Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Why not Jax?


> Why not Jax?

- JAX Windows support is lacking

- CuPy is much closer to CUDA than JAX, so you can get better performance

- CuPy is generally more mature than JAX (fewer bugs)

- CuPy is more flexible thanks to cp.RawKernel

- (For those familiar with NumPy) CuPy is closer to NumPy than jax.numpy

But CuPy does not support automatic gradient computation, so if you do deep learning, use JAX instead. Or PyTorch, if you do not trust Google to maintain a project for a prolonged period of time https://killedbygoogle.com/


What about CPU-only loads? If one wants to write code that'll eventually run in both CPU and GPU but in the short-to-mid term will only be used in CPU? Since JAX natively support CPU (with numpy backend), but CuPy doesn't, this seems like a potential problem for some.


Isn't there a way to dynamically select between numpy and cupy, depending on whether you want cpu or gpu code?


NumPy has a mechanism to dispatch execution to CuPy: https://numpy.org/neps/nep-0018-array-function-protocol.html

Just prepare the input on NumPy or CuPy, and then you can just feed it to NumPy APIs. NumPy functions will handle itself if the input is NumPy ndarray, or dispatch the execution to CuPy if the input is CuPy ndarray.


> Isn't there a way to dynamically select between numpy and cupy, depending on whether you want cpu or gpu code?

CuPy is an (almost) drop-in replacement for NumPy, so the following works surprisingly often:

    if use_cpu:
        import numpy as np
    else:
       import cupy as np


> surprisingly

This is the problem with these kind of methods. It works, until it doesn't in an unknown way.


There is but then you're using two separate libraries, that seems like a fragile point of failure compared to just using jax. But regardless since jax will use different backends anyway, it's arguably not any worse (but it ends up being your responsibility to ensure correctness as opposed to the jax team).


> CuPy does not support automatic gradient computation, so if you do deep learning, use JAX instead

DL is major use case; is CuPy planning on adding auto gradient comp?


Real answer: CuPy has a name that is very similar to SciPy. I don’t know GPU, that’s why I’m using this sort of library, haha. The branding for CuPy makes it obvious. Is Jax the same thing, but implemented better somehow?


Yeah, Jax provides a one-to-one reimplementation of the Numpy interface, and a decent chunk of the scipy interface. Random number handling is a bit different, but Numpy random number handling seeeeems to be trending in the Jax direction (explicitly passed RNG objects).

Jax also provides back-propagation wherever possible, so you can optimize.


yes


cupy came out a long time before Jax; remember using it in a project for my BSc around 2015-2016.

Cool to see that it's still kicking!




Consider applying for YC's Fall 2025 batch! Applications are open till Aug 4

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: