PyTorchのMulti-GPUでメモリ使用量の偏りと改善策
よくあるGPUメモリの偏り
PyTorchでマルチGPUをすると, こんな感じで 1つのGPUのメモリだけ他に比べて多い (2倍以上)場合って OOMでバッチサイズが増やせずに悲しいですよね。
解決法
以下の方法で,8 GPUでOOMしないバッチサイズが16->32になりました。
◯forward
にlossの計算までを入れる
i.e.,
class Model(nn.Module): def __init__(self, d_in, n_class): super().__init__() self. n_class = n_class self.fc = nn.Linear(d_in, n_class) def forward(self, x, labels=None): x = self.fc(x) if labels: ### ここが重要!!! ### return x, F.cross_entropy(x.view(-1, self.n_class), labels.view(-1)) else: return x model = DataParallel(Model(100, 10)) for x, labels in dataloader: pred, loss = model(x, labels) optimizer.zero_grad() loss.backward() optimizer.step()
criteria
を使ってloss計算すると,メインGPUのみでlossの計算が
おこなわれることが,先のメモリ使用量の偏りを引き起こしているらしいです*1。
よく見てみたら,BERT等のPyTorch実装も同様のforward
実装になってました。最近のベストプラクティスなのかもしれない?