과제 중 train, validation loss 모두 nan이 나오는 현상을 확인했다.
에러 로그도 없어 골치 아프던 와중 발견한 아래의 코드
torch.autograd.set_detect_anomaly(True)
autograd 중 어디에서 오류가 발생했는지 감지하고 출력하는 함수로, 코드 첫 줄에 넣어주면 된다.
로그에서 보이듯 backward 중 sqrt(0)의 derivative 계산에서 nan이 나왔다.
torch.sqrt(x) 대신 torch.sqrt(x + 1e-8)로 바꾸어 해결하였다.
'Study > CS' 카테고리의 다른 글
Lab3: Attacklab (phase 5) (0) | 2023.11.05 |
---|---|
Lab3: Attacklab (phase 4) (0) | 2023.11.04 |
Lab3: Attacklab (phase 3) (0) | 2023.11.01 |
Lab3: Attacklab (phase 2) (1) | 2023.10.30 |
Lab3: Attacklab (phase 1) (0) | 2023.10.29 |