The same issue exists with Jax. XLA compilation can take up quite a bit of time, especially on larger NN models. And theres no persistent compile cache, so even if you don't change the jitted function you need to wait for compilation again as you restart the process.