static_argnums is really just a way to give a bit more assumptions to attempt to build a quasi-static code even if it's using dynamic constructs. In this example that will force it to trace one only one of the two branches (depending on whichever static_argnums sends it down). That is going to generate incorrect code for input values which should've traced the other branch (so the real solution of `lax.cond` is to always trace and always compute both branches, as mentioned in the post). If the computation is actually not quasi-static, there's no good choice for a static argnum. See the factorial example.