Hacker News new | past | comments | ask | show | jobs | submit login

You probably could get perturbation confusion in reverse mode, but it's not an easy trap like it is with forward mode. The problem with forward mode AD is that it's deceptively easy: you transform every number a into a pair of numbers (a,b) and you change every function from f(a) to (f(a),f'(a)), put the chain rule on it, and you're done. Whether you call it dual number arithmetic or a compiler transformation it's the same thing in the end. The issue with perturbation confusion is that you've created this "secret extra number" to store the derivative data, and so if two things are differentiating code at the same time, you need to make sure you turn (a,b) into (a,a',b,b') and that all layers of AD are always grabbing the right value out of there. Note that in the way I wrote it, if you assumed "the b term is always num[2] in the tuple", oops perturbation confusion, and so your generated code needs to be "smart" (but not lose efficiency!). Thus the fixes are proofs and tagging systems that ensure the right perturbation terms are always used in the right places.

With reverse mode AD, this is much less likely to be an issue because the AD system isn't necessarily storing and working on hidden extensions to the values, it's running a function forwards and then running a separate function backwards having remembered some values from the forward pass. If the remembered values are correct and never modified, then generating a higher order derivative is just as safe as the first. But that last little detail is thus what I think is most akin to perturbation confusion in reverse mode: reverse mode has the assumption that the objects captured in the forward pass will not be changed (or will be at least be back in the correct state) when it is trying to reverse. The easy way to break this assumption doesn't even require second derivatives. The easiest way to break it is mutation: if you walk forward by doing Ax, then the reverse pass wants to do A'v so it just keeps the pointer to A, but if A gets mutated in the meantime then using that pointer is incorrect. This is the reason why most AD systems simply disallow mutation except in very special unoptimized cases (PyTorch, Jax, Zygote, ...).

Enzyme.jl is an exception because it takes a global analysis of the program it's differentiating (with proper escape analysis etc. passes at the LLVM level) in order to know that any mutation going forward will be reversed during the reverse path, so by the time it gets back to A'*v it knows A will be the same. Higher level ADs could go cowboy YOLO style and just assume the reversed matrix is correct (and it might be a lot of the time), though that causes some pretty major concerns for correctness. The other option is to simply make a full copy of A every time you mutate an element, so have fun if you loop through your weight matrix. The Diffractor.jl near future approach is more like Haskell GHC where it just wants you to give it the non-mutating code so it can try and generate the mutating code when that would be more efficient (https://github.com/JuliaLang/julia/pull/42465).

So with forward-mode AD there was an entire literature around schemes of provable safety to perturbation confusion, and I'm surprised we haven't already started seeing papers about provable safety with respect to mutation in higher-level reverse-mode AD. I would suspect that the only reason why it hasn't started is that the people who write type-theoretic proofs tend to be the functional programming pure function folks that tell people to never mutate anyways, so the literature might instead go the direction of escape analysis proofs to optimize immutable array code to (and beyond) the performance of mutation code on commonly mutating applications. Either way it's getting there with the same purpose in mind.




Consider applying for YC's Spring batch! Applications are open till Feb 11.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: