diff --git a/server/ctrl/session.go b/server/ctrl/session.go index 21468258..da19ebc9 100644 --- a/server/ctrl/session.go +++ b/server/ctrl/session.go @@ -377,21 +377,7 @@ func SessionAuthMiddleware(ctx *App, res http.ResponseWriter, req *http.Request) } tmpl, err := template. New("ctrl::session::auth_middleware"). - Funcs(map[string]interface{}{ - "contains": func(str string, match string) bool { - splits := strings.Split(str, ",") - for _, split := range splits { - if split == match { - return true - } - } - return false - }, - "encryptGCM": func(str string, key string) (string, error) { - data, err := EncryptAESGCM([]byte(key), []byte(str)) - return base64.StdEncoding.EncodeToString(data), err - }, - }). + Funcs(tmplFuncs). Parse(str) mappingToUse[k] = str if err != nil { diff --git a/server/ctrl/tmpl.go b/server/ctrl/tmpl.go new file mode 100644 index 00000000..ad5802ab --- /dev/null +++ b/server/ctrl/tmpl.go @@ -0,0 +1,73 @@ +package ctrl + +import ( + "encoding/base64" + "strings" + "text/template" + + . "github.com/mickael-kerjean/filestash/server/common" +) + +var tmplFuncs = template.FuncMap{ + "split": func(s, sep string) []string { + return strings.Split(sep, s) + }, + "get": func(i int, arr any) (string, error) { + switch v := arr.(type) { + case string: + splits := strings.Split(v, ",") + if i < len(splits) && i >= 0 { + return strings.TrimSpace(splits[i]), nil + } + return "", nil + case []string: + if i < len(v) && i >= 0 { + return v[i], nil + } + return "", nil + default: + return "", ErrNotImplemented + } + }, + "contains": func(match string, opts ...any) (bool, error) { + exact := true + var input any + if len(opts) == 0 { + return false, ErrNotValid + } else if len(opts) == 1 { + input = opts[0] + } else if len(opts) == 2 { + exact = opts[0].(bool) + input = opts[1] + } + switch in := input.(type) { + case string: + splits := strings.Split(in, ",") + for _, split := range splits { + split = strings.TrimSpace(split) + if exact && split == match { + return true, nil + } else if !exact && strings.Contains(split, match) { + return true, nil + } + } + return false, nil + case []string: + for _, split := range in { + split = strings.TrimSpace(split) + if exact && split == match { + return true, nil + } else if !exact && strings.Contains(split, match) { + return true, nil + } + } + return false, nil + default: + return false, ErrNotImplemented + } + }, + "encryptGCM": func(str string, key string) (string, error) { + data, err := EncryptAESGCM([]byte(key), []byte(str)) + return base64.StdEncoding.EncodeToString(data), err + }, +}