Hacker News new | past | comments | ask | show | jobs | submit login

I still don't get it. Why can't I use a debugger to step through derivatives when autodiff is implemented as a library?



Reverse mode autodiff is best implemented with a non-local transformation of the program: First you run the original operations forward; then you run corresponding operations in reverse order.

You can do this with a library by implementing "number" types that, as a side effect of arithmetic operations, record those operations onto a "tape", so that corresponding (different) operations can be played back later in reverse order.

Unfortunately, those side effects have a cost during the forward pass. And during the backwards pass you are essentially running a little interpreter to execute your instructions.

Putting this stuff into the compiler lets the forward pass just be normal, native code on doubles/floats, and the same for the reverse-pass code. Moreover, all of this can now be worked on by the optimizing compiler.

This is especially important for embedded applications where you can't afford all this interpretation.

I guess the original TensorFlow "computation graph" approach is a little different to the "tape" I described because the graph is built explicitly rather than through side effects, but that just makes even more clear that you're really assembling an AST for some /other/ language and (when not using XLA) running an interpreter.

In principle some suitably powerful macro language with full access to the AST might be able to do good compile-time reverse-mode AD, but I am not aware of any language with powerful enough macros. Again, this is because the transformation is not just a local pattern replacement; it involves flipping the code upside down.

What's funny is that physicists have been able to do this kind of program transformation in a slightly clunky(?) way -- FORTRAN in, FORTRAN out -- since the 70s, supposedly.

It all makes me think that our ideas about what a "language" is, what a "compiler" is, what a "library" is, are all stuck in convention and prematurely ossified. The first compilers, which we celebrate, looked to their users like codegen (which we detest, I think)! I would love for the boundary between "language" and "compiler" to be broken down more so that AD could be more easily done "within the language"; maybe one day that will happen in Julia, or Nim, or Jai, or Terra.

But for now I think I agree with the designers that the best hope for good results, and really the most straightforward way, is to just do it in the compiler.


Just complementing the other reply, there is a small article about how you can implement a source to source AD in Julia [1]. Basically Julia has a special type of macro called generated function [2] which instead of executing during the AST lowering phase (when the compiler still didn't evaluate the symbols) it executes during the final step of compilation (when type inference already ran and you have all the exact types), and in that function you can return either the AST or Julia's SSA IR directly (which is good for AD since it closely resembles the execution graph since it avoids mutability). And you can also inspect the IR of any function call and manipulate it within the language [3], so you can recursively create the tape entirely at compile time.

[1] http://blog.rogerluo.me/2019/07/27/yassad/

[2] https://docs.julialang.org/en/v1/manual/metaprogramming/inde...

[3] https://mikeinnes.github.io/IRTools.jl/latest/#Evaluating-IR...


It's already bring done in Julia. See zygote.jl.

Works with typed IR equivalent to a core compiler pass but from a third party package with regular Julia staged programming.


I think it's mostly due to a couple of reasons:

1. The debugging experience is probably better. Computing a derivative can be complex— you might be seeing a high value where you were expecting a low one, and you want to "step through the equation and how it changes. Having to do that when "the equation" is a bunch of obscure data structures, through the internal representation of functions in 3-rd party library could get very complex very quickly.

2. You might not catch non-differentiable functions and zero derivatives. Moreover, given just the nature of math, you could see millions of inputs that wouldn't trigger an exception, ship your model, and then one day the one that yields a 0 shows up, your model crashes, and you don't know why. Having the compiler essentially act as proof that something will _never_ be zero is awesome for correctness and reliability.

If you think about it, derivatives aren't really an operation "through" the equation, but "on" the equation. You're writing some function, but instead of passing a value through it, you're changing the function itself.

So the functions are the values, and need to be changed, morphed, combined, split, etc. If a library wanted to do this, my guess is either:

a) devx would suffer since wouldn't be writing functions normally, but rather defining them as objects with verbose constructors, etc.

b) for the sake of devx, the library would have to do some hacky introspection and jump hoops to get to work on the functions themselves, not with them, at the cost of performance or debuggability.

There's already software we use all the time that takes these functions, breaks them apart, understands them and does things with them though— the compiler. It transforms functions to machine code. Let's have it add a step in the middle there, and if a function is marked as being derived, let's have the compiler take it, transform it to its derivative, and then to machine code ¯\_(ツ)_/¯.




Consider applying for YC's Spring batch! Applications are open till Feb 11.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: