I've been trying to train a larger LLM over 2 nodes with 2 GPUs each.. such a bad UX... NVIDA / pytorch / keras / JAX team should fix that. I shouldn't have to deal with all this data/model/context parallelism... i shouldnt “need” to control what device everything is on.. I should just specify 1) the model, and 2) a list of nodes with GPUs avail to PyTorch... and like a heap allocator, your library should look up all the available VRAM across all nodes and all GPUs and solve which device the compute goes to. Move tensors around as it’s going.. reshuffle things until it just works. Whatever will make it work.. No more CUDA OOMs errors!! 😡😡 is there such an “auto-placement” mechanism yet?