mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
cleanup some unused imports
This commit is contained in:
@ -41,7 +41,6 @@ import numpy as np
|
|||||||
from labml import experiment
|
from labml import experiment
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml_nn.cfr import History as _History, InfoSet as _InfoSet, Action, Player, CFRConfigs
|
from labml_nn.cfr import History as _History, InfoSet as _InfoSet, Action, Player, CFRConfigs
|
||||||
from labml_nn.cfr.infoset_saver import InfoSetSaver
|
|
||||||
|
|
||||||
# Kuhn poker actions are pass (`p`) or bet (`b`)
|
# Kuhn poker actions are pass (`p`) or bet (`b`)
|
||||||
ACTIONS = cast(List[Action], ['p', 'b'])
|
ACTIONS = cast(List[Action], ['p', 'b'])
|
||||||
|
@ -71,11 +71,7 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"\n",
|
|
||||||
"from labml import experiment\n",
|
"from labml import experiment\n",
|
||||||
"from labml.configs import option\n",
|
|
||||||
"from labml_nn.diffusion.ddpm.experiment import Configs"
|
"from labml_nn.diffusion.ddpm.experiment import Configs"
|
||||||
],
|
],
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
@ -68,11 +68,7 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"\n",
|
|
||||||
"from labml import experiment\n",
|
"from labml import experiment\n",
|
||||||
"from labml.configs import option\n",
|
|
||||||
"from labml_nn.normalization.deep_norm.experiment import Configs"
|
"from labml_nn.normalization.deep_norm.experiment import Configs"
|
||||||
],
|
],
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
@ -53,9 +53,6 @@
|
|||||||
"id": "0hJXx_g0wS2C"
|
"id": "0hJXx_g0wS2C"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"\n",
|
|
||||||
"from labml import experiment\n",
|
"from labml import experiment\n",
|
||||||
"from labml_nn.normalization.weight_standardization.experiment import CIFAR10Configs as Configs"
|
"from labml_nn.normalization.weight_standardization.experiment import CIFAR10Configs as Configs"
|
||||||
],
|
],
|
||||||
|
@ -26,8 +26,8 @@ The experiment uses [Generalized Advantage Estimation](gae.html).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from labml_nn.rl.ppo.gae import GAE
|
from labml_nn.rl.ppo.gae import GAE
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class ClippedPPOLoss(nn.Module):
|
class ClippedPPOLoss(nn.Module):
|
||||||
@ -178,7 +178,7 @@ class ClippedPPOLoss(nn.Module):
|
|||||||
return -policy_reward.mean()
|
return -policy_reward.mean()
|
||||||
|
|
||||||
|
|
||||||
class ClippedValueFunctionLoss(Module):
|
class ClippedValueFunctionLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
## Clipped Value Function Loss
|
## Clipped Value Function Loss
|
||||||
|
|
||||||
|
@ -71,11 +71,7 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"\n",
|
|
||||||
"from labml import experiment\n",
|
"from labml import experiment\n",
|
||||||
"from labml.configs import option\n",
|
|
||||||
"from labml_nn.transformers.basic.autoregressive_experiment import Configs"
|
"from labml_nn.transformers.basic.autoregressive_experiment import Configs"
|
||||||
],
|
],
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
@ -67,11 +67,7 @@
|
|||||||
"id": "0hJXx_g0wS2C"
|
"id": "0hJXx_g0wS2C"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"\n",
|
|
||||||
"from labml import experiment\n",
|
"from labml import experiment\n",
|
||||||
"from labml.configs import option\n",
|
|
||||||
"from labml_nn.transformers.fast_weights.experiment import Configs"
|
"from labml_nn.transformers.fast_weights.experiment import Configs"
|
||||||
],
|
],
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -138,21 +134,21 @@
|
|||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"experiment.configs(conf,\n",
|
"experiment.configs(conf,\n",
|
||||||
" # A dictionary of configurations to override\n",
|
" # A dictionary of configurations to override\n",
|
||||||
" {'tokenizer': 'character',\n",
|
" {'tokenizer': 'character',\n",
|
||||||
" 'text': 'tiny_shakespeare',\n",
|
" 'text': 'tiny_shakespeare',\n",
|
||||||
" 'optimizer.learning_rate': 1.0,\n",
|
" 'optimizer.learning_rate': 1.0,\n",
|
||||||
" 'optimizer.optimizer': 'Noam',\n",
|
" 'optimizer.optimizer': 'Noam',\n",
|
||||||
" 'prompt': 'It is',\n",
|
" 'prompt': 'It is',\n",
|
||||||
" 'prompt_separator': '',\n",
|
" 'prompt_separator': '',\n",
|
||||||
"\n",
|
"\n",
|
||||||
" 'train_loader': 'shuffled_train_loader',\n",
|
" 'train_loader': 'shuffled_train_loader',\n",
|
||||||
" 'valid_loader': 'shuffled_valid_loader',\n",
|
" 'valid_loader': 'shuffled_valid_loader',\n",
|
||||||
"\n",
|
"\n",
|
||||||
" 'seq_len': 128,\n",
|
" 'seq_len': 128,\n",
|
||||||
" 'epochs': 128,\n",
|
" 'epochs': 128,\n",
|
||||||
" 'batch_size': 16,\n",
|
" 'batch_size': 16,\n",
|
||||||
" 'inner_iterations': 25})"
|
" 'inner_iterations': 25})"
|
||||||
],
|
],
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"execution_count": null
|
"execution_count": null
|
||||||
|
@ -13,13 +13,12 @@ You can find the download instructions
|
|||||||
Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
|
Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch.utils.data
|
|
||||||
import torchvision.transforms.functional
|
import torchvision.transforms.functional
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch.utils.data
|
||||||
from labml import lab
|
from labml import lab
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user