PyTorchのMulti-GPUでメモリ使用量の偏りと改善策

よくあるGPUメモリの偏り

PyTorchでマルチGPUをすると, こんな感じで f:id:atksh:20190317221346p:plain 1つのGPUのメモリだけ他に比べて多い (2倍以上)場合って OOMでバッチサイズが増やせずに悲しいですよね。

解決法

以下の方法で,8 GPUでOOMしないバッチサイズが16->32になりました。

forwardにlossの計算までを入れる

i.e.,

class Model(nn.Module):
  def __init__(self, d_in, n_class):
    super().__init__()
    self. n_class = n_class
    self.fc = nn.Linear(d_in, n_class)

  def forward(self, x, labels=None):
    x = self.fc(x)
    if labels:
      ### ここが重要!!! ###
      return x, F.cross_entropy(x.view(-1, self.n_class), labels.view(-1))
    else:
      return x

model = DataParallel(Model(100, 10))
for x, labels in dataloader:
  pred, loss = model(x, labels)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

criteriaを使ってloss計算すると,メインGPUのみでlossの計算が おこなわれることが,先のメモリ使用量の偏りを引き起こしているらしいです*1

よく見てみたら,BERT等のPyTorch実装も同様のforward実装になってました。最近のベストプラクティスなのかもしれない?