| """ |
| Sharpness-Aware Minimization (SAM) optimizer wrapper. |
| Seeks parameters in flatter minima for better OOD generalization. |
| Reference: Foret et al., "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021) |
| """ |
|
|
| import torch |
|
|
|
|
| class SAM(torch.optim.Optimizer): |
| def __init__(self, params, base_optimizer, rho=0.05, **kwargs): |
| defaults = dict(rho=rho, **kwargs) |
| super().__init__(params, defaults) |
| self.base_optimizer = base_optimizer(self.param_groups, **kwargs) |
|
|
| @torch.no_grad() |
| def first_step(self): |
| grad_norm = self._grad_norm() |
| for group in self.param_groups: |
| scale = group['rho'] / (grad_norm + 1e-12) |
| for p in group['params']: |
| if p.grad is None: |
| continue |
| e_w = p.grad * scale |
| p.add_(e_w) |
| self.state[p]['e_w'] = e_w |
|
|
| @torch.no_grad() |
| def second_step(self): |
| for group in self.param_groups: |
| for p in group['params']: |
| if p.grad is None: |
| continue |
| p.sub_(self.state[p]['e_w']) |
| self.base_optimizer.step() |
|
|
| def _grad_norm(self): |
| shared_device = self.param_groups[0]['params'][0].device |
| norm = torch.norm( |
| torch.stack([ |
| p.grad.norm(p=2).to(shared_device) |
| for group in self.param_groups |
| for p in group['params'] |
| if p.grad is not None |
| ]), |
| p=2, |
| ) |
| return norm |
|
|
| def step(self, closure=None): |
| raise NotImplementedError("SAM requires manual first_step() and second_step() calls") |
|
|
| def zero_grad(self): |
| self.base_optimizer.zero_grad() |
|
|