|
Convex Neural Networks via Operator Splitting |
|
Welcome to the official implementation for the CRONOS project! Check out the paper for more details.
We introduce the CRONOS algorithm for convex optimization of two-layer neural networks. This repo contains the official JAX implementation of the CRONOS paper, and allows installation as a handy pip package for all your binary classification needs.
- CRONOS: Uses convex optimization to train two-layer neural networks efficiently at scale. Experiments include fullsize ImageNet, downsampled ImageNet, IMDb, Food, FMNIST, CIFAR-10, MNIST, and synthetic datasets.
- CRONOS-AM: CRONOS with Alternating Minimization. This extension allows training of multi-layer networks with arbitrary architectures (MLP, CNN, GPT, etc.).
- Scalability: CRONOS can handle high-dimensional datasets.
- Convergence: Our theoretical analysis demonstrates that CRONOS converges to the global minimum of the convex reformulation under mild assumptions.
- Performance: Large-scale numerical experiments with GPU acceleration in JAX. Optimized to be VRAM friendly without sacrificing speed.
Clone the repository and install from source:
git clone https://github.com/pilancilab/CRONOS.git
cd CRONOS
pip install -e .If you use this code in your work, please cite the paper:
@inproceedings{feng2024cronos,
title = {CRONOS: Convex Neural Networks via Operator Splitting},
author = {Feng, Miria and Frangella, Zachary and Pilanci, Mert},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2024},
url = {https://arxiv.org/abs/2411.01088}
}- add in jupyter demo
- hydra + omegaconf (user sets dataset, add new dataset, template loader)
- add in instructions for vision and GPT2, especially GPT2 (3 step run process)
- RTX4090 minimum, JAX, NVIDIA, CUDA, NVIDIA driver versions
- add in sharding here, or in separate codebase?
- consolidate 3 step run process for gpt, consolidate 2 runners
- populate tests for all modules
- populate requirements.txt
- push to PyPI
