
[PyTorch] loss nan 문제 해결
Study/CS
2024. 11. 25. 20:13
과제 중 train, validation loss 모두 nan이 나오는 현상을 확인했다.에러 로그도 없어 골치 아프던 와중 발견한 아래의 코드torch.autograd.set_detect_anomaly(True) autograd 중 어디에서 오류가 발생했는지 감지하고 출력하는 함수로, 코드 첫 줄에 넣어주면 된다. 로그에서 보이듯 backward 중 sqrt(0)의 derivative 계산에서 nan이 나왔다.torch.sqrt(x) 대신 torch.sqrt(x + 1e-8)로 바꾸어 해결하였다.