PyTorch is slowly introducing native support for per-sample gradients (https://github.com/pytorch/pytorch/pull/70141) This is a good reason to get rid of custom grad sampler code
PyTorch is slowly introducing native support for per-sample gradients (pytorch/pytorch#70141)
This is a good reason to get rid of custom grad sampler code