package mcp import ( "context" "encoding/json" "fmt" "net/http" "sync" "time" "github.com/vektah/gqlparser/v2/ast" "gogs.dmsc.dev/arp/auth" "gogs.dmsc.dev/arp/graph" ) // Server represents the MCP server type Server struct { resolver *graph.Resolver schema *ast.Schema // Session management for SSE sessions map[string]*Session sessionsMu sync.RWMutex } // Session represents an SSE client session type Session struct { ID string User *auth.UserContext Events chan []byte Done chan struct{} Subscriptions map[string]context.CancelFunc // URI -> cancel function SubsMu sync.RWMutex } // NewServer creates a new MCP server func NewServer(resolver *graph.Resolver, schema *ast.Schema) *Server { return &Server{ resolver: resolver, schema: schema, sessions: make(map[string]*Session), } } // ServeHTTP handles MCP requests over SSE func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") // Handle CORS preflight if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // Get user context from request (set by auth middleware) userCtx, _ := auth.CurrentUser(r.Context()) // Create session sessionID := generateSessionID() session := &Session{ ID: sessionID, User: userCtx, Events: make(chan []byte, 100), Done: make(chan struct{}), Subscriptions: make(map[string]context.CancelFunc), } // Register session s.sessionsMu.Lock() s.sessions[sessionID] = session s.sessionsMu.Unlock() defer func() { s.sessionsMu.Lock() delete(s.sessions, sessionID) s.sessionsMu.Unlock() close(session.Done) }() // Flush helper flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "SSE not supported", http.StatusInternalServerError) return } // Send endpoint event endpoint := fmt.Sprintf("/message?sessionId=%s", sessionID) fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint) flusher.Flush() // Stream events for { select { case event := <-session.Events: fmt.Fprintf(w, "event: message\ndata: %s\n\n", string(event)) flusher.Flush() case <-r.Context().Done(): return case <-session.Done: return } } } // HandleMessage handles incoming JSON-RPC messages func (s *Server) HandleMessage(w http.ResponseWriter, r *http.Request) { // Set CORS headers w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") // Handle CORS preflight if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // Get session ID from query sessionID := r.URL.Query().Get("sessionId") if sessionID == "" { s.writeError(w, nil, ErrInvalidParams) return } // Get session s.sessionsMu.RLock() session, ok := s.sessions[sessionID] s.sessionsMu.RUnlock() if !ok { s.writeError(w, nil, &RPCError{Code: -32001, Message: "Session not found"}) return } // Parse request var req JSONRPCRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.writeError(w, nil, ErrParseError) return } // Handle request ctx := r.Context() if session.User != nil { ctx = auth.WithUser(ctx, session.User) } response := s.handleRequest(ctx, &req, session) // Send response via SSE if there's a session if response != nil && session != nil { respBytes, _ := json.Marshal(response) select { case session.Events <- respBytes: default: // Channel full, skip } } // Also write response to HTTP w.WriteHeader(http.StatusAccepted) } // handleRequest processes a JSON-RPC request func (s *Server) handleRequest(ctx context.Context, req *JSONRPCRequest, session *Session) *JSONRPCResponse { switch req.Method { case "initialize": return s.handleInitialize(ctx, req) case "notifications/initialized": // Notification, no response needed return nil case "ping": return s.handlePing(ctx, req) case "tools/list": return s.handleToolsList(ctx, req) case "tools/call": return s.handleToolsCall(ctx, req) case "resources/list": return s.handleResourcesList(ctx, req) case "resources/read": return s.handleResourcesRead(ctx, req) case "resources/subscribe": return s.handleResourcesSubscribe(ctx, req, session) case "resources/unsubscribe": return s.handleResourcesUnsubscribe(ctx, req, session) default: return &JSONRPCResponse{ JSONRPC: "2.0", ID: req.ID, Error: ErrMethodNotFound, } } } // handleInitialize handles the initialize request func (s *Server) handleInitialize(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse { var params InitializeParams if req.Params != nil { if err := json.Unmarshal(req.Params, ¶ms); err != nil { return &JSONRPCResponse{ JSONRPC: "2.0", ID: req.ID, Error: ErrInvalidParams, } } } result := InitializeResult{ ProtocolVersion: ProtocolVersion, Capabilities: ServerCapabilities{ Tools: &ToolsCapability{ListChanged: false}, Resources: &ResourcesCapability{Subscribe: true, ListChanged: false}, }, ServerInfo: ImplementationInfo{ Name: "ARP MCP Server", Version: "1.0.0", }, Instructions: "Use the introspect tool to discover the GraphQL schema, query tool for read operations, mutate tool for write operations, and resources for real-time subscriptions.", } return &JSONRPCResponse{ JSONRPC: "2.0", ID: req.ID, Result: result, } } // handlePing handles the ping request func (s *Server) handlePing(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse { return &JSONRPCResponse{ JSONRPC: "2.0", ID: req.ID, Result: PingResult{}, } } // writeError writes a JSON-RPC error response func (s *Server) writeError(w http.ResponseWriter, id interface{}, err *RPCError) { w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(JSONRPCResponse{ JSONRPC: "2.0", ID: id, Error: err, }) } // generateSessionID generates a unique session ID func generateSessionID() string { return fmt.Sprintf("%d", time.Now().UnixNano()) }