Files
2023-07-26 17:05:50 +02:00

362 lines
13 KiB
Python

from typing import Any, Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, validator
from starlette.datastructures import FormData
from starlette.requests import Request
from tortoise import ForeignKeyFieldInstance, ManyToManyFieldInstance
from tortoise import Model as TortoiseModel
from tortoise.fields import BooleanField, DateField, DatetimeField, JSONField
from tortoise.fields.data import CharEnumFieldInstance, IntEnumFieldInstance, IntField, TextField
from tortoise.queryset import QuerySet
from fastapi_admin.enums import Method
from fastapi_admin.exceptions import NoSuchFieldFound
from fastapi_admin.i18n import _
from fastapi_admin.widgets import Widget, displays, inputs
from fastapi_admin.widgets.filters import Filter, Search
class Resource:
"""
Base Resource
"""
label: str
icon: str = ""
class Link(Resource):
url: str
target: str = "_self"
class Field:
name: str
label: str
display: displays.Display
input: inputs.Input
def __init__(
self,
name: str,
label: Optional[str] = None,
display: Optional[displays.Display] = None,
input_: Optional[Widget] = None,
):
self.name = name
self.label = label or name.title()
if not display:
display = displays.Display()
display.context.update(label=self.label)
self.display = display
if not input_:
input_ = inputs.Input()
input_.context.update(label=self.label, name=name)
self.input = input_
class ComputeField(Field):
async def get_value(self, request: Request, obj: dict):
return obj.get(self.name)
class Action(BaseModel):
icon: str
label: str
name: str
method: Method = Method.POST
ajax: bool = True
@validator("ajax")
def ajax_validate(cls, v: bool, values: dict, **kwargs):
if not v and values["method"] != Method.GET:
raise ValueError("ajax is False only available when method is Method.GET")
class ToolbarAction(Action):
class_: Optional[str]
class Model(Resource):
model: Type[TortoiseModel]
fields: List[Union[str, Field, ComputeField]] = []
page_size: int = 10
page_pre_title: Optional[str] = None
page_title: Optional[str] = None
filters: List[Union[str, Filter]] = []
async def get_toolbar_actions(self, request: Request) -> List[ToolbarAction]:
return [
ToolbarAction(
label=_("create"),
icon="fas fa-plus",
name="create",
method=Method.GET,
ajax=False,
class_="btn-dark",
)
]
async def row_attributes(self, request: Request, obj: dict) -> dict:
return {}
async def column_attributes(self, request: Request, field: Field) -> dict:
return {}
async def cell_attributes(self, request: Request, obj: dict, field: Field) -> dict:
return {}
async def get_actions(self, request: Request) -> List[Action]:
return [
Action(
label=_("update"), icon="ti ti-edit", name="update", method=Method.GET, ajax=False
),
Action(label=_("delete"), icon="ti ti-trash", name="delete", method=Method.DELETE),
]
async def get_bulk_actions(self, request: Request) -> List[Action]:
return [
Action(
label=_("delete_selected"),
icon="ti ti-trash",
name="delete",
method=Method.DELETE,
),
]
@classmethod
async def get_inputs(cls, request: Request, obj: Optional[TortoiseModel] = None):
ret = []
for field in cls.get_fields(is_display=False):
input_ = field.input
name = input_.context.get("name")
if isinstance(input_, inputs.DisplayOnly):
continue
if isinstance(input_, inputs.File):
cls.enctype = "multipart/form-data"
if (
isinstance(input_, inputs.ForeignKey)
and (obj is not None)
and name in obj._meta.fk_fields
):
await obj.fetch_related(name)
# Value must be the string representation of the fk obj
value = str(getattr(obj, name, None))
ret.append(await input_.render(request, value))
continue
ret.append(await input_.render(request, getattr(obj, name, None)))
return ret
@classmethod
async def resolve_query_params(cls, request: Request, values: dict, qs: QuerySet):
ret = {}
for f in cls.filters:
if isinstance(f, str):
f = Search(name=f, label=f.title())
name = f.context.get("name")
v = values.get(name)
if v is not None and v != "":
ret[name] = await f.parse_value(request, v)
qs = await f.get_queryset(request, v, qs)
return ret, qs
@classmethod
async def resolve_data(cls, request: Request, data: FormData):
ret = {}
m2m_ret = {}
for field in cls.get_fields(is_display=False):
input_ = field.input
if input_.context.get("disabled") or isinstance(input_, inputs.DisplayOnly):
continue
name = input_.context.get("name")
if isinstance(input_, inputs.ForeignKey):
v = data.getlist(name)[0]
ret[name] = int(v) if v else None
continue
if isinstance(input_, inputs.ManyToMany):
v = data.getlist(name)
value = await input_.parse_value(request, v)
m2m_ret[name] = await input_.model.filter(pk__in=value)
else:
v = data.get(name)
value = await input_.parse_value(request, v)
if value is None:
continue
ret[name] = value
return ret, m2m_ret
@classmethod
async def get_filters(cls, request: Request, values: Optional[dict] = None):
if not values:
values = {}
ret = []
for f in cls.filters:
if isinstance(f, str):
f = Search(name=f, label=f.title())
name = f.context.get("name")
value = values.get(name)
ret.append(await f.render(request, value))
return ret
@classmethod
def _get_fields_attr(cls, attr: str, display: bool = True):
ret = []
for field in cls.get_fields():
if display and isinstance(field.display, displays.InputOnly):
continue
ret.append(getattr(field, attr))
return ret or cls.model._meta.db_fields
@classmethod
def get_fields_name(cls, display: bool = True):
return cls._get_fields_attr("name", display)
@classmethod
def _get_display_input_field(cls, field_name: str) -> Field:
fields_map = cls.model._meta.fields_map
field = fields_map.get(field_name)
if not field:
raise NoSuchFieldFound(f"Can't found field '{field_name}' in model {cls.model}")
label = field_name
null = field.null
placeholder = field.description or ""
display, input_ = displays.Display(), inputs.Input(
placeholder=placeholder, null=null, default=field.default
)
if field.pk or field.generated:
display, input_ = displays.Display(), inputs.DisplayOnly()
elif isinstance(field, BooleanField):
display, input_ = displays.Boolean(), inputs.Switch(null=null, default=field.default)
elif isinstance(field, DatetimeField):
if field.auto_now or field.auto_now_add:
input_ = inputs.DisplayOnly()
else:
input_ = inputs.DateTime(null=null, default=field.default)
display, input_ = displays.DatetimeDisplay(), input_
elif isinstance(field, DateField):
display, input_ = displays.DateDisplay(), inputs.Date(null=null, default=field.default)
elif isinstance(field, IntEnumFieldInstance):
display, input_ = displays.Display(), inputs.Enum(
field.enum_type, null=null, default=field.default
)
elif isinstance(field, CharEnumFieldInstance):
display, input_ = displays.Display(), inputs.Enum(
field.enum_type, enum_type=str, null=null, default=field.default
)
elif isinstance(field, JSONField):
display, input_ = displays.Json(), inputs.Json(null=null)
elif isinstance(field, TextField):
display, input_ = displays.Display(), inputs.TextArea(
placeholder=placeholder, null=null, default=field.default
)
elif isinstance(field, IntField):
display, input_ = displays.Display(), inputs.Number(
placeholder=placeholder, null=null, default=field.default
)
elif isinstance(field, ForeignKeyFieldInstance):
display, input_ = displays.Display(), inputs.ForeignKey(
field.related_model, null=null, default=field.default
)
field_name = field.source_field
elif isinstance(field, ManyToManyFieldInstance):
display, input_ = displays.InputOnly(), inputs.ManyToMany(field.related_model)
return Field(name=field_name, label=label.title(), display=display, input_=input_)
@classmethod
def get_fields(cls, is_display: bool = True):
ret = []
pk_column = cls.model._meta.db_pk_column
for field in cls.fields or cls.model._meta.fields:
if isinstance(field, str):
if field == pk_column:
continue
field = cls._get_display_input_field(field)
if isinstance(field, ComputeField) and not is_display:
continue
elif isinstance(field, Field):
if field.name == pk_column:
continue
if (is_display and isinstance(field.display, displays.InputOnly)) or (
not is_display and isinstance(field.input, inputs.DisplayOnly)
):
continue
if (
field.name in cls.model._meta.fetch_fields
and field.name not in cls.model._meta.fk_fields | cls.model._meta.m2m_fields
):
continue
ret.append(field)
ret.insert(0, cls._get_display_input_field(pk_column))
return ret
@classmethod
def get_fields_label(cls, display: bool = True):
return cls._get_fields_attr("label", display)
@classmethod
def get_m2m_field(cls):
ret = []
for field in cls.fields or cls.model._meta.fields:
if isinstance(field, Field):
field = field.name
if field in cls.model._meta.m2m_fields:
ret.append(field)
return ret
@classmethod
def get_fk_field(cls):
ret = []
for field in cls.fields or cls.model._meta.fields:
if isinstance(field, Field):
field = field.name
if field in cls.model._meta.fk_fields:
ret.append(field)
return ret
class Dropdown(Resource):
resources: List[Type[Resource]]
async def render_values(
request: Request,
model: "Model",
fields: List["Field"],
values: List[Dict[str, Any]],
display: bool = True,
) -> Tuple[List[List[Any]], List[dict], List[dict], List[List[dict]]]:
"""
render values with template render
:params model:
:params request:
:params fields:
:params values:
:params display:
:params request:
:params model:
:return:
"""
ret = []
cell_attributes: List[List[dict]] = []
row_attributes: List[dict] = []
column_attributes: List[dict] = []
for field in fields:
column_attributes.append(await model.column_attributes(request, field))
for value in values:
row_attributes.append(await model.row_attributes(request, value))
item = []
cell_item = []
for field in fields:
if isinstance(field, ComputeField):
v = await field.get_value(request, value)
else:
v = value.get(field.name)
cell_item.append(await model.cell_attributes(request, value, field))
if display:
item.append(await field.display.render(request, v))
else:
item.append(await field.input.render(request, v))
ret.append(item)
cell_attributes.append(cell_item)
return ret, row_attributes, column_attributes, cell_attributes