diff --git a/tutorials/03-advanced/image_captioning/sample.py b/tutorials/03-advanced/image_captioning/sample.py index 23e07ef..b41cc6f 100644 --- a/tutorials/03-advanced/image_captioning/sample.py +++ b/tutorials/03-advanced/image_captioning/sample.py @@ -14,7 +14,7 @@ from PIL import Image device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_image(image_path, transform=None): - image = Image.open(image_path) + image = Image.open(image_path).convert('RGB') image = image.resize([224, 224], Image.LANCZOS) if transform is not None: @@ -78,4 +78,4 @@ if __name__ == '__main__': parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states') parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm') args = parser.parse_args() - main(args) \ No newline at end of file + main(args)