← Back to Notes

PyTorch Memory Optimization Tips

PyTorchPerformance

PyTorch Memory Optimization Tips

1. Use torch.no_grad() for inference

with torch.no_grad():
    output = model(input)

2. Clear cache when switching tasks

torch.cuda.empty_cache()

3. Use gradient checkpointing for large models

from torch.utils.checkpoint import checkpoint
output = checkpoint(model.layer, input)

4. Mixed precision training

from torch.cuda.amp import autocast
with autocast():
    output = model(input)