Files
2025-09-05 14:40:30 +10:00

213 lines
6.2 KiB
Go

package plg_handler_mcp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
. "github.com/mickael-kerjean/filestash/server/common"
"github.com/mickael-kerjean/filestash/server/model"
. "github.com/mickael-kerjean/filestash/server/plugin/plg_handler_mcp/impl"
. "github.com/mickael-kerjean/filestash/server/plugin/plg_handler_mcp/types"
. "github.com/mickael-kerjean/filestash/server/plugin/plg_handler_mcp/utils"
"github.com/google/uuid"
)
func (this *Server) messageHandler(_ *App, w http.ResponseWriter, r *http.Request) {
sessionID := r.URL.Query().Get("sessionId")
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid Request"))
return
}
request := JSONRPCRequest{}
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("ERR: " + err.Error()))
return
}
this.GetSession(sessionID).Chan <- request
w.WriteHeader(http.StatusNoContent)
}
func (this *Server) sseHandler(_ *App, w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusBadRequest)
return
}
token := this.ValidateToken(r.Header.Get("Authorization"))
if token == "" {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(JSONRPCResponse{
JSONRPC: "2.0",
Error: &JSONRPCError{
Code: http.StatusUnauthorized,
Message: "Missing or invalid access token",
},
})
return
}
userSession := this.GetSession(uuid.New().String())
userSession.Token = token
if b, err := getBackend(userSession.Token); err == nil {
userSession.HomeDir, _ = model.GetHome(b, "/")
userSession.CurrDir = ToString(userSession.HomeDir, "/")
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
fmt.Fprintf(w, "event: endpoint\ndata: %s?sessionId=%s\n\n", "/messages", userSession.Id)
w.(http.Flusher).Flush()
for {
select {
case request := <-userSession.Chan:
b, err := getBackend(userSession.Token)
if err != nil {
if err == ErrNotAuthorized {
err = JSONRPCError{
Code: ErrNotAuthorized.Status(),
Message: "You aren't authenticated",
}
}
SendError(w, request.ID, err)
break
}
userSession.Backend = b
switch request.Method {
case "initialize":
SendMessage(w, request.ID, InitializeResult{
ProtocolVersion: "2024-11-05",
ServerInfo: ServerInfo{
Name: "Universal Storage Server",
Version: "1.0.0",
},
Capabilities: Capabilities{
Tools: map[string]interface{}{
"listChanged": true,
},
Resources: map[string]interface{}{},
Prompts: map[string]interface{}{},
},
})
case "resources/list":
SendMessage(w, request.ID, &CallResourcesList{
Resources: AllResources(),
})
case "resources/templates/list":
SendMessage(w, request.ID, &CallResourceTemplatesList{
ResourceTemplates: AllResourceTemplates(),
})
case "resources/read":
SendMessage(w, request.ID, &CallResourceRead{
Contents: ExecResourceRead(request.Params),
})
case "prompts/list":
SendMessage(w, request.ID, &CallPromptsList{
Prompts: AllPrompts(),
})
case "prompts/get":
if m, ok := request.Params["name"].(string); ok {
res, err := ExecPromptGet(m, request.Params, &userSession)
if err == nil {
SendMessage(w, request.ID, CallPromptGet{
Messages: res,
Description: ExecPromptDescription(request.Params),
})
} else {
SendError(w, request.ID, err)
}
} else {
SendError(w, request.ID, JSONRPCError{
Code: http.StatusBadRequest,
Message: fmt.Sprintf("Unknown prompt name: %v", request.Params["name"]),
})
}
case "tools/list":
SendMessage(w, request.ID, &CallListTools{
Tools: AllTools(),
})
case "tools/call":
if m, ok := request.Params["name"].(string); ok {
res, err := ExecTool(m, request.Params, &userSession)
if err == nil {
SendMessage(w, request.ID, CallTool{
Content: []TextContent{*res},
})
} else {
SendError(w, request.ID, err)
}
} else {
SendError(w, request.ID, JSONRPCError{
Code: http.StatusBadRequest,
Message: fmt.Sprintf("Unknown tool name: %v", request.Params["name"]),
})
}
case "notifications/initialized":
SendMessage(w, request.ID, map[string]string{})
case "completion/complete":
SendMessage(w, request.ID, CallCompletionResult{
Completion: ExecCompletion(request.Params, &userSession),
})
case "ping":
SendMessage(w, request.ID, map[string]string{})
default:
if request.Method == "" && userSession.Ping.ID == request.ID { // response to ping
userSession.Ping.LastResponse = time.Now()
userSession.Ping.ID += 1
} else {
Log.Warning("plg_handler_mcp::sse message=unknown_method method=%s requestID=%d", request.Method, request.ID)
SendError(w, request.ID, JSONRPCError{
Code: http.StatusMethodNotAllowed,
Message: fmt.Sprintf("Unknown request: %s", request.Method),
})
}
}
case <-r.Context().Done():
this.RemoveSession(&userSession)
return
case <-time.After(15 * time.Second):
SendPing(w, userSession.Ping.ID)
if time.Since(userSession.Ping.LastResponse) > 60*time.Second {
SendMethod(w, userSession.Ping.ID+1, "notifications/cancelled", map[string]interface{}{
"requestId": userSession.Ping.ID,
"reason": "Request timed out",
})
time.Sleep(2 * time.Second)
this.RemoveSession(&userSession)
if hi, ok := w.(http.Hijacker); ok {
if conn, rw, err := hi.Hijack(); err == nil {
rw.WriteString("0\r\n\r\n")
rw.Flush()
time.Sleep(1 * time.Second)
conn.Close()
}
}
return
}
}
}
}
func getBackend(token string) (IBackend, error) {
str, err := DecryptString(SECRET_KEY_DERIVATE_FOR_USER, token)
if err != nil {
return nil, ErrNotAuthorized
}
session := map[string]string{}
if err = json.Unmarshal([]byte(str), &session); err != nil {
return nil, err
}
return model.NewBackend(&App{
Context: context.Background(),
}, session)
}