关于Pytorch 训练时显存ok,但是load checkpoints时显存out of memory问题
这个问题主要是由于一下几点:
在load时先将checkpoints load到了gpu上,再load到model的地址,这样中间就多了一次存储。
模型在load 之前就使用cuda()放在了gpu上,这样也会造成空间使冗余的情况
解决方式:
使用先load 在cpu上,然后load到model地址,最后push到gpu上的操作。
1 | cpts = torch.load(os.path.join(checkponits_dir, "model.pth"), map_location='cpu') |