The choice of the dynamic computation graph [1] of PyTorch made it easier to debug and implement, leading to higher adoption, even though running speed was initially slower (and therefore training cost higher).
Other decisions follow from this one.
Tensorflow started with static and had to move to dynamic at version 2.0, which broke everything. Fragmentation between tensorflow 1, tensorflow 2, keras, jax.
Pytorch's compilation of this computation graph erased the remaining edge of Tensorflow.
Is the battle over ? From a purely computational point, Pytorch solution is very far from optimal and billions of dollars of electricity and GPUs are burned every year, but major players are happy with circular deals to entrench their positions. So at the pace of current AI code development, probably one or two years before Pytorch is old history.
> at the pace of current AI code development, probably one or two years before Pytorch is old history.
Ehhh, I don’t know about that.
Sure, new AI techniques and new models are coming out pretty fast, but when I go to work with a new AI project, they’re often using a version of PyTorch or CUDA from when the project began a year or two ago. It’s been super annoying having to update projects to PyTorch 2.7.0 and CUDA 12.8 so I can run them on RTX 5000 series GPUs.
All this to say: If PyTorch was going to be replaced in a year or two, we’d know the name of its killer by now, and they’d be the talk of HN. Not to mention that at this point all of the PhDs flooding into AI startups wrote their grad work in PyTorch, it has a lot of network lock-in that an upstart would have to overcome by being way better at something PyTorch can never be good at. I don’t even know what that would be.
Bear in mind that it took a few years for Tensorflow to die out due to lock in, and we all knew about PyTorch that whole time.
> a lot of network lock-in that an upstart would have to overcome by being way better at something PyTorch can never be good at
Higher level code migration to the newer framework, is going to 0. You ask your favorite agent (or intern) to port and check that the migration is exact. We already see this in the multitude of deep-learning frameworks.
The day one optimization trick that PyTorch can't do but another framework can, which reduce your training cost 10x and PyTorch is going the way of the dodo.
The day one architecture which can't be implemented in PyTorch get superior performance, and it's bye bye python.
We see this with architectures which require real-time rendering like Gaussian Splatting (Instant Nerf), or the caching strategies for LLM sequence generation.
Pytorch's has 3 main selling point :
- Abstracting away the GPU (or device) specific code, which is due to nvidia's mess : custom optimized kernels, which you are forced to adapt to if you don't want to write custom kernels.
If you don't mind writing optimized kernels, because the machine write them. Or if you don't need Cuda because you can't use Nvidia hardware because for example you are in China. Or if you use custom silicon, like Grok and need your own kernels anyway.
- Automatic differentiation. It's one of its weak point, because they went for easy instead of optimal. They shut themselves off some architectures. Some language like Julia because of the dynamic low-level compilation can do things Pytorch won't even dream about, (but Julia has its own problems mainly related to memory allocations). Here with the pytorch's introduction of the "scan function"[2] we have made our way full circle to Theano, Tensorflow's/Keras ancestor, which is usually the pain point of the bad automatic differentiating strategy chosen by Pytorch.
The optimal solution like all physics Phds which wrote simulations know, is writing custom adjoint code with 'Source Code Transformation' or symbolically : it's not hard but very tedious so it's now a great fit for your LLM (or intern or Phd candidate running 'student gradient descent') if you prove or check your gradient calculation is ok.
- Cluster Orchestration and serialization : a model can be shared with less security risks than arbitrary source code, because you only share weights. A model can be splitted between machines dynamically. But this is also a big weakness because your code rust as you become dependent of versioning, you are locked with the specific version number your model was trained on.
There are two types of stops : soft stops, and hard stops.
- Soft stops is when the dynamic graph computation overhead is too much, which mean you can still calculate, but if you were to write the function manually or with a better framework, you could be 10x faster.
Typical example involve manually unrolling a loop. Or doing kernel fusion. Other typical example is when you have lots of small objects or need to do loops in python because it doesn't vectorize well. Or using the sparsity efficiently by ignoring the zeros.
- Hard stop is when computing the function become impossible, because the memory needed to do the computation in a non optimal way explode. Some times you can get away with just writing customized kernels.
The typical example where you can get away with it are custom attention layers.
Typical example where you can't get away are physics simulations. Like for example the force is the gradient of energy, but you have n^2 interactions between the particles, so if you use anything more than 0 memory preserved during the forward pass per interaction, your memory consumption explode. And typically with things like Lagrangian or Hamiltonian neural networks where you look the discover dynamics of an energy conserving system, you need to be able differentiate at least three times in a row.
There are also energy expanding stops, where you need to find work-around to make it work like if you want to have your parameters changing shape during the optimization process like learning point clouds of growing size, and they spread you thin so they won't be standardized.
I don't know the full list, but back when it came out, TF felt like a crude set of bindings to the underlying c++/CUDA workhorse. PyTorch felt, in contrast, pythonic. It was much closer in feeling to numpy.
I think it was mostly the eager evaluation that made it possible to debug every step in the network forward/backward passes.
Tensorflow didn't have that at the time which made debugging practically impossible.
I’m not sure if such an overview exists, but when caffe2 was still a thing and JAX was a big contender dynamic vs static computational graphs seemed to be a major focus point for people ranking the frameworks.