원인이 되는 연산 찾기
먼저 torch.autograd 함수 중에 NaN loss가 발생했을 경우 원인을 찾아주는 함수가 있다.
autograd.set_detect_anomaly(True)
학습 코드에 위 코드를 추가해주고 실험을 하면, NaN loss가 발생하는 즉시 실행이 멈추고 NaN을 유발한 라인을 출력해준다. 주로 division by zero나 매우 작은 값에 대한 log 연산이 NaN loss를 유발한다. NaN은 loss 연산 뿐만 아니라 forward 연산, backward 연산에서도 발생할 수 있으므로 직접 찾으려면 힘든데, 위 코드를 쓰면 간편하다.
연산 수정하기
나누는 연산이 있는데 divisor가 0이 될 수 있는 경우라면, 예외 처리를 해주거나 divisor에 1e-6 등 연산에 영향을 끼치지 않는 작은 상수를 더해주면 된다. log도 마찬가지다. log(x)에서 x가 매우 작은 값이 될 수 있다면, x에 상수를 더해주면 된다. 또는 NaN 값을 0으로 바꾸어주는 torch 함수를 쓰자.
a = torch.nan_to_num(a)
주의할 점은 nan_to_num은 PyTorch 1.8.0 이후부터 지원된다.
원인이 되는 연산을 알았지만 이유를 모르겠다
- Gradient exploding / vanishing
원인이 되는 레이어의 weight과 grad를 출력해보면 알 수 있다.
torch.any(torch.isnan(weight)) # weight에 NaN 존재 여부
model.layer.grad # layer의 gradient
- Learning rate이 너무 높을 경우
- PyTorch 내장함수 중 나눗셈 연산이 있는 함수를 썼을 경우
댓글