I was able to achieve comparable validation loss with MNIST and CIFAR-10 to other DiT implementations. I plan to train on ImageNet soon, and will upda

Search code, repositories, users, issues, pull requests...

submited by
Style Pass
2024-06-10 01:30:07

I was able to achieve comparable validation loss with MNIST and CIFAR-10 to other DiT implementations. I plan to train on ImageNet soon, and will update this README with the results.

I wanted to brush up on my Jax knowledge, and also hadn't implemented a full MMDiT from scratch before. So I figured I'd try to do both at once! :)

TensorBoard is used for logging. Samples will be logged to the samples directory, with the X dimension representing batch and Y dimension representing each iteration of the sampling loop.

train.py contains the training loop, and the main entrypoint for the project. Call the main() function to run the training loop; additionally, the Trainer class can be used to load and train / inference the model independently of this loop.

I decided to not use either Tensorflow or PyTorch data loading, to keep external dependencies to a minimum. Instead datasets are loaded using Datasets, and processed with the process_batch function. To add a new dataset, simply add a new entry to the DATASET_CONFIGS dictionary in train.py.

Leave a Comment