IT 지식 창고
(pytorch) pytorch lightning 사용 시, 이어서 학습하기
casim
2023. 2. 8. 18:34
컴퓨터의 강제종료, RAM 또는 VRAM의 부족 등과 같은 부득이한 상황에서 학습이 중단된 경우 이어서 학습하는 방법이 있습니다.
물론, 이어서 학습하기 위해서는 중단되기전에 .ckpt file을 저장해야합니다. (callbacks.ModelCheckpoint() 함수 활용)
trainer.fit(model, train, val, ckpt_path=checkpoint_path)
pytorch lightning으로 학습할 때 사용하는 fit 함수 안에 ckpt_path에 이어서 학습할 .ckpt file경로를 넣어주면 됩니다.
그러면, 해당 .ckpt를 사용하여 weight를 update하고, 시작 epoch을 이어서 설정합니다. (만약 10epoch에서 중단되었다면, 11epoch부터 다시 시작합니다.)
* keras와 같은 경우는 이어서 학습 시 initial_epoch이라는 변수가 있기 때문에 시작 epoch을 직접 설정한 후 이어서 학습합니다.