Well, the most common ML problems can be expressed as optimization over smooth functions (or reformulated that way manually). We might have to convince the ML world that branches do matter :) On the other hand, there are gradient-free approaches that solve problems with jumps in other ways, like many reinforcement learning algorithms, or metaheuristics such as genetic algorithms in simulation-based optimization. The jury's still out on "killer apps" where gradient descent can outperform these approaches reliably, but we're hoping to add to that body of knowledge...
> Why do you think similar approaches never landed on jax?
Isn't this just adding noise to some branching conditions? What would take for a framework like Jax to "support" it, it seems like all you have to do is change
> if (x>0)
to
> if (x+n > 0)
where n is a sampled Gaussian.
Not sure this warrants any kind of changes in a framework if it's truly that trivial.
Semantically it seems truly that trivial, but in practice handling expectations in AD requires some additional machinery not found in implementations that were not written for nondeterminism.
- Why do you think similar approaches never landed on jax? My guess is this is not that useful for the current optimizations in fashion (transformers)
- How would you convince jax to incorporate this?