In my previous two posts "Ways to use torch.compile" and "Ways to use torch.export", I often said that PyTorch would be good for a use case, but there might be some downsides. Some of the downsides are foundational and difficult to remove. But some... just seem like a little something is missing from PyTorch. In this post, here are some things I hope we will end up shipping in 2025!
A programming model for PT2. A programming model is a an abstract description of the system that is both simple (so anyone can understand it and keep it in their head all at once) and can be used to predict the system's behavior. The torch.export programming model is an example of such a description. Beyond export, we would like to help users understand why all aspects of PT2 behave the way it does (e.g., via improved error messages), and give simple, predictable tools for working around problems when they arise. The programming model helps us clearly define the intrinsic complexity of our compiler, which we must educate users about. This is a big effort involving many folks on the PyTorch team and I hope we can share more about this effort soon.
Pre-compilation: beyond single graph export. Whenever someone realizes that torch.compile compilation is taking a substantial amount of time on expensive cluster machines, the first thing they ask is, "Why don't we just compile it in advance?" To support precompiling the torch.compile API exactly as is not so easy; unlike a traditional compiler which gets the source program directly as input, users of torch.compile must actually run their Python program to hit the regions of code that are intended to be compiled. Nor can these regions be trivially enumerated and then compiled: not only must know all the metadata input tensors flowing into a region, a user might not even know what the compiled graphs are if a model has graph breaks.