"The second factor, and probably the more damning one, is that most ML codes don't actually use that much dynamism." I would argue that this in true precisely because it is not available in an AD system. When I tell friends and coworkers about what zygote can do they light up and start describing different use cases they have that could benefit from AD. Diff eq solving is a big one.
This is because continuous optimization is useless when crossing a discontinuity, which is what control flow creates. Even in a trivial situation like ReLU, where the control flow is mimicking a continuous transition, you have the "dead ReLU" problem, where you have to start training on the correct side of the discontinuity and make sure to never cross.
Formally, there is a generalisation of differentiation which can handle functions like ReLU (i.e. locally Lipschitz non-differentiable functions) by allowing a derivative to be set-valued. It's called the Clarke gradient. The Clarke gradient of ReLU at 0 is the closed interval [0,1]. Note that the Clarke gradient doesn't satisfy the chain rule (except in a weakened form) which might seriously mess up some assumptions about autodiff. Is this generalised derivative useful in autodiff?
I imagine that this is a largely theoretical tool that's useful in analysing algorithms but useless for actually computing things.
The subgradient in convex analysis is a special case of the Clarke gradient. The subgradient is precisely the Clarke gradient for convex functions. Convex functions are always locally Lipschitz except in weird cases.
[edit]
Question: Are there numerical applications in which the subgradient is actually computed, or is it a purely analytical tool?
(Stochastic) subgradient methods are used in practice to optimize non-differentiable convex functions. They have a slower convergence rate than (stochastic) gradient descent though.
The gist of it is that we endeavour to provide ‘useful’ values for the gradient at non-differentiable points, even if the traditional derivative is not defined or infinite.
You could imagine it as us smoothing out edges or discontinuities, ideally in a way that that makes things like gradient descent well behaved.
Interesting read. I was very disappointed when the Swift TensorFlow project withered away. A good general purpose programming language combined with deep learning seemed like a great idea. Apple actually provides a good dev experience with Swift and CoreML (for fun I wrote a Swift/SwiftUI/CoreML app that uses two deep learning models that is in the App Store).
Wolfram Language takes an approach similar to Apple’s in providing a good number of pre trained models, but I haven’t yet discovered any automatic differentiation examples.
Of the frameworks described in the article, I find Julia most interesting but I need to use Python and TensorFlow in my work.
I remember reading an early S4TF manifesto talking about natural language processing, image processing, etc. thinking, that cannot be your audience because that audience already has AD systems which support their domain. Building something that is more general for the sake of being more general is never a good idea, that's bad engineering. It had a lot of great ideas, and indeed the dev experience seemed nice. But I would venture to guess that the Google overlords had to question what the true value of S4TF in that light. "Standard ML" cannot be your target audience if you want to work on AD extensions.
Following this thread, you can also see what how the Julia tools evolved. If you see the paper that was the synthesis for Zygote.jl, it was all about while loops and scalar operations (https://arxiv.org/abs/1810.07951). Why did that not completely change ML? Well, ML doesn't use those kinds of operations. I would say the project kind of started as a "tool looking for a problem". It did get a bit lucky that it found a problem: scientific applications need to be able to use automatic differentiation without rewriting the whole codebase to an ML library, leading to the big Julia AD manifesto of language-wide differentiable programming by directly acting on the Julia source itself rather than a language subset (http://ceur-ws.org/Vol-2587/article_8.pdf). Zygote was a good AD, but not a great AD, why? Because it could not hit this goal, mostly because of its lack of mutation handling. Yes, it does handle standard ML just fine, but does not justify its added complexity.
What has actually kept Julia AD research going is that some scientific machine learning applications, specifically physics-informed neural networks (PINNs), require very high order derivatives. For example, to solve the PDE u_t = u_xx with neural networks, you need to take the third derivative of the neural network. With Jax this can only be done with a separate language subset (https://openreview.net/pdf?id=SkxEF3FNPH), and thus a new AD for Julia to replace Zygote, known as Diffractor.jl, was devised to automatically incorporate higher order AD optimizations as part of the regular usage (https://www.youtube.com/watch?v=mQnSRfseu0c). It is these PINN SciML applications that have funded its development and is its built-in audience: it solves a problem nothing else does, even if it is potentially niche. Similarly with Enzyme, it solved the problem of how to do mutation well, which is where you can see in the paper that its applications are mostly ODE and PDE solvers (Euler, RK4, the Bruss semilinear PDE) (https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b...). Torchscript and Jax do not handle this domain well, so it has an audience, which may (or may not?) be niche.
A big part of writing this blog post was to highlight this to the Julia AD crew that I regularly work with. What will keep these projects alive is understanding the engineering trade-offs that are made and who the audience is. The complexity has a cost so it better have a benefit. If that target is lost, if any benefit is a theoretical "but you may need more features some day", then the projects will lose traction. The project needs to be two-fold: identify new architectures and applications that would benefit from expanded language support from AD and build good support for those projects. Otherwise it is just training a transformer in Julia vs training a transformer in Python, and that is not justifiable.
My impression from your comment is that you don't care that much about "standard" ML users. As a "standard" ML user (pytorch/jax), and a potential Julia user in the future, this is not what I like to hear.
The idea, I imagine, is to differentiate what the julia ML stack offers over what is already on python, if it offers the same thing, but without the funding from facebook or google, why bother switching? It has to offer something more.
If the purpose of the AD was to do something simply for standard ML workflows, the current AD tools are not the right design for that. They are too complex and solve a harder problem than they would need to. A better approach would be to use the abstract interpretation afforded by Symbolics.jl, mixed with the array symbolics, and use MetaTheory.jl to define simplification rules similar to XLA. You'd essentially get a slightly expanded Jax/TensorFlow with graph simplification rules that could be adjusted/improved directly from the host language. There's a prototype (ReversePropogation.jl), but it needs array support to actually be useful for this application. Or if you don't need the whole stack to be modifiable from Julia, a better interpreter on XLA.jl would get you there. A sufficiently decent lone coder could get those up and optimized fairly quickly for standard ML applications, if that's the goal.
Diffractor.jl has a much loftier goal: optimized differentiable programming of any code from any package in the Julia ecosystem. Because it's building typed IR, it will need a full set of Julia-based analysis tools (escape analysis, loop-invariant code motion, etc.) to approach the amount of optimization XLA can do when XLA optimizations are applicable. While such passes are being developed (for example, this is the PR for putting immutable array optimizations into the language so that Diffractor-friendly immutable can generate the optimized mutable form: https://github.com/JuliaLang/julia/pull/42465), it's at least a few years away before it's doing something like reliably combining multiple matrix-vector products into a matrix-matrix BLAS3 call. That would put it on even footing today to compete against PyTorch in a kernel vs kernel optimization battle, but not against TensorFlow code in cases where XLA optimizations are doing something more.
"Just wait 3 years and it will be really cool" is not a good way to start building a robust community, instead those interested in it need to ask how to demonstrate the improvements afforded by the added generality today. That's why what's making the project tick right now is the ARPA-E projects for physics-informed neural networks, the DJ4Earth project to do direct differentiation of the CLIMA climate model without changing any of the model code (https://dj4earth.github.io/), etc. Those kinds of projects are what is keeping a lot of the dev team open to be full time on these AD and compiler optimization projects. But if successful, it will also give an AD that is great for standard ML.
I am really glad to hear this. As it happens, my main post-retirement has been to learn Swift to try out CoreML. I'm really enjoying learning this so far.
the advantage of transformers (computationally) seems to be how little sophistication the attention mechanism needs from AD systems (and how well it appears to scale with data). it's also a very static architecture in terms of a data flow/control flow perspective.
as far as I understand, this is far different from systems needing to be modeled in continuous time, especially things like SDEs. I am curious if things like delay embeddings will ever be modeled in terms of mechanisms similar to attention however.
Outside of research transformers are rarely used for computer vision problems and CNNs remain the go to architecture. And you actually need to do some hacks to get transformers to work with computer vision at a meaningful scale (splitting images into patches and convoluting the patches to produce features to feed into the transformer).
> some hacks to get transformers to work with computer vision at a meaningful scale (splitting images into patches and convoluting the patches to produce features to feed into the transformer).
Yeah. Even modern CV methods are hacky insofar as picking the “right” way to apply linear algebra. Convolution layers are hacked up matrix multiplications that are “inspired” by human vision. Of course, the real reason for the hacks is that form works in practice.
I’ve looked into transformers for semantic segmentation, but the patching aspect seems to make it hard too. Do you have some sources that describe these hacks in detail?
You could do a code search on GitHub. I’m pretty lazy in the aspect of coding. I always seem to find a repo that has implemented an MVP with what I already had in mind. There are some gold nuggets on GitHub like Googles DDSP implementation they have academically published anonymous.
It will be interesting when Tesla and Waymo moves to transformer architecture, but as you wrote my guess is that it's not yet in production for vision tasks.
Tesla did, as mentioned in their AI Day. It is not full transformer (aka ViT). The use transformer decoder to synthesize data from different cameras and decode 3d coordinates directly (aka DETR).
I’m not sure they will, at least not with the research in the state it is presently. Researchers are interested in vision transformers because they’re competitive with CNNs if you give them enough training data - they don’t drastically outperform them.
Right now switching over to them would require a ton of code changes, relearning intuitions, debugging, profiling, etc. for not a ton of benefit.
Like other comments, CNNs and LSTMs are still in wide use today. If you dig deep enough, position encoding doesn't really capture time-based series information that well.
Even though this posts thesis is “trade offs”, it doesn’t really talk about any technical advantages that the Python’s AD ecosystem (Tensorflow, PyTorch, JAX) has over Julia’s (Zygote.jl, Diffractor.jl).
Maybe it has no technical advantages, unless you count being very popular and in an accessible language as a technical advantage (which it definitely could be depending on your definition of "technical").
Julia is designed for advanced numerical computing and Python isn't. The metaprogramming affordances needed for AD are much better developed in Julia than they ever will be in Python. And let's not forget the immense utility of multiple dispatch in Julia, another feature Python will probably never have. So it's not surprising that Julia is simply way more capable.
One disatvantage of the language itself is the need for compilation, which isn't that fast in my limited experience. But I would love to hear how much this affects iteration speed.
The same issue exists with Jax. XLA compilation can take up quite a bit of time, especially on larger NN models. And theres no persistent compile cache, so even if you don't change the jitted function you need to wait for compilation again as you restart the process.
Yeah, I'd imagine that for things both Python and Julia can do AD-wise, Python may be preferable since it's interpreted and thus instant-feedback, but all the numerical heavy lifting in packages like Jax and PyTorch is done in fast C++. So you should be getting a more appealing environment for experimentation without losing out on speed.
The Julia crowd touting multiple dispatch all the time is so strange, it’s actually one of the main reasons the language hasn’t had much uptake from what I can tell.
Python is just more approachable and natural to people. Julia should learn from that
Python is just more OO style, so people who have been taught OOP in school are comfortable with it. That will include the vast majority of generic SWEs writing generic CRUD apps.
But personally I find OOP ugly and unnatural, and Julia's model elegant and natural. And far more powerful - Julia programmers are using multiple dispatch to build out scientific computing to a sophistication not seen in any other language.
It might not be your cup of tea if you need to see object.method() in your code, but if you're more mentally flexible and want to build the next generation of technical computing tools, Julia is the place to be right now.
> But personally I find OOP ugly and unnatural, and Julia's model elegant and natural.
This definitely fits with my experience. It took me quite a while to really "get" dispatch-oriented programming as a paradigm, but once I started to get it there was no going back.
Yeah I’m definitely mentally flexible and have coded in many paradigms, I don’t love OO and generally don’t write that way but multiple dispatch as a primary design pattern is odd.
I’ve tried it for close to a year and the ergonomics still felt off, it reminds me of how the scala crowd talked about functional programming, and we’ve seen how that turned out.
I hear this from a lot of people that try Julia and yet the Julia crowds answer is always that they are dumb. Sounds a lot like the scala crowd…
I think 90% of the ergonomics issue is that people want dot notation and tab-autocomplete in their IDE so they can type obj.<tab> and get the methods that operate on obj. Which I agree, some version of that should exist, and there's no real reason it can't exist in Julia. The tooling is just not as mature as other languages.
Julia is far ahead in affordances to write fancy technical code and fairly behind in simple things, like standard affordances to write more ordinary code, or the ability to quickly load in data and make a plot.
I just think it's a misdiagnosis to blame multiple dispatch for this issue. It's much more about the Julia community prioritizing the needs of their target market.
Yeah, it would definitely be technically possible to build some sort of editor tab-complete for methods of a type ala `methodswith` or etc., someone would just have to step up and build it. The lack of this sort of method autocomplete tooling hasn’t ever been a pain point for me, but evidently is for some.
That is not a very charitable characterization of the position of multiple dispatch in Julia. It's not something that's optional: multiple dispatch is essential for the performance that Julia is looking to achieve. If you notice where acceleration DSLs tend to have trouble, you'll notice that it's always at the point where you get beyond built-in float primitives and onto object support. For example, Numba's object mode has the caveat that "code compiled in object mode will often run no faster than Python interpreted code, unless the Numba compiler can take advantage of loop-jitting", where loop jitting is simply the ability to prove that some loop is compatible with moving to the nopython mode.
The reason why Julia is fast is because automatic function specialization to concrete dispatches gives type-grounded functions which allows the complete optimization to occur on high-level looking code (see https://arxiv.org/abs/2109.01950 for type-theoretic proofs). It's basically a combination of (1) define a type system in a way that allows for type-grounded functions and compile-time shape inference (shape as in, byte structure of the structs), (2) define a multiple dispatch system with automatic function specialization on concrete types, (3) have a typed IR which proves and devirtualizes all dispatches before hitting the LLVM JIT. If you simply slap the LLVM JIT on random code, you will not get that performance. But now because multiple dispatch is fundamental to performance in the language, the rest of the "game" for the language is how to design an ergonomic language around this feature and how to teach people to use it effectively as a problem solving tool.
You actually see something similar going on in the world of Jax. With Jax, you need to be able to perform abstract interpretation to the Jax IR. In order for this to be possible with the interpreters Jax has, the functions that are being interpreted need to always have the same computational graph for the same inputs, i.e. they need to be pure functions. This is why Jax is built on functional programming paradigms. It would be similarly uncharitable to say the reason why Jax does not embrace OO is because the developers just love functional programming: the programming paradigm choice clearly falls out of what the tools needs to do.
It remains to be seen if Jax is the tool that makes more people finally embrace functional programming styles, or if enough people see pervasive performance necessary enough to change to the multiple dispatch paradigm of Julia. But what is clear is that tools that are moving away from OOP are not doing so arbitrarily, it's all about whether doing so is beneficial enough to justify the change.
I often wonder why so much effort is being put into shoehorning everything into a single language. Wouldn't it make much more sense to use a fully differentiable DSL for machine learning / xla, then call it from whatever host language you use? This approach has worked really well for SQL for the past couple of decades.
Has it worked really well? I feel ORMs are a sign it hasn't. Though I really enjoy having learned SQL and being able to interact with almost all relational databases.
Not much really... In many languages ORM's require just about as much boiler plate or careful error checking as copy pasting SQL strings with supportive structs... Not saying ORM's are bad, some are OK some are maybe even good. Project/team dependent...
That said, I find the concept of abstracting ML ingredients outside of languages a nice one although its not entirely novel(python's been doing this from day 1 :D). The strength for keeping it in 1 language can be profound though. Compilers can optimize across operations. Calling many atomic functions from an API/server from a client loses that unless implemented carefully. That one language benefit is a big part of what Julia has to offer.
I could see a value addition statement being made if the "whole market" solution included a lot of goodies. But every time I think of what that looks like - I think it looks like Julia in 2-5 years....
Boilerplate. Writing serializers and deserializers by hand is not an efficient use of developer time.
Related to ORMs, but not quite on topic - query building. Type checked queries, parts of which can be passed around business logic, are very powerful and flexible.
There are more and more libraries that let you to write SQL and bind the results into native records (objects, structs) in the host language. I find it an interesting middle ground
being able to do interprocedural cross language analysis seems awesome considering how much code is written in C++, but used in higher level languages.
Part of the challenge is that most formulations of (reverse mode )auto diff wind up requiring extra runtime data structures for the backwards computation step.
There’s been some great work in this space in the past 5 years.
I’ve got some stuff I worked out this fall I’m overdue to write up and share some prototypes : there is a way to do reverse mode auto diff isolated to just being an invisible compiler pass! Without any of the extra complexity in what are otherwise equivalent formulations
Hey! Thanks for the link. We decided to postpone our coroutine-based implementation until native support for multi-shot delimited continuations becomes more stable. The linked section may be found here [1], for posterity.
This misses some discussion on tf.function which does a Python AST level transformation to a static TF computation graph, including dynamic control flow like loops and conditional branches.
IMO automatic differentiation shouldn’t be baked into the compiler, because that forces a bit of a monoculture around one AD system and forces you to accept the trade offs of that particular AD system.
The Julia approach has been instead to expose an interface for compiler plugins that any AD (or other ‘nonstandard interpretation / code transformation) library can access. There’s negative trade offs to to this as well, but I really like the way it’s turned out and I think it’s given us some fantastic tools for non-AD purposes as well.
> fun fact, the Jax folks at Google Brain did have a Python source code transform AD at one point but it was scrapped essentially because of these difficulties
No, autograd acts similarly to PyTorch in that it builds a tape that it reverses while PyTorch just comes with more optimized kernels (and kernels that act on GPUs). The AD that I was referencing was tangent (https://github.com/google/tangent). It was an interesting project but it's hard to see who the audience is. Generating Python source code makes things harder to analyze, and you cannot JIT compile the generated code unless you could JIT compile Python. So you might as well first trace to a JIT-compliable sublanguage and do the actions there, which is precisely what Jax does. In theory tangent is a bit more general, and maybe you could mix it with Numba, but then it's hard to justify. If it's more general then it's not for the standard ML community for the same reason as the Julia tools, but then it better do better than the Julia tools in the specific niche that they are targeting. That generality means that it cannot use XLA, and thus from day 1 it wouldn't get the extra compiler optimizations that some which uses XLA does (Jax). Jax just makes much more sense for the people who were building it, it chose its niche very well.
Indeed, and that makes a lot of sense. The qualm of tangent is that you get the source code translation but without the additional optimizations that the technique can provide. It was then natural to just target TensorFlow/XLA to do a similar thing but get the performance of TensorFlow as a result. The downside is that it loses the one true upside of tangent which was that, by generating Python code, it could in theory be easier for a Python programming to debug. But this was probably the right sacrifice to make for most people.
This is possibly irrelevant to practice: But is perturbation confusion a problem which affects reverse-mode autodiff as well as forward-mode autodiff? With dual-number based approaches, the usual solution is tagging, but this is known to be incorrect when higher order functions get used.
For people who don't know, perturbation confusion is a problem that affects naive implementations of autodiff where derivatives of order bigger than 1 are not computed correctly.
You probably could get perturbation confusion in reverse mode, but it's not an easy trap like it is with forward mode. The problem with forward mode AD is that it's deceptively easy: you transform every number a into a pair of numbers (a,b) and you change every function from f(a) to (f(a),f'(a)), put the chain rule on it, and you're done. Whether you call it dual number arithmetic or a compiler transformation it's the same thing in the end. The issue with perturbation confusion is that you've created this "secret extra number" to store the derivative data, and so if two things are differentiating code at the same time, you need to make sure you turn (a,b) into (a,a',b,b') and that all layers of AD are always grabbing the right value out of there. Note that in the way I wrote it, if you assumed "the b term is always num[2] in the tuple", oops perturbation confusion, and so your generated code needs to be "smart" (but not lose efficiency!). Thus the fixes are proofs and tagging systems that ensure the right perturbation terms are always used in the right places.
With reverse mode AD, this is much less likely to be an issue because the AD system isn't necessarily storing and working on hidden extensions to the values, it's running a function forwards and then running a separate function backwards having remembered some values from the forward pass. If the remembered values are correct and never modified, then generating a higher order derivative is just as safe as the first. But that last little detail is thus what I think is most akin to perturbation confusion in reverse mode: reverse mode has the assumption that the objects captured in the forward pass will not be changed (or will be at least be back in the correct state) when it is trying to reverse. The easy way to break this assumption doesn't even require second derivatives. The easiest way to break it is mutation: if you walk forward by doing Ax, then the reverse pass wants to do A'v so it just keeps the pointer to A, but if A gets mutated in the meantime then using that pointer is incorrect. This is the reason why most AD systems simply disallow mutation except in very special unoptimized cases (PyTorch, Jax, Zygote, ...).
Enzyme.jl is an exception because it takes a global analysis of the program it's differentiating (with proper escape analysis etc. passes at the LLVM level) in order to know that any mutation going forward will be reversed during the reverse path, so by the time it gets back to A'*v it knows A will be the same. Higher level ADs could go cowboy YOLO style and just assume the reversed matrix is correct (and it might be a lot of the time), though that causes some pretty major concerns for correctness. The other option is to simply make a full copy of A every time you mutate an element, so have fun if you loop through your weight matrix. The Diffractor.jl near future approach is more like Haskell GHC where it just wants you to give it the non-mutating code so it can try and generate the mutating code when that would be more efficient (https://github.com/JuliaLang/julia/pull/42465).
So with forward-mode AD there was an entire literature around schemes of provable safety to perturbation confusion, and I'm surprised we haven't already started seeing papers about provable safety with respect to mutation in higher-level reverse-mode AD. I would suspect that the only reason why it hasn't started is that the people who write type-theoretic proofs tend to be the functional programming pure function folks that tell people to never mutate anyways, so the literature might instead go the direction of escape analysis proofs to optimize immutable array code to (and beyond) the performance of mutation code on commonly mutating applications. Either way it's getting there with the same purpose in mind.
Thank you so much for this post. I was going to have to self discover a lot of this information for an upcoming project, and now its laid right before me. Merry Christmas!!