这是《零:训练一万亿个参数模型的内存优化》一文中介绍的零 DP 的实现,
它将优化器状态、梯度和参数的分片保存到多个设备/节点中。它减少了原始模型的内存消耗,其中是参数的数量,是分片的数量,是每个参数的优化器字节数。是假设精度为 16 位的参数和梯度存储器;即每个参数和梯度为 2 个字节。对于 Adam 优化器,因为它维护参数的副本,在 fp32 中每个参数两个时刻。
零 DP 的通信量为。比较而言,数据并行训练的通信量为。
尽管它被命名了Zero3
,但我们只实现了其中的零 DP 部分,没有实现针对剩余内存消耗的 Zero-R 内存优化。Out 实现仅支持训练一部分参数。
此实施的灵感来自公平规模的财务安全发展计划。
32import functools
33from typing import List, Optional, Tuple
34
35import torch
36import torch.distributed as dist
37from torch import nn40class Zero3Layer(nn.Module):每个分片都将参数保存在chunk
列表中。用于chunk[0]
可训练的参数,chunk[1]
用于固定参数。
49 chunk: List[nn.Parameter]这是chunk
列表中区块的大小。
51 chunk_size: List[int]第一个区块用于可训练的参数。
53 TRAINING_PARAMS_IDX = 0这是分为可训练参数和固定参数的列表的参数列表。
56 param_refs: List[List[nn.Parameter]]CUDA 流到精选参数
59 fetch_stream: Optional[torch.cuda.Stream]用于备份/累积梯度的 CUDA 流
61 backup_stream: Optional[torch.cuda.Stream]此图层之前的图层列表
63 prev_layer: List['Zero3Layer']紧随此图层之后的图层列表
65 next_layer: List['Zero3Layer']当前层的位置;用于调试日志
67 layer_idx: int参数是否已获取
70 is_fetched: bool该层的设备
73 device: torch.device图层的数据类型
75 dtype: torch.dtype要封装的模块
77 module: nn.Module分片数据的节点/设备数量
79 world_size: intmodule
要封装的模块。rank
当前节点的等级。world_size
分片数据的节点/设备数量。device
层的设备。dtype
图层的数据类型。81 def __init__(self, module: nn.Module, rank: int, world_size: int, device: torch.device, dtype: torch.dtype):89 super().__init__()初始化属性
92 self.device = device
93 self.dtype = dtype
94 self.module = module
95 self.prev_layer = []
96 self.next_layer = []
97 self.is_fetched = False
98 self.world_size = world_size
99 self.layer_idx = -1
100 self.fetch_stream = None
101 self.backup_stream = None
102
103 with torch.no_grad():收集图层的所有参数
105 all_param_refs = [p for p in self.parameters()]存储参数的形状,因为我们稍后需要它来重建它们
108 for p in all_param_refs:
109 p._orig_shape = p.shape所有参数都应具有相同的类型
112 for p in all_param_refs:
113 assert p.dtype == dtype, "All parameters should have same dtype"将参数分为可训练和固定
116 self.param_refs = [[p for p in all_param_refs if p.requires_grad],
117 [p for p in all_param_refs if not p.requires_grad]]
118 del all_param_refs该rank = 0
节点将计算每个设备/节点应存储的大小,并相应地分配参数。
122 if rank == 0:合并和填充可训练 (merged_params[0]
) 和 fixed (merged_params[1]
) 参数
124 merged_params = [self._merge_and_pad_params(ps) for ps in self.param_refs]计算可训练参数和固定参数的区块大小
126 self.chunk_size = [(len(p) // world_size if p is not None else 0) for p in merged_params]广播尺寸
128 dist.broadcast(torch.tensor(self.chunk_size, device=device), src=0)
129 else:创建一个空张量来接收大小
131 chunk_size = torch.tensor([0, 0], device=device)收到尺码
133 dist.broadcast(chunk_size, src=0)
134 self.chunk_size = chunk_size.tolist()为要存储在当前设备/节点中的可训练 (self.chunk[0]
self.chunk[1]
) 和 fixed () 参数创建参数
138 self.chunk = [nn.Parameter(self._empty((s,)), requires_grad=i == self.TRAINING_PARAMS_IDX)
139 for i, s in enumerate(self.chunk_size)]一个空张量,用于接收可训练参数和固定参数的组合
142 chunk = self._empty((sum(self.chunk_size),))
143
144 if rank == 0:连接可训练参数和固定参数
146 all_params = torch.cat([p.view(world_size, -1) for p in merged_params], dim=-1).view(-1)
147 del merged_params将它们分散到所有节点/设备
150 dist.scatter(chunk, list(all_params.split(sum(self.chunk_size))))
151 del all_params
152 else:接收参数
154 dist.scatter(chunk)收集区块数据
157 chunk = chunk.split(self.chunk_size)
158 for i, c in enumerate(chunk):
159 self.chunk[i].data[:] = c
160 del chunk清理普通参数
163 self._cleanup_params()添加一个向后钩子。当计算相对于模块的梯度时,会调用该函数。
166 self._backward_hook_ref = self.register_full_backward_hook(self._backward_hook) # type: ignoreworld_size
。168 def _merge_and_pad_params(self, params: List[nn.Parameter]) -> torch.Tensor:参数总数
173 size = sum(p.shape.numel() for p in params)如果它不能被整除world_size
,请填充它
176 if size % self.world_size != 0:
177 padding_fixed = self.world_size - (size % self.world_size)否则,无需填充
179 else:
180 padding_fixed = 0创建一个空的填充张量
182 padding = self._empty((padding_fixed,))连接所有参数并填充它
184 return torch.cat([p.view(-1) for p in params] + [padding], dim=0)186 def get_trainable_chunk(self) -> List[nn.Parameter]:如果没有可训练的参数,则返回空列表
193 if len(self.chunk[self.TRAINING_PARAMS_IDX]) == 0:
194 return []将可训练区块作为列表返回
197 return [self.chunk[self.TRAINING_PARAMS_IDX]]199 def _empty(self, shape: Tuple[int, ...]) -> torch.Tensor:203 return torch.empty(shape, device=self.device, dtype=self.dtype)205 @torch.no_grad()
206 def _cleanup_params(self):设置标志以指示未读取参数
214 self.is_fetched = False遍历所有参数
217 for ps in self.param_refs:
218 for p in ps:在进行任何新操作之前,请等待对参数的操作完成
220 p.data.record_stream(torch.cuda.current_stream())检查以确保该参数不与其他任何内容共享存储
222 assert p.data.storage_offset() == 0, "The tensor is not the sole occupant of the storage."226 p.data.storage().resize_(0) # This is what actually clears the memory确保参数没有梯度数据
228 assert p.grad is None, 'Gradients should be None'230 @torch.no_grad()
231 def fetch_params(self):已获取 Skip
239 if self.is_fetched:
240 return设置旗帜
243 self.is_fetched = True如果没有要获取或共享的内容,请跳过。
246 if sum(self.chunk_size) == 0:
247 returnfetch_stream
使用从所有分片中获取参数
250 with torch.cuda.stream(self.fetch_stream):创建一个空张量来接收参数
252 buffer = self._empty((self.world_size * sum(self.chunk_size),))将连续缓冲区拆分为节点数。这些拆分是 “缓冲区” 的视图。
254 buffers = list(buffer.split(sum(self.chunk_size)))连接可训练和固定区块
257 chunk = torch.cat(self.chunk, dim=0)从所有节点/设备收集参数
260 dist.all_gather(buffers, chunk)将收集的参数拆分为可训练的和固定的区块
263 params = buffer.view(-1, sum(self.chunk_size)).split(self.chunk_size, dim=1)等待收集操作完成,然后清除对缓冲区的引用
265 buffer.record_stream(self.fetch_stream)
266 for b in buffers:
267 b.record_stream(self.fetch_stream)
268 buffer.record_stream(self.fetch_stream)
269 del buffer
270 del buffers将可训练和固定参数重塑为连续张量
273 params = [p.reshape(-1) for p in params]收集单个参数张量
276 for cont, ps in zip(params, self.param_refs):如果没有参数,请跳过
278 if not ps:
279 continue连续张量的偏移量
282 offset = 0遍历模型参数并分配来自连续张量的值
284 for p in ps:原始参数形状
286 shape = p._orig_shape # type: ignore[attr-defined]更改参数的存储大小。这是我们清理参数时设置的。
288 p.data.storage().resize_(shape.numel())从连续张量中分配值
290 p.data[:] = cont[offset: offset + shape.numel()].reshape(shape)等待操作完成后才能执行其他操作
292 p.data.record_stream(self.fetch_stream)更新偏移量
294 offset += shape.numel()等待操作完成后才能执行其他操作
297 cont.record_stream(self.fetch_stream)300 del params302 def forward(self, *args, **kwargs):获取当前节点的所有参数。这被前一层调用,所以这个调用只是为了确保参数被抓取。
309 self.fetch_params()等待参数提取完成。
312 torch.cuda.current_stream().wait_stream(self.fetch_stream)开始获取后续层的参数,以便它们将获取当前层进行计算的参数。
316 for layer in self.next_layer:
317 layer.fetch_params()启用了 autograd,则向当前层的参数添加向后挂钩。
320 if torch.is_grad_enabled():
321 self._add_backward_hooks()计算当前图层的输出
324 res = self.module(*args, **kwargs)330 if not torch.is_grad_enabled() or self.next_layer:
331 self._cleanup_params()
332
333 return res335 def _add_backward_hooks(self):添加的向后钩子数量
341 self._backward_hook_handles = 0循环浏览当前图层的可训练参数
344 for p in self.param_refs[self.TRAINING_PARAMS_IDX]:确保尚未添加挂钩
346 assert not hasattr(p, "_hook_handle"), 'Parameter has already been hooked'expand_as
用于创建我们可以拦截的 autograd 步骤
348 p_tmp = p.expand_as(p)获取一个手柄来添加向后钩。这篇博客讨论grad_acc
了.
351 grad_acc = p_tmp.grad_fn.next_functions[0][0]添加向后挂钩
353 handle = grad_acc.register_hook(
354 functools.partial(self._post_backward_hook, p))保留对手柄的引用
356 p._hook_handle = handle增加添加的钩子数量
358 self._backward_hook_handles += 1360 def _backward_event(self):减少钩子计数器
368 self._backward_hook_handles -= 1如果所有的钩子(包括模块钩子)都被调用了,那么我们可以备份渐变并清理参数。
372 if self._backward_hook_handles == -1:
373 self._backup_grads()
374 self._cleanup_params()开始获取前一图层的参数,因为 autograd 接下来将处理该图层的渐变。
377 for layer in self.prev_layer:
378 layer.fetch_params()380 def _post_backward_hook(self, p: nn.Parameter, *args):从参数中移除句柄
385 p._hook_handle.remove() # type: ignore[attr-defined]
386 delattr(p, "_hook_handle")处理向后事件
389 self._backward_event()391 def _backward_hook(self, *args, **kwargs):处理向后事件
396 self._backward_event()上一层将开始计算梯度。我们需要确保它已经完成了参数的获取。
399 torch.cuda.current_stream().wait_stream(self.fetch_stream)402 return None404 @torch.no_grad()
405 def _backup_grads(self):如果没有可训练的参数,则跳过
410 if self.chunk_size[self.TRAINING_PARAMS_IDX] == 0:
411 return使用备份流备份渐变
414 with torch.cuda.stream(self.backup_stream):用于存储渐变的缓冲区
416 buffer = self._empty((self.world_size * self.chunk_size[self.TRAINING_PARAMS_IDX],))将连续缓冲区拆分为多个节点。这些拆分是 “缓冲区” 的视图。
418 buffers = list(buffer.split(self.chunk_size[self.TRAINING_PARAMS_IDX]))连续缓冲区的偏移量
421 offset = 0遍历可训练的参数
423 for p in self.param_refs[self.TRAINING_PARAMS_IDX]:收集渐变
425 shape = p._orig_shape # type: ignore[attr-defined]
426 buffer[offset: offset + shape.numel()] = p.grad.view(-1)更新偏移量
428 offset += shape.numel()清理渐变
430 p.grad = None空张量累积当前分片的梯度
433 grad = self._empty((self.chunk_size[self.TRAINING_PARAMS_IDX],))累积每个分片的梯度。它将缓冲区分散到节点上,每个节点累积(减少)它收到的张量。
436 dist.reduce_scatter(grad, buffers)等待操作完成,然后清除对缓冲区的引用
439 for b in buffers:
440 b.record_stream(self.fetch_stream)
441 buffer.record_stream(self.fetch_stream)
442 del buffer
443 del buffers设置分块渐变。这就是优化器所看到的。
446 self.chunk[self.TRAINING_PARAMS_IDX].grad = grad
447 del gradZero3Layer
层的顺序模块450class Zero3Sequential(nn.Module):modules
Zero3Layer
图层列表454 def __init__(self, modules: List[Zero3Layer]):458 super().__init__()用于获取参数的 CUDA 流
461 self.fetch_stream = torch.cuda.Stream()用于备份(累积)梯度的 CUDA 流
463 self.backup_stream = torch.cuda.Stream()为每个层设置流以及前面和后面的Zero3Layer
层
466 for i in range(len(modules)):设置图层索引
468 modules[i].layer_idx = i设置直播
470 modules[i].fetch_stream = self.fetch_stream
471 modules[i].backup_stream = self.backup_stream设置后续图层
473 if i + 1 < len(modules):
474 modules[i].next_layer.append(modules[i + 1])设置前面的图层
476 if i - 1 >= 0:
477 modules[i].prev_layer.append(modules[i - 1])存储模块清单
480 self.module_list = nn.ModuleList(modules)482 def get_trainable_chunk(self):返回每层可训练区块的列表
484 return sum([m.get_trainable_chunk() for m in self.module_list], [])486 def forward(self, x: torch.Tensor):确保渐变备份已完成
488 torch.cuda.current_stream().wait_stream(self.backup_stream)向前传球
491 for m in self.module_list:
492 x = m(x)495 return x