TPUs are tightly coupled to JAX and the XLA compiler. If your model is based on Pytorch you can use a bridge to export your model to StableHLO and then compile it to a TPU accelerator. In theory the XLA compiler should be more performant than the Pytorch Inductor.