handler.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "gogs.dmsc.dev/arp/auth"
  6. "gogs.dmsc.dev/arp/graph"
  7. "gogs.dmsc.dev/arp/mcp/tools"
  8. )
  9. // handleToolsList returns the list of available tools
  10. func (s *Server) handleToolsList(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
  11. toolList := []Tool{
  12. {
  13. Name: "introspect",
  14. Description: "Get GraphQL schema information - types, fields, queries, mutations. Use this to discover the API structure before making queries or mutations.",
  15. InputSchema: InputSchema{
  16. Type: "object",
  17. AdditionalProperties: false,
  18. Properties: map[string]Property{
  19. "typeName": {
  20. Type: "string",
  21. Description: "Optional - specific type to introspect (e.g., 'Query', 'Mutation', 'User', 'Task'). If omitted, returns full schema overview.",
  22. },
  23. },
  24. },
  25. },
  26. {
  27. Name: "query",
  28. Description: "Execute GraphQL queries (read operations). Use for fetching data from the API. The query must be a valid GraphQL query string.",
  29. InputSchema: InputSchema{
  30. Type: "object",
  31. AdditionalProperties: false,
  32. Properties: map[string]Property{
  33. "query": {
  34. Type: "string",
  35. Description: "GraphQL query string (e.g., 'query { users { id email } }')",
  36. },
  37. "variables": {
  38. Type: "object",
  39. Description: "Optional query variables as key-value pairs",
  40. },
  41. },
  42. Required: []string{"query"},
  43. },
  44. },
  45. {
  46. Name: "mutate",
  47. Description: "Execute GraphQL mutations (create/update/delete operations). Use for modifying data in the API. The mutation must be a valid GraphQL mutation string.",
  48. InputSchema: InputSchema{
  49. Type: "object",
  50. AdditionalProperties: false,
  51. Properties: map[string]Property{
  52. "mutation": {
  53. Type: "string",
  54. Description: "GraphQL mutation string (e.g., 'mutation { createUser(input: {email: \"test@example.com\", password: \"pass\", roles: []}) { id } }')",
  55. },
  56. "variables": {
  57. Type: "object",
  58. Description: "Optional mutation variables as key-value pairs",
  59. },
  60. },
  61. Required: []string{"mutation"},
  62. },
  63. },
  64. }
  65. return &JSONRPCResponse{
  66. JSONRPC: "2.0",
  67. ID: req.ID,
  68. Result: ListToolsResult{Tools: toolList},
  69. }
  70. }
  71. // handleToolsCall executes a tool call
  72. func (s *Server) handleToolsCall(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
  73. var params CallToolParams
  74. if req.Params != nil {
  75. if err := json.Unmarshal(req.Params, &params); err != nil {
  76. return &JSONRPCResponse{
  77. JSONRPC: "2.0",
  78. ID: req.ID,
  79. Error: ErrInvalidParams,
  80. }
  81. }
  82. }
  83. var result tools.CallToolResult
  84. var err error
  85. switch params.Name {
  86. case "introspect":
  87. result, err = tools.Introspect(ctx, s.schema, params.Arguments)
  88. case "query":
  89. result, err = tools.Query(ctx, s.resolver, s.schema, params.Arguments)
  90. case "mutate":
  91. result, err = tools.Mutate(ctx, s.resolver, s.schema, params.Arguments)
  92. default:
  93. return &JSONRPCResponse{
  94. JSONRPC: "2.0",
  95. ID: req.ID,
  96. Error: ErrMethodNotFound,
  97. }
  98. }
  99. if err != nil {
  100. return &JSONRPCResponse{
  101. JSONRPC: "2.0",
  102. ID: req.ID,
  103. Result: tools.CallToolResult{
  104. Content: []tools.ContentBlock{
  105. {Type: "text", Text: err.Error()},
  106. },
  107. IsError: true,
  108. },
  109. }
  110. }
  111. return &JSONRPCResponse{
  112. JSONRPC: "2.0",
  113. ID: req.ID,
  114. Result: result,
  115. }
  116. }
  117. // handleResourcesList returns the list of available subscription resources
  118. func (s *Server) handleResourcesList(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
  119. resources := []Resource{
  120. {
  121. URI: "graphql://subscription/taskCreated",
  122. Name: "taskCreated",
  123. Description: "Subscribe to task creation events. Receives Task objects when new tasks are created and assigned to you.",
  124. MimeType: "application/json",
  125. },
  126. {
  127. URI: "graphql://subscription/taskUpdated",
  128. Name: "taskUpdated",
  129. Description: "Subscribe to task update events. Receives Task objects when tasks assigned to you are updated.",
  130. MimeType: "application/json",
  131. },
  132. {
  133. URI: "graphql://subscription/taskDeleted",
  134. Name: "taskDeleted",
  135. Description: "Subscribe to task deletion events. Receives Task objects when tasks assigned to you are deleted.",
  136. MimeType: "application/json",
  137. },
  138. {
  139. URI: "graphql://subscription/messageAdded",
  140. Name: "messageAdded",
  141. Description: "Subscribe to new message events. Receives Message objects when messages are sent to you.",
  142. MimeType: "application/json",
  143. },
  144. }
  145. return &JSONRPCResponse{
  146. JSONRPC: "2.0",
  147. ID: req.ID,
  148. Result: ListResourcesResult{Resources: resources},
  149. }
  150. }
  151. // handleResourcesRead returns current state of a resource (for subscriptions, this is a description)
  152. func (s *Server) handleResourcesRead(ctx context.Context, req *JSONRPCRequest) *JSONRPCResponse {
  153. var params ReadResourceParams
  154. if req.Params != nil {
  155. if err := json.Unmarshal(req.Params, &params); err != nil {
  156. return &JSONRPCResponse{
  157. JSONRPC: "2.0",
  158. ID: req.ID,
  159. Error: ErrInvalidParams,
  160. }
  161. }
  162. }
  163. // For subscriptions, reading returns a description
  164. description := "This is a subscription resource. Use resources/subscribe to receive real-time updates."
  165. return &JSONRPCResponse{
  166. JSONRPC: "2.0",
  167. ID: req.ID,
  168. Result: ReadResourceResult{
  169. Contents: []ResourceContents{
  170. {
  171. URI: params.URI,
  172. MimeType: "text/plain",
  173. Text: description,
  174. },
  175. },
  176. },
  177. }
  178. }
  179. // handleResourcesSubscribe starts a subscription for real-time updates
  180. func (s *Server) handleResourcesSubscribe(ctx context.Context, req *JSONRPCRequest, session *Session) *JSONRPCResponse {
  181. var params SubscribeParams
  182. if req.Params != nil {
  183. if err := json.Unmarshal(req.Params, &params); err != nil {
  184. return &JSONRPCResponse{
  185. JSONRPC: "2.0",
  186. ID: req.ID,
  187. Error: ErrInvalidParams,
  188. }
  189. }
  190. }
  191. // Check authentication
  192. user, err := auth.CurrentUser(ctx)
  193. if err != nil {
  194. return &JSONRPCResponse{
  195. JSONRPC: "2.0",
  196. ID: req.ID,
  197. Error: &RPCError{Code: -32603, Message: "Authentication required for subscriptions"},
  198. }
  199. }
  200. // Create cancellable context for this subscription
  201. subCtx, cancel := context.WithCancel(context.Background())
  202. // Store the cancel function
  203. session.SubsMu.Lock()
  204. session.Subscriptions[params.URI] = cancel
  205. session.SubsMu.Unlock()
  206. // Subscribe synchronously BEFORE starting the goroutine to avoid race condition
  207. // where events are published before the subscription channel is registered
  208. var eventChan interface{}
  209. switch params.URI {
  210. case "graphql://subscription/taskCreated", "graphql://subscription/taskUpdated", "graphql://subscription/taskDeleted":
  211. eventChan = s.resolver.SubscribeToTasks(user.ID)
  212. case "graphql://subscription/messageAdded":
  213. eventChan = s.resolver.SubscribeToMessages(user.ID)
  214. default:
  215. cancel()
  216. return &JSONRPCResponse{
  217. JSONRPC: "2.0",
  218. ID: req.ID,
  219. Error: &RPCError{Code: -32602, Message: "Unknown subscription URI"},
  220. }
  221. }
  222. // Start the subscription based on URI (pass the already-created channel)
  223. go s.runSubscription(subCtx, params.URI, user.ID, session, eventChan)
  224. return &JSONRPCResponse{
  225. JSONRPC: "2.0",
  226. ID: req.ID,
  227. Result: map[string]interface{}{"subscribed": true, "uri": params.URI},
  228. }
  229. }
  230. // handleResourcesUnsubscribe stops a subscription
  231. func (s *Server) handleResourcesUnsubscribe(ctx context.Context, req *JSONRPCRequest, session *Session) *JSONRPCResponse {
  232. var params UnsubscribeParams
  233. if req.Params != nil {
  234. if err := json.Unmarshal(req.Params, &params); err != nil {
  235. return &JSONRPCResponse{
  236. JSONRPC: "2.0",
  237. ID: req.ID,
  238. Error: ErrInvalidParams,
  239. }
  240. }
  241. }
  242. session.SubsMu.Lock()
  243. if cancel, ok := session.Subscriptions[params.URI]; ok {
  244. cancel()
  245. delete(session.Subscriptions, params.URI)
  246. }
  247. session.SubsMu.Unlock()
  248. return &JSONRPCResponse{
  249. JSONRPC: "2.0",
  250. ID: req.ID,
  251. Result: map[string]interface{}{"unsubscribed": true, "uri": params.URI},
  252. }
  253. }
  254. // runSubscription handles the actual subscription event streaming
  255. func (s *Server) runSubscription(ctx context.Context, uri string, userID uint, session *Session, eventChan interface{}) {
  256. switch uri {
  257. case "graphql://subscription/taskCreated":
  258. s.streamTaskEvents(ctx, userID, session, "created", eventChan.(<-chan graph.TaskEvent))
  259. case "graphql://subscription/taskUpdated":
  260. s.streamTaskEvents(ctx, userID, session, "updated", eventChan.(<-chan graph.TaskEvent))
  261. case "graphql://subscription/taskDeleted":
  262. s.streamTaskEvents(ctx, userID, session, "deleted", eventChan.(<-chan graph.TaskEvent))
  263. case "graphql://subscription/messageAdded":
  264. s.streamMessageEvents(ctx, userID, session, eventChan.(<-chan graph.MessageEvent))
  265. }
  266. }
  267. // streamTaskEvents streams task events to the session
  268. func (s *Server) streamTaskEvents(ctx context.Context, userID uint, session *Session, eventType string, eventChan <-chan graph.TaskEvent) {
  269. for {
  270. select {
  271. case <-ctx.Done():
  272. return
  273. case <-session.Done:
  274. return
  275. case event, ok := <-eventChan:
  276. if !ok {
  277. return
  278. }
  279. if event.EventType == eventType && event.Task != nil {
  280. notification := CreateResourceNotification(
  281. "graphql://subscription/task"+capitalize(eventType),
  282. event.Task,
  283. )
  284. s.sendNotification(session, notification)
  285. }
  286. }
  287. }
  288. }
  289. // streamMessageEvents streams message events to the session
  290. func (s *Server) streamMessageEvents(ctx context.Context, userID uint, session *Session, eventChan <-chan graph.MessageEvent) {
  291. for {
  292. select {
  293. case <-ctx.Done():
  294. return
  295. case <-session.Done:
  296. return
  297. case event, ok := <-eventChan:
  298. if !ok {
  299. return
  300. }
  301. // Check if user is a receiver
  302. isReceiver := false
  303. for _, receiverID := range event.ReceiverIDs {
  304. if receiverID == userID {
  305. isReceiver = true
  306. break
  307. }
  308. }
  309. if isReceiver && event.Message != nil {
  310. notification := CreateResourceNotification(
  311. "graphql://subscription/messageAdded",
  312. event.Message,
  313. )
  314. s.sendNotification(session, notification)
  315. }
  316. }
  317. }
  318. }
  319. // sendNotification sends a JSON-RPC notification to the session
  320. func (s *Server) sendNotification(session *Session, notification *JSONRPCNotification) {
  321. notifBytes, err := json.Marshal(notification)
  322. if err != nil {
  323. return
  324. }
  325. select {
  326. case session.Events <- notifBytes:
  327. default:
  328. // Channel full, skip
  329. }
  330. }
  331. // capitalize helper
  332. func capitalize(s string) string {
  333. if len(s) == 0 {
  334. return s
  335. }
  336. return string(s[0]-32) + s[1:]
  337. }