server.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "time"
  9. "github.com/vektah/gqlparser/v2/ast"
  10. "gogs.dmsc.dev/arp/auth"
  11. "gogs.dmsc.dev/arp/graph"
  12. )
  13. // Server represents the MCP server
  14. type Server struct {
  15. resolver *graph.Resolver
  16. schema *ast.Schema
  17. // Session management for SSE
  18. sessions map[string]*Session
  19. sessionsMu sync.RWMutex
  20. }
  21. // Session represents an SSE client session
  22. type Session struct {
  23. ID string
  24. User *auth.UserContext
  25. Events chan []byte
  26. Done chan struct{}
  27. Subscriptions map[string]context.CancelFunc // URI -> cancel function
  28. SubsMu sync.RWMutex
  29. }
  30. // NewServer creates a new MCP server
  31. func NewServer(resolver *graph.Resolver, schema *ast.Schema) *Server {
  32. return &Server{
  33. resolver: resolver,
  34. schema: schema,
  35. sessions: make(map[string]*Session),
  36. }
  37. }
  38. // ServeHTTP handles MCP requests over SSE
  39. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  40. // Set SSE headers
  41. w.Header().Set("Content-Type", "text/event-stream")
  42. w.Header().Set("Cache-Control", "no-cache")
  43. w.Header().Set("Connection", "keep-alive")
  44. w.Header().Set("Access-Control-Allow-Origin", "*")
  45. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
  46. // Handle CORS preflight
  47. if r.Method == "OPTIONS" {
  48. w.WriteHeader(http.StatusOK)
  49. return
  50. }
  51. // Get user context from request (set by auth middleware)
  52. userCtx, _ := auth.CurrentUser(r.Context())
  53. // Create session
  54. sessionID := generateSessionID()
  55. session := &Session{
  56. ID: sessionID,
  57. User: userCtx,
  58. Events: make(chan []byte, 100),
  59. Done: make(chan struct{}),
  60. Subscriptions: make(map[string]context.CancelFunc),
  61. }
  62. // Register session
  63. s.sessionsMu.Lock()
  64. s.sessions[sessionID] = session
  65. s.sessionsMu.Unlock()
  66. defer func() {
  67. s.sessionsMu.Lock()
  68. delete(s.sessions, sessionID)
  69. s.sessionsMu.Unlock()
  70. close(session.Done)
  71. }()
  72. // Flush helper
  73. flusher, ok := w.(http.Flusher)
  74. if !ok {
  75. http.Error(w, "SSE not supported", http.StatusInternalServerError)
  76. return
  77. }
  78. // Send endpoint event
  79. endpoint := fmt.Sprintf("/message?sessionId=%s", sessionID)
  80. fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
  81. flusher.Flush()
  82. // Stream events
  83. for {
  84. select {
  85. case event := <-session.Events:
  86. fmt.Fprintf(w, "event: message\ndata: %s\n\n", string(event))
  87. flusher.Flush()
  88. case <-r.Context().Done():
  89. return
  90. case <-session.Done:
  91. return
  92. }
  93. }
  94. }
  95. // HandleMessage handles incoming JSON-RPC messages
  96. func (s *Server) HandleMessage(w http.ResponseWriter, r *http.Request) {
  97. // Set CORS headers
  98. w.Header().Set("Content-Type", "application/json")
  99. w.Header().Set("Access-Control-Allow-Origin", "*")
  100. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
  101. // Handle CORS preflight
  102. if r.Method == "OPTIONS" {
  103. w.WriteHeader(http.StatusOK)
  104. return
  105. }
  106. // Get session ID from query
  107. sessionID := r.URL.Query().Get("sessionId")
  108. if sessionID == "" {
  109. s.writeError(w, nil, ErrInvalidParams)
  110. return
  111. }
  112. // Get session
  113. s.sessionsMu.RLock()
  114. session, ok := s.sessions[sessionID]
  115. s.sessionsMu.RUnlock()
  116. if !ok {
  117. s.writeError(w, nil, &RPCError{Code: -32001, Message: "Session not found"})
  118. return
  119. }
  120. // Parse request
  121. var req JSONRPCRequest
  122. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  123. s.writeError(w, nil, ErrParseError)
  124. return
  125. }
  126. // Handle request
  127. ctx := r.Context()
  128. if session.User != nil {
  129. ctx = auth.WithUser(ctx, session.User)
  130. }
  131. response := s.handleRequest(ctx, &req, session)
  132. // Send response via SSE if there's a session
  133. if response != nil && session != nil {
  134. respBytes, _ := json.Marshal(response)
  135. select {
  136. case session.Events <- respBytes:
  137. default:
  138. // Channel full, skip
  139. }
  140. }
  141. // Also write response to HTTP
  142. w.WriteHeader(http.StatusAccepted)
  143. }
  144. // handleRequest processes a JSON-RPC request
  145. func (s *Server) handleRequest(ctx context.Context, req *JSONRPCRequest, session *Session) *JSONRPCResponse {
  146. switch req.Method {
  147. case "initialize":
  148. return s.handleInitialize(ctx, req)
  149. case "notifications/initialized":
  150. // Notification, no response needed
  151. return nil
  152. case "ping":
  153. return s.handlePing(ctx, req)
  154. case "tools/list":
  155. return s.handleToolsList(ctx, req)
  156. case "tools/call":
  157. return s.handleToolsCall(ctx, req)
  158. case "resources/list":
  159. return s.handleResourcesList(ctx, req)
  160. case "resources/read":
  161. return s.handleResourcesRead(ctx, req)
  162. case "resources/subscribe":
  163. return s.handleResourcesSubscribe(ctx, req, session)
  164. case "resources/unsubscribe":
  165. return s.handleResourcesUnsubscribe(ctx, req, session)
  166. default:
  167. return &JSONRPCResponse{
  168. JSONRPC: "2.0",
  169. ID: req.ID,
  170. Error: ErrMethodNotFound,
  171. }
  172. }
  173. }
  174. // handleInitialize handles the initialize request
  175. func (s *Server) handleInitialize(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
  176. var params InitializeParams
  177. if req.Params != nil {
  178. if err := json.Unmarshal(req.Params, &params); err != nil {
  179. return &JSONRPCResponse{
  180. JSONRPC: "2.0",
  181. ID: req.ID,
  182. Error: ErrInvalidParams,
  183. }
  184. }
  185. }
  186. result := InitializeResult{
  187. ProtocolVersion: ProtocolVersion,
  188. Capabilities: ServerCapabilities{
  189. Tools: &ToolsCapability{ListChanged: false},
  190. Resources: &ResourcesCapability{Subscribe: true, ListChanged: false},
  191. },
  192. ServerInfo: ImplementationInfo{
  193. Name: "ARP MCP Server",
  194. Version: "1.0.0",
  195. },
  196. Instructions: "Use the introspect tool to discover the GraphQL schema, query tool for read operations, mutate tool for write operations, and resources for real-time subscriptions.",
  197. }
  198. return &JSONRPCResponse{
  199. JSONRPC: "2.0",
  200. ID: req.ID,
  201. Result: result,
  202. }
  203. }
  204. // handlePing handles the ping request
  205. func (s *Server) handlePing(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
  206. return &JSONRPCResponse{
  207. JSONRPC: "2.0",
  208. ID: req.ID,
  209. Result: PingResult{},
  210. }
  211. }
  212. // writeError writes a JSON-RPC error response
  213. func (s *Server) writeError(w http.ResponseWriter, id interface{}, err *RPCError) {
  214. w.WriteHeader(http.StatusBadRequest)
  215. json.NewEncoder(w).Encode(JSONRPCResponse{
  216. JSONRPC: "2.0",
  217. ID: id,
  218. Error: err,
  219. })
  220. }
  221. // generateSessionID generates a unique session ID
  222. func generateSessionID() string {
  223. return fmt.Sprintf("%d", time.Now().UnixNano())
  224. }