| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- package mcp
- import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "sync"
- "time"
- "github.com/vektah/gqlparser/v2/ast"
- "gogs.dmsc.dev/arp/auth"
- "gogs.dmsc.dev/arp/graph"
- )
- // Server represents the MCP server
- type Server struct {
- resolver *graph.Resolver
- schema *ast.Schema
- // Session management for SSE
- sessions map[string]*Session
- sessionsMu sync.RWMutex
- }
- // Session represents an SSE client session
- type Session struct {
- ID string
- User *auth.UserContext
- Events chan []byte
- Done chan struct{}
- Subscriptions map[string]context.CancelFunc // URI -> cancel function
- SubsMu sync.RWMutex
- }
- // NewServer creates a new MCP server
- func NewServer(resolver *graph.Resolver, schema *ast.Schema) *Server {
- return &Server{
- resolver: resolver,
- schema: schema,
- sessions: make(map[string]*Session),
- }
- }
- // ServeHTTP handles MCP requests over SSE
- func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // Set SSE headers
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
- // Handle CORS preflight
- if r.Method == "OPTIONS" {
- w.WriteHeader(http.StatusOK)
- return
- }
- // Get user context from request (set by auth middleware)
- userCtx, _ := auth.CurrentUser(r.Context())
- // Create session
- sessionID := generateSessionID()
- session := &Session{
- ID: sessionID,
- User: userCtx,
- Events: make(chan []byte, 100),
- Done: make(chan struct{}),
- Subscriptions: make(map[string]context.CancelFunc),
- }
- // Register session
- s.sessionsMu.Lock()
- s.sessions[sessionID] = session
- s.sessionsMu.Unlock()
- defer func() {
- s.sessionsMu.Lock()
- delete(s.sessions, sessionID)
- s.sessionsMu.Unlock()
- close(session.Done)
- }()
- // Flush helper
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "SSE not supported", http.StatusInternalServerError)
- return
- }
- // Send endpoint event
- endpoint := fmt.Sprintf("/message?sessionId=%s", sessionID)
- fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
- flusher.Flush()
- // Stream events
- for {
- select {
- case event := <-session.Events:
- fmt.Fprintf(w, "event: message\ndata: %s\n\n", string(event))
- flusher.Flush()
- case <-r.Context().Done():
- return
- case <-session.Done:
- return
- }
- }
- }
- // HandleMessage handles incoming JSON-RPC messages
- func (s *Server) HandleMessage(w http.ResponseWriter, r *http.Request) {
- // Set CORS headers
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
- // Handle CORS preflight
- if r.Method == "OPTIONS" {
- w.WriteHeader(http.StatusOK)
- return
- }
- // Get session ID from query
- sessionID := r.URL.Query().Get("sessionId")
- if sessionID == "" {
- s.writeError(w, nil, ErrInvalidParams)
- return
- }
- // Get session
- s.sessionsMu.RLock()
- session, ok := s.sessions[sessionID]
- s.sessionsMu.RUnlock()
- if !ok {
- s.writeError(w, nil, &RPCError{Code: -32001, Message: "Session not found"})
- return
- }
- // Parse request
- var req JSONRPCRequest
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- s.writeError(w, nil, ErrParseError)
- return
- }
- // Handle request
- ctx := r.Context()
- if session.User != nil {
- ctx = auth.WithUser(ctx, session.User)
- }
- response := s.handleRequest(ctx, &req, session)
- // Send response via SSE if there's a session
- if response != nil && session != nil {
- respBytes, _ := json.Marshal(response)
- select {
- case session.Events <- respBytes:
- default:
- // Channel full, skip
- }
- }
- // Also write response to HTTP
- w.WriteHeader(http.StatusAccepted)
- }
- // handleRequest processes a JSON-RPC request
- func (s *Server) handleRequest(ctx context.Context, req *JSONRPCRequest, session *Session) *JSONRPCResponse {
- switch req.Method {
- case "initialize":
- return s.handleInitialize(ctx, req)
- case "notifications/initialized":
- // Notification, no response needed
- return nil
- case "ping":
- return s.handlePing(ctx, req)
- case "tools/list":
- return s.handleToolsList(ctx, req)
- case "tools/call":
- return s.handleToolsCall(ctx, req)
- case "resources/list":
- return s.handleResourcesList(ctx, req)
- case "resources/read":
- return s.handleResourcesRead(ctx, req)
- case "resources/subscribe":
- return s.handleResourcesSubscribe(ctx, req, session)
- case "resources/unsubscribe":
- return s.handleResourcesUnsubscribe(ctx, req, session)
- default:
- return &JSONRPCResponse{
- JSONRPC: "2.0",
- ID: req.ID,
- Error: ErrMethodNotFound,
- }
- }
- }
- // handleInitialize handles the initialize request
- func (s *Server) handleInitialize(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
- var params InitializeParams
- if req.Params != nil {
- if err := json.Unmarshal(req.Params, ¶ms); err != nil {
- return &JSONRPCResponse{
- JSONRPC: "2.0",
- ID: req.ID,
- Error: ErrInvalidParams,
- }
- }
- }
- result := InitializeResult{
- ProtocolVersion: ProtocolVersion,
- Capabilities: ServerCapabilities{
- Tools: &ToolsCapability{ListChanged: false},
- Resources: &ResourcesCapability{Subscribe: true, ListChanged: false},
- },
- ServerInfo: ImplementationInfo{
- Name: "ARP MCP Server",
- Version: "1.0.0",
- },
- Instructions: "Use the introspect tool to discover the GraphQL schema, query tool for read operations, mutate tool for write operations, and resources for real-time subscriptions.",
- }
- return &JSONRPCResponse{
- JSONRPC: "2.0",
- ID: req.ID,
- Result: result,
- }
- }
- // handlePing handles the ping request
- func (s *Server) handlePing(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
- return &JSONRPCResponse{
- JSONRPC: "2.0",
- ID: req.ID,
- Result: PingResult{},
- }
- }
- // writeError writes a JSON-RPC error response
- func (s *Server) writeError(w http.ResponseWriter, id interface{}, err *RPCError) {
- w.WriteHeader(http.StatusBadRequest)
- json.NewEncoder(w).Encode(JSONRPCResponse{
- JSONRPC: "2.0",
- ID: id,
- Error: err,
- })
- }
- // generateSessionID generates a unique session ID
- func generateSessionID() string {
- return fmt.Sprintf("%d", time.Now().UnixNano())
- }
|