これは、論文「Zero: 1兆のパラメーターモデルのトレーニングに向けたメモリ最適化」で紹介されているゼロDPの実装です。
オプティマイザの状態、グラデーション、パラメータの断片を複数のデバイス/ノードに保持します。これにより、メモリ消費量が元のモデルと同じになります。ここで、はパラメーターの数、はシャードの数、パラメーターごとのオプティマイザーのバイト数です。は、16 ビットの精度を前提としたパラメーターとグラデーションのメモリです。つまり、パラメーターとグラデーションごとに 2 バイトです。Adam オプティマイザー用です。これは、パラメーターのコピーと fp32 のパラメーターごとに 2 つのモーメントを保持しているためです
。ゼロDPの通信量は。比較のためにデータ並行トレーニングの通信量は
.これは名前が付けられていますがZero3
、残留メモリ消費を対象とするゼロRメモリ最適化は実装しておらず、DPがゼロの部分のみを実装しています。この実装では、パラメータのサブセットのみのトレーニングをサポートしています
32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn40class Zero3Layer(nn.Module):chunk
各シャードはパラメータをリストに保持します。chunk[0]
はトレーニング可能なパラメーター用で、chunk[1]
固定パラメーター用です
49    chunk: List[nn.Parameter]chunk
これはリスト内のチャンクのサイズです。
51    chunk_size: List[int]最初のチャンクはトレーニング可能なパラメーター用です。
53    TRAINING_PARAMS_IDX = 0これは、トレーニング可能なパラメーターと固定パラメーターとしてリストに分割されたパラメーターのリストです。
56    param_refs: List[List[nn.Parameter]]パラメータを取得する CUDA ストリーム
59    fetch_stream: Optional[torch.cuda.Stream]勾配をバックアップ/蓄積するためのCUDAストリーム
61    backup_stream: Optional[torch.cuda.Stream]このレイヤーの直前のレイヤーのリスト
63    prev_layer: List['Zero3Layer']このレイヤーの直後のレイヤーのリスト
65    next_layer: List['Zero3Layer']現在のレイヤーの位置。これをログのデバッグに使用しました
67    layer_idx: intパラメータが取得されているかどうか
70    is_fetched: boolレイヤーのデバイス
73    device: torch.deviceレイヤーのデータタイプ
75    dtype: torch.dtypeラップするモジュール
77    module: nn.Moduleデータがシャーディングされるノード/デバイスの数
79    world_size: intmodule
ラップするモジュール。rank
現在のノードのランク。world_size
データがシャーディングされるノード/デバイスの数。device
レイヤーのデバイス。dtype
レイヤーのデータタイプ。81    def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):89        super().__init__()プロパティを初期化
92        self.device = device
93        self.dtype = dtype
94        self.module = module
95        self.prev_layer = []
96        self.next_layer = []
97        self.is_fetched = False
98        self.world_size = world_size
99        self.layer_idx = -1
100        self.fetch_stream = None
101        self.backup_stream = None
102
103        with torch.no_grad():レイヤーのすべてのパラメーターを収集します。
105            all_param_refs = [p for p in self.parameters()]パラメータの形状を保存しておきます。後で再構築する必要があるからです。
108            for p in all_param_refs:
109                p._orig_shape = p.shapeすべてのパラメータは同じタイプでなければなりません
112            for p in all_param_refs:
113                assert p.dtype == dtype, "All parameters should have same dtype"トレーニング可能なパラメータと固定パラメータを分離
116            self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117                               [p for p in all_param_refs if not p.requires_grad]]
118            del all_param_refsrank = 0
ノードは、各デバイス/ノードが保存するサイズを計算し、それに応じてパラメータを分散します。
122            if rank == 0:トレーニング可能 (merged_params[0]
) パラメーターと固定 (merged_params[1]
) パラメーターをマージしてパディングする
124                merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]トレーニング可能なパラメータと固定パラメータのチャンクサイズを計算します
126                self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]サイズをブロードキャスト
128                dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129            else:空のテンソルを作成してサイズを受け取る
131                chunk_size = torch.tensor([0, 0], device=device)サイズを受け取る
133                dist.broadcast(chunk_size, src=0)
134                self.chunk_size = chunk_size.tolist()trainable (self.chunk[0]
) パラメーターと fixed () パラメーターのパラメーターを作成して、現在のデバイス/ノードに保存します self.chunk[1]
138            self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139                          for i, s in enumerate(self.chunk_size)]トレーニング可能なパラメーターと固定パラメーターを組み合わせて受け取る空のテンソル
142            chunk = self._empty((sum(self.chunk_size),))
143
144            if rank == 0:トレーニング可能なパラメータと固定パラメータの両方を連結する
146                all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147                del merged_paramsそれらをすべてのノード/デバイスに散らす
150                dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151                del all_params
152            else:パラメータを受け取る
154                dist.scatter(chunk)チャンクデータを収集する
157            chunk = chunk.split(self.chunk_size)
158            for i, c in enumerate(chunk):
159                self.chunk[i].data[:] = c
160            del chunk標準パラメータをクリーンアップ
163            self._cleanup_params()後方フックを追加。これは、モジュールを基準としたグラデーションが計算されるときに呼び出されます
。166            self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook)  # type: ignoreworld_size
次の数で割り切れるようにパディングします。168    def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:パラメータの総数
173        size = sum(p.shape.numel() for p in params)world_size
割り切れない場合はパディングしてください
176        if size % self.world_size != 0:
177            padding_fixed = self.world_size - (size % self.world_size)それ以外の場合は、パッドする必要はありません
179        else:
180            padding_fixed = 0空のパディングテンソルの作成
182        padding = self._empty((padding_fixed,))すべてのパラメータを連結してパディングします。
184        return torch.cat([p.view(-1) for p in params] + [padding], dim=0)186    def get_trainable_chunk(self) -> List[nn.Parameter]:トレーニング可能なパラメータがない場合はリストを返して空にする
193        if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194            return []トレーニング可能なチャンクをリストとして返す
197        return [self.chunk[self.TRAINING_PARAMS_IDX]]199    def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:203        return torch.empty(shape, device=self.device, dtype=self.dtype)205    @torch.no_grad()
206    def _cleanup_params(self):パラメータがフェッチされないことを示すフラグを設定します。
214        self.is_fetched = Falseすべてのパラメータを反復処理
217        for ps in self.param_refs:
218            for p in ps:パラメータの操作が完了するまで待ってから、新しい操作を行います。
220                p.data.record_stream(torch.cuda.current_stream())パラメータが他のものとストレージを共有していないことを確認してください
222                assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."ストレージのサイズをに変更します。これにより、パラメータが使用していたメモリが解放されます。
autograd p.data
 グラフはメモリへの参照を保持するので、設定してもメモリは解放されません。
226                p.data.storage().resize_(0)  # This is what actually clears the memoryパラメータに勾配データがないことを確認してください
228                assert p.grad is None, 'Gradients should be None'230    @torch.no_grad()
231    def fetch_params(self):スキップはすでに取得されています
239        if self.is_fetched:
240            returnフラグを設定
243        self.is_fetched = True取得または共有するものがない場合はスキップしてください。
246        if sum(self.chunk_size) == 0:
247            returnfetch_stream
を使用してすべてのシャードからパラメータを取得します。
250        with torch.cuda.stream(self.fetch_stream):空のテンソルを作成してパラメーターを受け取る
252            buffer = self._empty((self.world_size * sum(self.chunk_size),))連続バッファをノード数に分割します。これらの分割は「buffer」のビューです
。254            buffers = list(buffer.split(sum(self.chunk_size)))トレーニング可能なチャンクと固定チャンクの両方を連結する
257            chunk = torch.cat(self.chunk, dim=0)すべてのノード/デバイスからパラメータを収集します
260            dist.all_gather(buffers, chunk)収集したパラメーターをトレーニング可能なチャンクと固定チャンクに分割します
263            params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)収集操作が完了するのを待ってから、バッファへの参照をクリアします。
265            buffer.record_stream(self.fetch_stream)
266            for b in buffers:
267                b.record_stream(self.fetch_stream)
268            buffer.record_stream(self.fetch_stream)
269            del buffer
270            del buffersトレーニング可能なパラメーターと固定パラメーターを連続テンソルにリシェイプ
273            params = [p.reshape(-1) for p in params]個々のパラメータテンソルを収集する
276            for cont, ps in zip(params, self.param_refs):パラメータがない場合はスキップしてください
278                if not ps:
279                    continue連続テンソルのオフセット
282                offset = 0モデルパラメーターを繰り返し処理し、連続テンソルから値を割り当てます
284                for p in ps:オリジナルのパラメータシェイプ
286                    shape = p._orig_shape  # type: ignore[attr-defined]パラメータのストレージサイズを変更します。これは、パラメータをクリーンアップしたときに設定されました。
288                    p.data.storage().resize_(shape.numel())連続テンソルから値を割り当てます
290                    p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)操作が完了するまで待ってから、他の操作を実行してください
292                    p.data.record_stream(self.fetch_stream)オフセットの更新
294                    offset += shape.numel()操作が完了するまで待ってから、他の操作を実行してください
297                cont.record_stream(self.fetch_stream)300            del params302    def forward(self, *args, **kwargs):現在のノードのすべてのパラメータを取得します。これは前のレイヤーから呼び出されるので、この呼び出しはパラメーターがフェッチされていることを確認するためだけのものです
。309        self.fetch_params()パラメータの取得が完了するまでお待ちください。
312        torch.cuda.current_stream().wait_stream(self.fetch_stream)処理中のレイヤーのパラメーターの取得を開始します。そうすれば、現在のレイヤーが計算を行うパラメーターが取得されます。
316        for layer in self.next_layer:
317            layer.fetch_params()autograd が有効になっている場合は、現在のレイヤーのパラメーターに逆方向フックを追加します。
320        if torch.is_grad_enabled():
321            self._add_backward_hooks()現在のレイヤーの出力を計算します
324        res = self.module(*args, **kwargs)レイヤーのパラメーターをクリーンアップします。
autograd が有効になっていて、これがネットワークの最後のレイヤーである場合は、後方パスのパラメーターを再度取得する必要があるため、クリーンアップをスキップしてください。
330        if not torch.is_grad_enabled() or self.next_layer:
331            self._cleanup_params()
332
333        return res335    def _add_backward_hooks(self):追加された後方フックの数
341        self._backward_hook_handles = 0現在のレイヤーのトレーニング可能なパラメーターをループスルーする
344        for p in self.param_refs[self.TRAINING_PARAMS_IDX]:フックがまだ追加されていないことを確認してください
346            assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'expand_as
インターセプトできるオートグラードのステップを作るのに使う
348            p_tmp = p.expand_as(p)後ろ向きフックを取り付けるためのハンドルを用意してください。このブログでは、. について説明しますgrad_acc
351            grad_acc = p_tmp.grad_fn.next_functions[0][0]後方フックを追加
353            handle = grad_acc.register_hook(
354                functools.partial(self._post_backward_hook, p))ハンドルへの参照を忘れずに
356            p._hook_handle = handle追加するフックの数を増やしてください
358            self._backward_hook_handles += 1360    def _backward_event(self):フックカウンターをデクリメントしてください
368        self._backward_hook_handles -= 1すべてのフック (モジュールフックを含む) が呼び出されたら、グラデーションをバックアップしてパラメータをクリーンアップできます。
372        if self._backward_hook_handles == -1:
373            self._backup_grads()
374            self._cleanup_params()autograd が次にそのレイヤーのグラデーションを処理するので、前のレイヤーのパラメーターの取得を開始します。
377        for layer in self.prev_layer:
378            layer.fetch_params()380    def _post_backward_hook(self, p: nn.Parameter, *args):パラメータからハンドルを削除します
385        p._hook_handle.remove()  # type: ignore[attr-defined]
386        delattr(p, "_hook_handle")逆方向イベントの処理
389        self._backward_event()391    def _backward_hook(self, *args, **kwargs):逆方向イベントの処理
396        self._backward_event()前のレイヤーがグラデーションの計算を開始します。パラメータの取得が完了したことを確認する必要があります
。399        torch.cuda.current_stream().wait_stream(self.fetch_stream)402        return None404    @torch.no_grad()
405    def _backup_grads(self):トレーニング可能なパラメータがない場合はスキップ
410        if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411            returnバックアップストリームを使用してグラデーションをバックアップします
414        with torch.cuda.stream(self.backup_stream):グラデーションを保存するバッファ
416            buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))連続バッファを複数のノードに分割します。これらの分割は「buffer」のビューです
。418            buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))連続バッファのオフセット
421            offset = 0トレーニング可能なパラメーターを繰り返し処理
423            for p in self.param_refs[self.TRAINING_PARAMS_IDX]:グラデーションを集める
425                shape = p._orig_shape  # type: ignore[attr-defined]
426                buffer[offset: offset + shape.numel()] = p.grad.view(-1)オフセットの更新
428                offset += shape.numel()グラデーションをきれいにする
430                p.grad = Noneテンソルを空にすると、現在のシャードの勾配が蓄積されます
433            grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))各シャードのグラデーションを累積します。バッファをノード全体に分散させ、各ノードは受け取ったテンソルを蓄積(削減)します
。436            dist.reduce_scatter(grad, buffers)操作が完了するのを待ってから、バッファへの参照をクリアします。
439            for b in buffers:
440                b.record_stream(self.fetch_stream)
441            buffer.record_stream(self.fetch_stream)
442            del buffer
443            del buffersチャンクのグラデーションを設定します。これがオプティマイザーの見解です
。446            self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447            del gradZero3Layer
レイヤー用シーケンシャルモジュール450class Zero3Sequential(nn.Module):modules
Zero3Layer
レイヤーのリスト454    def __init__(self, modules: List[Zero3Layer]):458        super().__init__()パラメータを取得するための CUDA ストリーム
461        self.fetch_stream = torch.cuda.Stream()グラデーションをバックアップ(蓄積)するCUDAストリーム
463        self.backup_stream = torch.cuda.Stream()各レイヤーのストリームと先行レイヤーと後続レイヤーを設定します Zero3Layer
466        for i in range(len(modules)):レイヤーインデックスを設定
468            modules[i].layer_idx = iストリームを設定
470            modules[i].fetch_stream = self.fetch_stream
471            modules[i].backup_stream = self.backup_stream進行中のレイヤーを設定
473            if i + 1 < len(modules):
474                modules[i].next_layer.append(modules[i + 1])前のレイヤーを設定
476            if i - 1 >= 0:
477                modules[i].prev_layer.append(modules[i - 1])モジュール一覧を保存
480        self.module_list = nn.ModuleList(modules)482    def get_trainable_chunk(self):各レイヤーからトレーニング可能なチャンクのリストを返します
484        return sum([m.get_trainable_chunk() for m in self.module_list], [])486    def forward(self, x: torch.Tensor):グラデーションのバックアップが完了していることを確認してください
488        torch.cuda.current_stream().wait_stream(self.backup_stream)フォワードパス
491        for m in self.module_list:
492            x = m(x)495        return x