这是《零:训练一万亿个参数模型的内存优化》一文中介绍的零 DP 的实现,
它将优化器状态、梯度和参数的分片保存到多个设备/节点中。它减少了原始模型的内存消耗,其中是参数的数量,是分片的数量,是每个参数的优化器字节数。是假设精度为 16 位的参数和梯度存储器;即每个参数和梯度为 2 个字节。对于 Adam 优化器,因为它维护参数的副本,在 fp32 中每个参数两个时刻。
零 DP 的通信量为。比较而言,数据并行训练的通信量为。
尽管它被命名了Zero3
,但我们只实现了其中的零 DP 部分,没有实现针对剩余内存消耗的 Zero-R 内存优化。Out 实现仅支持训练一部分参数。
此实施的灵感来自公平规模的财务安全发展计划。
32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn40class Zero3Layer(nn.Module):每个分片都将参数保存在chunk
列表中。用于chunk[0]
可训练的参数,chunk[1]
用于固定参数。
49    chunk: List[nn.Parameter]这是chunk
列表中区块的大小。
51    chunk_size: List[int]第一个区块用于可训练的参数。
53    TRAINING_PARAMS_IDX = 0这是分为可训练参数和固定参数的列表的参数列表。
56    param_refs: List[List[nn.Parameter]]CUDA 流到精选参数
59    fetch_stream: Optional[torch.cuda.Stream]用于备份/累积梯度的 CUDA 流
61    backup_stream: Optional[torch.cuda.Stream]此图层之前的图层列表
63    prev_layer: List['Zero3Layer']紧随此图层之后的图层列表
65    next_layer: List['Zero3Layer']当前层的位置;用于调试日志
67    layer_idx: int参数是否已获取
70    is_fetched: bool该层的设备
73    device: torch.device图层的数据类型
75    dtype: torch.dtype要封装的模块
77    module: nn.Module分片数据的节点/设备数量
79    world_size: intmodule
要封装的模块。rank
当前节点的等级。world_size
分片数据的节点/设备数量。device
层的设备。dtype
图层的数据类型。81    def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):89        super().__init__()初始化属性
92        self.device = device
93        self.dtype = dtype
94        self.module = module
95        self.prev_layer = []
96        self.next_layer = []
97        self.is_fetched = False
98        self.world_size = world_size
99        self.layer_idx = -1
100        self.fetch_stream = None
101        self.backup_stream = None
102
103        with torch.no_grad():收集图层的所有参数
105            all_param_refs = [p for p in self.parameters()]存储参数的形状,因为我们稍后需要它来重建它们
108            for p in all_param_refs:
109                p._orig_shape = p.shape所有参数都应具有相同的类型
112            for p in all_param_refs:
113                assert p.dtype == dtype, "All parameters should have same dtype"将参数分为可训练和固定
116            self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117                               [p for p in all_param_refs if not p.requires_grad]]
118            del all_param_refs该rank = 0
节点将计算每个设备/节点应存储的大小,并相应地分配参数。
122            if rank == 0:合并和填充可训练 (merged_params[0]
) 和 fixed (merged_params[1]
) 参数
124                merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]计算可训练参数和固定参数的区块大小
126                self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]广播尺寸
128                dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129            else:创建一个空张量来接收大小
131                chunk_size = torch.tensor([0, 0], device=device)收到尺码
133                dist.broadcast(chunk_size, src=0)
134                self.chunk_size = chunk_size.tolist()为要存储在当前设备/节点中的可训练 (self.chunk[0]
self.chunk[1]
) 和 fixed () 参数创建参数
138            self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139                          for i, s in enumerate(self.chunk_size)]一个空张量,用于接收可训练参数和固定参数的组合
142            chunk = self._empty((sum(self.chunk_size),))
143
144            if rank == 0:连接可训练参数和固定参数
146                all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147                del merged_params将它们分散到所有节点/设备
150                dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151                del all_params
152            else:接收参数
154                dist.scatter(chunk)收集区块数据
157            chunk = chunk.split(self.chunk_size)
158            for i, c in enumerate(chunk):
159                self.chunk[i].data[:] = c
160            del chunk清理普通参数
163            self._cleanup_params()添加一个向后钩子。当计算相对于模块的梯度时,会调用该函数。
166            self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook)  # type: ignoreworld_size
。168    def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:参数总数
173        size = sum(p.shape.numel() for p in params)如果它不能被整除world_size
,请填充它
176        if size % self.world_size != 0:
177            padding_fixed = self.world_size - (size % self.world_size)否则,无需填充
179        else:
180            padding_fixed = 0创建一个空的填充张量
182        padding = self._empty((padding_fixed,))连接所有参数并填充它
184        return torch.cat([p.view(-1) for p in params] + [padding], dim=0)186    def get_trainable_chunk(self) -> List[nn.Parameter]:如果没有可训练的参数,则返回空列表
193        if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194            return []将可训练区块作为列表返回
197        return [self.chunk[self.TRAINING_PARAMS_IDX]]199    def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:203        return torch.empty(shape, device=self.device, dtype=self.dtype)205    @torch.no_grad()
206    def _cleanup_params(self):设置标志以指示未读取参数
214        self.is_fetched = False遍历所有参数
217        for ps in self.param_refs:
218            for p in ps:在进行任何新操作之前,请等待对参数的操作完成
220                p.data.record_stream(torch.cuda.current_stream())检查以确保该参数不与其他任何内容共享存储
222                assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."226                p.data.storage().resize_(0)  # This is what actually clears the memory确保参数没有梯度数据
228                assert p.grad is None, 'Gradients should be None'230    @torch.no_grad()
231    def fetch_params(self):已获取 Skip
239        if self.is_fetched:
240            return设置旗帜
243        self.is_fetched = True如果没有要获取或共享的内容,请跳过。
246        if sum(self.chunk_size) == 0:
247            returnfetch_stream
使用从所有分片中获取参数
250        with torch.cuda.stream(self.fetch_stream):创建一个空张量来接收参数
252            buffer = self._empty((self.world_size * sum(self.chunk_size),))将连续缓冲区拆分为节点数。这些拆分是 “缓冲区” 的视图。
254            buffers = list(buffer.split(sum(self.chunk_size)))连接可训练和固定区块
257            chunk = torch.cat(self.chunk, dim=0)从所有节点/设备收集参数
260            dist.all_gather(buffers, chunk)将收集的参数拆分为可训练的和固定的区块
263            params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)等待收集操作完成,然后清除对缓冲区的引用
265            buffer.record_stream(self.fetch_stream)
266            for b in buffers:
267                b.record_stream(self.fetch_stream)
268            buffer.record_stream(self.fetch_stream)
269            del buffer
270            del buffers将可训练和固定参数重塑为连续张量
273            params = [p.reshape(-1) for p in params]收集单个参数张量
276            for cont, ps in zip(params, self.param_refs):如果没有参数,请跳过
278                if not ps:
279                    continue连续张量的偏移量
282                offset = 0遍历模型参数并分配来自连续张量的值
284                for p in ps:原始参数形状
286                    shape = p._orig_shape  # type: ignore[attr-defined]更改参数的存储大小。这是我们清理参数时设置的。
288                    p.data.storage().resize_(shape.numel())从连续张量中分配值
290                    p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)等待操作完成后才能执行其他操作
292                    p.data.record_stream(self.fetch_stream)更新偏移量
294                    offset += shape.numel()等待操作完成后才能执行其他操作
297                cont.record_stream(self.fetch_stream)300            del params302    def forward(self, *args, **kwargs):获取当前节点的所有参数。这被前一层调用,所以这个调用只是为了确保参数被抓取。
309        self.fetch_params()等待参数提取完成。
312        torch.cuda.current_stream().wait_stream(self.fetch_stream)开始获取后续层的参数,以便它们将获取当前层进行计算的参数。
316        for layer in self.next_layer:
317            layer.fetch_params()启用了 autograd,则向当前层的参数添加向后挂钩。
320        if torch.is_grad_enabled():
321            self._add_backward_hooks()计算当前图层的输出
324        res = self.module(*args, **kwargs)330        if not torch.is_grad_enabled() or self.next_layer:
331            self._cleanup_params()
332
333        return res335    def _add_backward_hooks(self):添加的向后钩子数量
341        self._backward_hook_handles = 0循环浏览当前图层的可训练参数
344        for p in self.param_refs[self.TRAINING_PARAMS_IDX]:确保尚未添加挂钩
346            assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'expand_as
用于创建我们可以拦截的 autograd 步骤
348            p_tmp = p.expand_as(p)获取一个手柄来添加向后钩。这篇博客讨论grad_acc
了.
351            grad_acc = p_tmp.grad_fn.next_functions[0][0]添加向后挂钩
353            handle = grad_acc.register_hook(
354                functools.partial(self._post_backward_hook, p))保留对手柄的引用
356            p._hook_handle = handle增加添加的钩子数量
358            self._backward_hook_handles += 1360    def _backward_event(self):减少钩子计数器
368        self._backward_hook_handles -= 1如果所有的钩子(包括模块钩子)都被调用了,那么我们可以备份渐变并清理参数。
372        if self._backward_hook_handles == -1:
373            self._backup_grads()
374            self._cleanup_params()开始获取前一图层的参数,因为 autograd 接下来将处理该图层的渐变。
377        for layer in self.prev_layer:
378            layer.fetch_params()380    def _post_backward_hook(self, p: nn.Parameter, *args):从参数中移除句柄
385        p._hook_handle.remove()  # type: ignore[attr-defined]
386        delattr(p, "_hook_handle")处理向后事件
389        self._backward_event()391    def _backward_hook(self, *args, **kwargs):处理向后事件
396        self._backward_event()上一层将开始计算梯度。我们需要确保它已经完成了参数的获取。
399        torch.cuda.current_stream().wait_stream(self.fetch_stream)402        return None404    @torch.no_grad()
405    def _backup_grads(self):如果没有可训练的参数,则跳过
410        if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411            return使用备份流备份渐变
414        with torch.cuda.stream(self.backup_stream):用于存储渐变的缓冲区
416            buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))将连续缓冲区拆分为多个节点。这些拆分是 “缓冲区” 的视图。
418            buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))连续缓冲区的偏移量
421            offset = 0遍历可训练的参数
423            for p in self.param_refs[self.TRAINING_PARAMS_IDX]:收集渐变
425                shape = p._orig_shape  # type: ignore[attr-defined]
426                buffer[offset: offset + shape.numel()] = p.grad.view(-1)更新偏移量
428                offset += shape.numel()清理渐变
430                p.grad = None空张量累积当前分片的梯度
433            grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))累积每个分片的梯度。它将缓冲区分散到节点上,每个节点累积(减少)它收到的张量。
436            dist.reduce_scatter(grad, buffers)等待操作完成,然后清除对缓冲区的引用
439            for b in buffers:
440                b.record_stream(self.fetch_stream)
441            buffer.record_stream(self.fetch_stream)
442            del buffer
443            del buffers设置分块渐变。这就是优化器所看到的。
446            self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447            del gradZero3Layer
层的顺序模块450class Zero3Sequential(nn.Module):modules
Zero3Layer
图层列表454    def __init__(self, modules: List[Zero3Layer]):458        super().__init__()用于获取参数的 CUDA 流
461        self.fetch_stream = torch.cuda.Stream()用于备份(累积)梯度的 CUDA 流
463        self.backup_stream = torch.cuda.Stream()为每个层设置流以及前面和后面的Zero3Layer
层
466        for i in range(len(modules)):设置图层索引
468            modules[i].layer_idx = i设置直播
470            modules[i].fetch_stream = self.fetch_stream
471            modules[i].backup_stream = self.backup_stream设置后续图层
473            if i + 1 < len(modules):
474                modules[i].next_layer.append(modules[i + 1])设置前面的图层
476            if i - 1 >= 0:
477                modules[i].prev_layer.append(modules[i - 1])存储模块清单
480        self.module_list = nn.ModuleList(modules)482    def get_trainable_chunk(self):返回每层可训练区块的列表
484        return sum([m.get_trainable_chunk() for m in self.module_list], [])486    def forward(self, x: torch.Tensor):确保渐变备份已完成
488        torch.cuda.current_stream().wait_stream(self.backup_stream)向前传球
491        for m in self.module_list:
492            x = m(x)495        return x