Fix Image input.

This commit is contained in:
long2ice
2021-06-18 11:10:15 +08:00
parent 007aa37634
commit 2eb9abe124
4 changed files with 23 additions and 16 deletions

View File

@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import Callable, List, Optional
import aiofiles
from starlette.datastructures import UploadFile
@ -11,32 +11,36 @@ class FileUpload:
def __init__(
self,
uploads_dir: str,
all_extensions: Optional[List[str]] = None,
prefix: str = "/static/uploads",
allow_extensions: Optional[List[str]] = None,
max_size: int = 1024 ** 3,
filename_generator: Optional[Callable] = None,
prefix: str = "/static/uploads",
):
self.max_size = max_size
self.all_extensions = all_extensions
self.allow_extensions = allow_extensions
self.uploads_dir = uploads_dir
self.filename_generator = filename_generator
self.prefix = prefix
def get_file_name(self, file: UploadFile):
return file.filename
async def save_file(self, filename: str, content: bytes):
file = os.path.join(self.uploads_dir, filename)
async with aiofiles.open(file, "wb") as f:
await f.write(content)
return os.path.join(self.prefix, filename)
async def upload(self, file: UploadFile):
filename = self.get_file_name(file)
if not filename:
return
if self.filename_generator:
filename = self.filename_generator(file)
else:
filename = file.filename
content = await file.read()
file_size = len(content)
if file_size > self.max_size:
raise FileMaxSizeLimit(f"File size {file_size} exceeds max size {self.max_size}")
if self.all_extensions:
for ext in self.all_extensions:
if self.allow_extensions:
for ext in self.allow_extensions:
if filename.endswith(ext):
raise FileExtNotAllowed(
f"File ext {ext} is not allowed of {self.all_extensions}"
f"File ext {ext} is not allowed of {self.allow_extensions}"
)
async with aiofiles.open(os.path.join(self.uploads_dir, filename), "wb") as f:
await f.write(content)
return os.path.join(self.prefix, filename)
return await self.save_file(filename, content)