Unsloth is great! They focus on single-GPU and LoRA fine-tuning on NVIDIA GPUs. We are initially trying to target multi-node, multi-TPU, full-precision training use cases.
That said, in terms of single-GPU speed, we believe we would be behind but not too far off, thanks to JAX+TPU's more performant stack. Additionally, we can do larger-scale multi-node training on TPUs.
There are still more optimizations we need to do for Llama 3.1, such as adding Pallas memory attention kernels, etc