1import torch
2
3from labml.configs import BaseConfigs, hyperparams, option
6class DeviceInfo:
7    def __init__(self, *,
8                 use_cuda: bool,
9                 cuda_device: int):
10        self.use_cuda = use_cuda
11        self.cuda_device = cuda_device
12        self.cuda_count = torch.cuda.device_count()
13
14        self.is_cuda = self.use_cuda and torch.cuda.is_available()
15        if not self.is_cuda:
16            self.device = torch.device('cpu')
17        else:
18            if self.cuda_device < self.cuda_count:
19                self.device = torch.device('cuda', self.cuda_device)
20            else:
21                self.device = torch.device('cuda', self.cuda_count - 1)
23    def __str__(self):
24        if not self.is_cuda:
25            return "CPU"
26
27        if self.cuda_device < self.cuda_count:
28            return f"GPU:{self.cuda_device} - {torch.cuda.get_device_name(self.cuda_device)}"
29        else:
30            return (f"GPU:{self.cuda_count - 1}({self.cuda_device}) "
31                    f"- {torch.cuda.get_device_name(self.cuda_count - 1)}")

This is a configurable module to get a single device to train model on. It can pick up CUDA devices and it will fall back to CPU if they are not available.

It has other small advantages such as being able to view the actual device name on configurations view of labml app <https://github.com/labmlai/labml/tree/master/app> _

Arguments: cuda_device (int): The CUDA device number. Defaults to 0 . use_cuda (bool): Whether to use CUDA devices. Defaults to True .

34class DeviceConfigs(BaseConfigs):
47    cuda_device: int = 0
48    use_cuda: bool = True
49
50    device_info: DeviceInfo
51
52    device: torch.device
54    def __init__(self):
55        super().__init__(_primary='device')
58@option(DeviceConfigs.device)
59def _device(c: DeviceConfigs):
60    return c.device_info.device
61
62
63hyperparams(DeviceConfigs.cuda_device, DeviceConfigs.use_cuda,
64            is_hyperparam=False)
65
66
67@option(DeviceConfigs.device_info)
68def _device_info(c: DeviceConfigs):
69    return DeviceInfo(use_cuda=c.use_cuda,
70                      cuda_device=c.cuda_device)