For any kind of research or experimental work, I cannot imagine using anything other than PyTorch, with the caveat that I do think JAX is extremely impressive and I've been meaning to learn more about it for a while.
Even though I've been working with Tensorflow for a few years now and I feel like I do understand the API pretty well, to some extent that just means I'm _really_ good at navigating the documentation, because there's no way to intuit the way things work. And I still run into bizarre performance issues when profiling graphs pretty much all the time. Some ops are just inefficient - oh but it was fixed in 2.x.yy! Oh but then it broke again in 2.x.yy+1! Sigh.
However - and I know this is a bit of a tired trope, but any kind of industrial deployment is just vastly, vastly easier with Tensorflow. I'm currently working with ultra-low-latency model development targeting a Tensorflow-Lite inference engine (C-API, wrapped via Rust) and it's just incredibly easy. With some elbow grease and willingness to dive into low level TF-Lite optimisations, one can see end to end model inference times in the order of 10-100us for simple models (say, a fully connected dnn with a few million parameters), and between 100us-1ms for fairly complex models utilising contemporary architectures in computer vision or NLP. Memory overhead and control over inference computation semantics are easy.
As a nice cherry on top, we can take the same Tensorflow SavedModels that get compiled to TF-Lite files and instead compile them to tensorflow-js for easy web deployment, which is a great portability upside.
However, I know there's some incredible progress being made on what one might call 'environmental agnostic computational graph ILs' (on second thought, let's not keep that name) which should open up more options for inference engines and graph optimisations (operator fusion, rollups, hardware dependant stuff, etc).
Overall I feel like things have been continuously getting better for the last 5 years or so. I'm pleased to see so many more options.
Agreed - JAX is really cool. It will be interesting to see how TF & JAX develop considering they're both made by Google. I also think JAX has the potential to be the fastest, although right now it's neck-and-neck with PyTorch.
Yes - a lot of TF users don't realize that knowing the "tricks of the trade" for wrangling TF don't apply in PT because it just works more easily.
I agree that industry-centric applications should probably use TF. TFX is just invaluable. Have you checked out Google's Coral devices? TFLite + Coral = revolution for a lot of industries.
Thanks for all your comments - I'm also really excited to see what the coming years bring. While we might debate if PT or TF is better, they're both undoubtedly improving very rapidly! So excited to see how ML/DL applications start permeating other industries
>10-100us for simple models (say, a fully connected dnn with a few million parameters)
I basically don't believe you. I'm a researcher in this area (DNNs on FPGAs) and you cannot get these latencies on real models without going to FPGA (and you're not synthesizing Verilog from TF, unless you're one of my competitors...). Just your kernel launch overheads for GPU are on the order of 10ms. For example, here's a talk given at GTC a couple of years ago where they do get down to 35us (on tensorcores) using persistent kernels, but on a mickey mouse network
CPU (where you don't have to deal with async CUDA calls) won't save you either; again here's a paper from USENIX (so you know it's legit) that shows that lowest times for real networks on CPU are ~2ms (and that's on resnet18, far shy of "millions" of weights)
Even though I've been working with Tensorflow for a few years now and I feel like I do understand the API pretty well, to some extent that just means I'm _really_ good at navigating the documentation, because there's no way to intuit the way things work. And I still run into bizarre performance issues when profiling graphs pretty much all the time. Some ops are just inefficient - oh but it was fixed in 2.x.yy! Oh but then it broke again in 2.x.yy+1! Sigh.
However - and I know this is a bit of a tired trope, but any kind of industrial deployment is just vastly, vastly easier with Tensorflow. I'm currently working with ultra-low-latency model development targeting a Tensorflow-Lite inference engine (C-API, wrapped via Rust) and it's just incredibly easy. With some elbow grease and willingness to dive into low level TF-Lite optimisations, one can see end to end model inference times in the order of 10-100us for simple models (say, a fully connected dnn with a few million parameters), and between 100us-1ms for fairly complex models utilising contemporary architectures in computer vision or NLP. Memory overhead and control over inference computation semantics are easy.
As a nice cherry on top, we can take the same Tensorflow SavedModels that get compiled to TF-Lite files and instead compile them to tensorflow-js for easy web deployment, which is a great portability upside.
However, I know there's some incredible progress being made on what one might call 'environmental agnostic computational graph ILs' (on second thought, let's not keep that name) which should open up more options for inference engines and graph optimisations (operator fusion, rollups, hardware dependant stuff, etc).
Overall I feel like things have been continuously getting better for the last 5 years or so. I'm pleased to see so many more options.