mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			71 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			71 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| 
 | |
| from labml.configs import BaseConfigs, hyperparams, option
 | |
| 
 | |
| 
 | |
| class DeviceInfo:
 | |
|     def __init__(self, *,
 | |
|                  use_cuda: bool,
 | |
|                  cuda_device: int):
 | |
|         self.use_cuda = use_cuda
 | |
|         self.cuda_device = cuda_device
 | |
|         self.cuda_count = torch.cuda.device_count()
 | |
| 
 | |
|         self.is_cuda = self.use_cuda and torch.cuda.is_available()
 | |
|         if not self.is_cuda:
 | |
|             self.device = torch.device('cpu')
 | |
|         else:
 | |
|             if self.cuda_device < self.cuda_count:
 | |
|                 self.device = torch.device('cuda', self.cuda_device)
 | |
|             else:
 | |
|                 self.device = torch.device('cuda', self.cuda_count - 1)
 | |
| 
 | |
|     def __str__(self):
 | |
|         if not self.is_cuda:
 | |
|             return "CPU"
 | |
| 
 | |
|         if self.cuda_device < self.cuda_count:
 | |
|             return f"GPU:{self.cuda_device} - {torch.cuda.get_device_name(self.cuda_device)}"
 | |
|         else:
 | |
|             return (f"GPU:{self.cuda_count - 1}({self.cuda_device}) "
 | |
|                     f"- {torch.cuda.get_device_name(self.cuda_count - 1)}")
 | |
| 
 | |
| 
 | |
| class DeviceConfigs(BaseConfigs):
 | |
|     r"""
 | |
|     This is a configurable module to get a single device to train model on.
 | |
|     It can pick up CUDA devices and it will fall back to CPU if they are not available.
 | |
| 
 | |
|     It has other small advantages such as being able to view the
 | |
|     actual device name on configurations view of
 | |
|     `labml app <https://github.com/labmlai/labml/tree/master/app>`_
 | |
| 
 | |
|     Arguments:
 | |
|         cuda_device (int): The CUDA device number. Defaults to ``0``.
 | |
|         use_cuda (bool): Whether to use CUDA devices. Defaults to ``True``.
 | |
|     """
 | |
|     cuda_device: int = 0
 | |
|     use_cuda: bool = True
 | |
| 
 | |
|     device_info: DeviceInfo
 | |
| 
 | |
|     device: torch.device
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__init__(_primary='device')
 | |
| 
 | |
| 
 | |
| @option(DeviceConfigs.device)
 | |
| def _device(c: DeviceConfigs):
 | |
|     return c.device_info.device
 | |
| 
 | |
| 
 | |
| hyperparams(DeviceConfigs.cuda_device, DeviceConfigs.use_cuda,
 | |
|             is_hyperparam=False)
 | |
| 
 | |
| 
 | |
| @option(DeviceConfigs.device_info)
 | |
| def _device_info(c: DeviceConfigs):
 | |
|     return DeviceInfo(use_cuda=c.use_cuda,
 | |
|                       cuda_device=c.cuda_device)
 | 
