Using PyTorch + NumPy? You're making a mistake.

submited by
Style Pass
2024-09-24 15:00:06

Bugs in ML code are notoriously hard to fix - they don’t cause compile errors but silently regress accuracy. Once you have endured the pain and fixed one of these, the lesson is forever etched into your brain, right? Wrong. Recently, an old foe made a comeback - a familiar bug bit me again! As before, the performance improved significantly after fixing it.

The bug was subtle and easy to make. How many others has it done damage to? Curious, I downloaded over a hundred thousand repositories from GitHub that import PyTorch, and analysed their source code. I kept projects that define a custom dataset, use NumPy’s random number generator with multi-process data loading, and are more-or-less straightforward to analyse using abstract syntax trees. Out of these, over 95% of the repositories are plagued by this problem. It’s inside PyTorch’s official tutorial, OpenAI’s code, and NVIDIA’s projects. Even Karpathy admitted falling prey to it.

The canonical way to load, pre-process and augment data in PyTorch is to subclass the torch.utils.data.Dataset and overwrite its __getitem__ method. To apply augmentations, such as random cropping and image flipping, the __getitem__ method often makes use of NumPy to generate random numbers. The map-styled dataset is then passed to the DataLoader to create batches. The training pipeline might be bottlenecked by data pre-processing, and therefore it makes sense to load data in parallel. This can be achieved by increasing the num_workers parameter in the DataLoader object.

Leave a Comment