One disatvantage of the language itself is the need for compilation, which isn't that fast in my limited experience. But I would love to hear how much this affects iteration speed.
The same issue exists with Jax. XLA compilation can take up quite a bit of time, especially on larger NN models. And theres no persistent compile cache, so even if you don't change the jitted function you need to wait for compilation again as you restart the process.
Yeah, I'd imagine that for things both Python and Julia can do AD-wise, Python may be preferable since it's interpreted and thus instant-feedback, but all the numerical heavy lifting in packages like Jax and PyTorch is done in fast C++. So you should be getting a more appealing environment for experimentation without losing out on speed.