integration_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/bradleyjkemp/cupaloy/v2"
  10. "github.com/vektah/gqlparser/v2/ast"
  11. "gogs.dmsc.dev/arp/auth"
  12. "gogs.dmsc.dev/arp/graph"
  13. "gogs.dmsc.dev/arp/graph/testutil"
  14. "gorm.io/gorm"
  15. )
  16. var snapshotter = cupaloy.New(cupaloy.SnapshotSubdirectory("testdata/snapshots"))
  17. // MCPTestClient wraps the MCP server for testing
  18. type MCPTestClient struct {
  19. server *Server
  20. db *gorm.DB
  21. schema *ast.Schema
  22. session *Session
  23. user *auth.UserContext
  24. }
  25. // IDTracker tracks entity IDs created during tests
  26. type IDTracker struct {
  27. Permissions map[string]string
  28. Roles map[string]string
  29. Users map[string]string
  30. TaskStatuses map[string]string
  31. Services map[string]string
  32. Tasks map[string]string
  33. Notes map[string]string
  34. Messages []string
  35. }
  36. func NewIDTracker() *IDTracker {
  37. return &IDTracker{
  38. Permissions: make(map[string]string),
  39. Roles: make(map[string]string),
  40. Users: make(map[string]string),
  41. TaskStatuses: make(map[string]string),
  42. Services: make(map[string]string),
  43. Tasks: make(map[string]string),
  44. Notes: make(map[string]string),
  45. Messages: make([]string, 0),
  46. }
  47. }
  48. // setupMCPTestClient creates a test client with bootstrapped database
  49. func setupMCPTestClient(t *testing.T) (*MCPTestClient, *IDTracker) {
  50. db, err := testutil.SetupAndBootstrapTestDB()
  51. if err != nil {
  52. t.Fatalf("Failed to setup test database: %v", err)
  53. }
  54. resolver := graph.NewResolver(db)
  55. schema := graph.NewExecutableSchema(graph.Config{Resolvers: resolver})
  56. astSchema := schema.Schema()
  57. server := NewServer(resolver, astSchema)
  58. adminUser := &auth.UserContext{
  59. ID: 1,
  60. Email: "admin@example.com",
  61. Roles: []auth.RoleClaim{{ID: 1, Name: "admin"}},
  62. Permissions: []string{
  63. "user:read", "user:write",
  64. "task:read", "task:write",
  65. "service:read", "service:write",
  66. "note:read", "note:write",
  67. },
  68. }
  69. session := &Session{
  70. ID: "test-session-id",
  71. User: adminUser,
  72. Events: make(chan []byte, 100),
  73. Done: make(chan struct{}),
  74. Subscriptions: make(map[string]context.CancelFunc),
  75. }
  76. tracker := NewIDTracker()
  77. tracker.Users["admin@example.com"] = "1"
  78. var perms []struct {
  79. ID uint
  80. Code string
  81. }
  82. db.Table("permissions").Find(&perms)
  83. for _, perm := range perms {
  84. tracker.Permissions[perm.Code] = fmt.Sprintf("%d", perm.ID)
  85. }
  86. var roles []struct {
  87. ID uint
  88. Name string
  89. }
  90. db.Table("roles").Find(&roles)
  91. for _, role := range roles {
  92. tracker.Roles[role.Name] = fmt.Sprintf("%d", role.ID)
  93. }
  94. var statuses []struct {
  95. ID uint
  96. Code string
  97. }
  98. db.Table("task_statuses").Find(&statuses)
  99. for _, status := range statuses {
  100. tracker.TaskStatuses[status.Code] = fmt.Sprintf("%d", status.ID)
  101. }
  102. return &MCPTestClient{
  103. server: server,
  104. db: db,
  105. schema: astSchema,
  106. session: session,
  107. user: adminUser,
  108. }, tracker
  109. }
  110. func (tc *MCPTestClient) callTool(ctx context.Context, name string, args map[string]interface{}) *JSONRPCResponse {
  111. argsJSON, _ := json.Marshal(args)
  112. params := json.RawMessage(fmt.Sprintf(`{"name": "%s", "arguments": %s}`, name, string(argsJSON)))
  113. req := &JSONRPCRequest{
  114. JSONRPC: "2.0",
  115. ID: fmt.Sprintf("test-%d", time.Now().UnixNano()),
  116. Method: "tools/call",
  117. Params: params,
  118. }
  119. // Inject user into context
  120. ctxWithUser := auth.WithUser(ctx, tc.user)
  121. return tc.server.handleRequest(ctxWithUser, req, tc.session)
  122. }
  123. func (tc *MCPTestClient) callMethod(ctx context.Context, method string, params json.RawMessage) *JSONRPCResponse {
  124. req := &JSONRPCRequest{
  125. JSONRPC: "2.0",
  126. ID: fmt.Sprintf("test-%d", time.Now().UnixNano()),
  127. Method: method,
  128. Params: params,
  129. }
  130. // Inject user into context
  131. ctxWithUser := auth.WithUser(ctx, tc.user)
  132. return tc.server.handleRequest(ctxWithUser, req, tc.session)
  133. }
  134. func (tc *MCPTestClient) subscribeToResource(ctx context.Context, uri string) *JSONRPCResponse {
  135. params := json.RawMessage(fmt.Sprintf(`{"uri": "%s"}`, uri))
  136. return tc.callMethod(ctx, "resources/subscribe", params)
  137. }
  138. func (tc *MCPTestClient) unsubscribeFromResource(ctx context.Context, uri string) *JSONRPCResponse {
  139. params := json.RawMessage(fmt.Sprintf(`{"uri": "%s"}`, uri))
  140. return tc.callMethod(ctx, "resources/unsubscribe", params)
  141. }
  142. func normalizeJSON(jsonStr string) string {
  143. var data interface{}
  144. if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
  145. return jsonStr
  146. }
  147. normalizeData(data)
  148. bytes, _ := json.MarshalIndent(data, "", " ")
  149. return string(bytes)
  150. }
  151. func normalizeData(data interface{}) {
  152. switch v := data.(type) {
  153. case map[string]interface{}:
  154. delete(v, "id")
  155. delete(v, "ID")
  156. delete(v, "createdAt")
  157. delete(v, "updatedAt")
  158. delete(v, "sentAt")
  159. delete(v, "createdByID")
  160. delete(v, "userId")
  161. delete(v, "serviceId")
  162. delete(v, "statusId")
  163. delete(v, "assigneeId")
  164. delete(v, "conversationId")
  165. delete(v, "senderId")
  166. delete(v, "password") // Remove password hashes
  167. for key, val := range v {
  168. // Normalize embedded JSON strings in "text" field
  169. if key == "text" {
  170. if strVal, ok := val.(string); ok {
  171. var embedded interface{}
  172. if err := json.Unmarshal([]byte(strVal), &embedded); err == nil {
  173. normalizeData(embedded)
  174. if embeddedBytes, err := json.Marshal(embedded); err == nil {
  175. v[key] = string(embeddedBytes)
  176. continue
  177. }
  178. }
  179. }
  180. }
  181. normalizeData(val)
  182. }
  183. case []interface{}:
  184. for _, item := range v {
  185. normalizeData(item)
  186. }
  187. }
  188. }
  189. func snapshotResult(t *testing.T, name string, response *JSONRPCResponse) {
  190. jsonBytes, _ := json.MarshalIndent(response, "", " ")
  191. normalized := normalizeJSON(string(jsonBytes))
  192. snapshotter.SnapshotT(t, name, normalized)
  193. }
  194. // TestMCP_Initialize tests the initialize method
  195. func TestMCP_Initialize(t *testing.T) {
  196. tc, _ := setupMCPTestClient(t)
  197. ctx := context.Background()
  198. params := json.RawMessage(`{"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test-client", "version": "1.0.0"}}`)
  199. response := tc.callMethod(ctx, "initialize", params)
  200. if response.Error != nil {
  201. t.Fatalf("Initialize failed: %v", response.Error)
  202. }
  203. result, ok := response.Result.(InitializeResult)
  204. if !ok {
  205. t.Fatalf("Expected InitializeResult, got %T", response.Result)
  206. }
  207. if result.ProtocolVersion != ProtocolVersion {
  208. t.Errorf("Expected protocol version %s, got %s", ProtocolVersion, result.ProtocolVersion)
  209. }
  210. if result.Capabilities.Tools == nil {
  211. t.Error("Expected tools capability to be present")
  212. }
  213. if result.Capabilities.Resources == nil {
  214. t.Error("Expected resources capability to be present")
  215. }
  216. if !result.Capabilities.Resources.Subscribe {
  217. t.Error("Expected resources.subscribe to be true")
  218. }
  219. snapshotResult(t, "initialize", response)
  220. }
  221. // TestMCP_ToolsList tests the tools/list method
  222. func TestMCP_ToolsList(t *testing.T) {
  223. tc, _ := setupMCPTestClient(t)
  224. ctx := context.Background()
  225. response := tc.callMethod(ctx, "tools/list", nil)
  226. if response.Error != nil {
  227. t.Fatalf("tools/list failed: %v", response.Error)
  228. }
  229. snapshotResult(t, "tools_list", response)
  230. }
  231. // TestMCP_ResourcesList tests the resources/list method
  232. func TestMCP_ResourcesList(t *testing.T) {
  233. tc, _ := setupMCPTestClient(t)
  234. ctx := context.Background()
  235. response := tc.callMethod(ctx, "resources/list", nil)
  236. if response.Error != nil {
  237. t.Fatalf("resources/list failed: %v", response.Error)
  238. }
  239. snapshotResult(t, "resources_list", response)
  240. }
  241. // TestMCP_Introspect tests the introspect tool
  242. func TestMCP_Introspect(t *testing.T) {
  243. tc, _ := setupMCPTestClient(t)
  244. ctx := context.Background()
  245. t.Run("FullSchema", func(t *testing.T) {
  246. response := tc.callTool(ctx, "introspect", map[string]interface{}{})
  247. if response.Error != nil {
  248. t.Fatalf("introspect failed: %v", response.Error)
  249. }
  250. // Verify the response contains expected content (skip snapshot due to non-deterministic ordering)
  251. jsonBytes, _ := json.Marshal(response.Result)
  252. var result map[string]interface{}
  253. if err := json.Unmarshal(jsonBytes, &result); err != nil {
  254. t.Fatalf("Failed to unmarshal result: %v", err)
  255. }
  256. content, ok := result["content"].([]interface{})
  257. if !ok || len(content) == 0 {
  258. t.Fatal("Expected content array with at least one item")
  259. }
  260. text, ok := content[0].(map[string]interface{})["text"].(string)
  261. if !ok {
  262. t.Fatal("Expected text field in content")
  263. }
  264. // Verify key sections are present
  265. expectedSections := []string{"Query Type", "Mutation Type", "Subscription Type", "Object Types", "Input Types"}
  266. for _, section := range expectedSections {
  267. if !strings.Contains(text, section) {
  268. t.Errorf("Expected section '%s' in introspection result", section)
  269. }
  270. }
  271. })
  272. t.Run("QueryType", func(t *testing.T) {
  273. response := tc.callTool(ctx, "introspect", map[string]interface{}{
  274. "typeName": "Query",
  275. })
  276. if response.Error != nil {
  277. t.Fatalf("introspect Query failed: %v", response.Error)
  278. }
  279. snapshotResult(t, "introspect_query", response)
  280. })
  281. t.Run("UserType", func(t *testing.T) {
  282. response := tc.callTool(ctx, "introspect", map[string]interface{}{
  283. "typeName": "User",
  284. })
  285. if response.Error != nil {
  286. t.Fatalf("introspect User failed: %v", response.Error)
  287. }
  288. snapshotResult(t, "introspect_user", response)
  289. })
  290. }
  291. // TestMCP_Query tests the query tool
  292. func TestMCP_Query(t *testing.T) {
  293. tc, _ := setupMCPTestClient(t)
  294. ctx := context.Background()
  295. t.Run("Users", func(t *testing.T) {
  296. response := tc.callTool(ctx, "query", map[string]interface{}{
  297. "query": "query { users { email roles { name } } }",
  298. })
  299. if response.Error != nil {
  300. t.Fatalf("query users failed: %v", response.Error)
  301. }
  302. snapshotResult(t, "query_users", response)
  303. })
  304. t.Run("Tasks", func(t *testing.T) {
  305. response := tc.callTool(ctx, "query", map[string]interface{}{
  306. "query": "query { tasks { title content priority } }",
  307. })
  308. if response.Error != nil {
  309. t.Fatalf("query tasks failed: %v", response.Error)
  310. }
  311. snapshotResult(t, "query_tasks", response)
  312. })
  313. t.Run("Services", func(t *testing.T) {
  314. response := tc.callTool(ctx, "query", map[string]interface{}{
  315. "query": "query { services { name description } }",
  316. })
  317. if response.Error != nil {
  318. t.Fatalf("query services failed: %v", response.Error)
  319. }
  320. snapshotResult(t, "query_services", response)
  321. })
  322. t.Run("Roles", func(t *testing.T) {
  323. response := tc.callTool(ctx, "query", map[string]interface{}{
  324. "query": "query { roles { name description permissions { code } } }",
  325. })
  326. if response.Error != nil {
  327. t.Fatalf("query roles failed: %v", response.Error)
  328. }
  329. snapshotResult(t, "query_roles", response)
  330. })
  331. t.Run("InvalidQuery", func(t *testing.T) {
  332. response := tc.callTool(ctx, "query", map[string]interface{}{
  333. "query": "query { nonexistent { id } }",
  334. })
  335. if response.Error != nil {
  336. t.Fatalf("query failed: %v", response.Error)
  337. }
  338. // The result is a CallToolResult with isError=true
  339. jsonBytes, _ := json.Marshal(response.Result)
  340. var result map[string]interface{}
  341. if err := json.Unmarshal(jsonBytes, &result); err != nil {
  342. t.Fatalf("Failed to unmarshal result: %v", err)
  343. }
  344. if isError, ok := result["isError"].(bool); !ok || !isError {
  345. t.Error("Expected isError to be true for invalid query")
  346. }
  347. snapshotResult(t, "query_invalid", response)
  348. })
  349. }
  350. // TestMCP_Mutate tests the mutate tool
  351. func TestMCP_Mutate(t *testing.T) {
  352. tc, tracker := setupMCPTestClient(t)
  353. ctx := context.Background()
  354. t.Run("CreateUser", func(t *testing.T) {
  355. response := tc.callTool(ctx, "mutate", map[string]interface{}{
  356. "mutation": fmt.Sprintf(`mutation { createUser(input: {email: "newuser@example.com", password: "password123", roles: ["%s"]}) { email } }`, tracker.Roles["admin"]),
  357. })
  358. if response.Error != nil {
  359. t.Fatalf("createUser failed: %v", response.Error)
  360. }
  361. snapshotResult(t, "mutate_create_user", response)
  362. })
  363. t.Run("CreateTask", func(t *testing.T) {
  364. response := tc.callTool(ctx, "mutate", map[string]interface{}{
  365. "mutation": fmt.Sprintf(`mutation { createTask(input: {title: "Test Task", content: "Test content", createdById: "%s", statusId: "%s", priority: "high"}) { title content priority } }`, tracker.Users["admin@example.com"], tracker.TaskStatuses["open"]),
  366. })
  367. if response.Error != nil {
  368. t.Fatalf("createTask failed: %v", response.Error)
  369. }
  370. snapshotResult(t, "mutate_create_task", response)
  371. })
  372. t.Run("CreateNote", func(t *testing.T) {
  373. response := tc.callTool(ctx, "mutate", map[string]interface{}{
  374. "mutation": fmt.Sprintf(`mutation { createNote(input: {title: "Test Note", content: "Note content", userId: "%s"}) { title content } }`, tracker.Users["admin@example.com"]),
  375. })
  376. if response.Error != nil {
  377. t.Fatalf("createNote failed: %v", response.Error)
  378. }
  379. snapshotResult(t, "mutate_create_note", response)
  380. })
  381. }
  382. // TestMCP_Resources tests resource operations
  383. func TestMCP_Resources(t *testing.T) {
  384. tc, _ := setupMCPTestClient(t)
  385. ctx := context.Background()
  386. t.Run("Read", func(t *testing.T) {
  387. params := json.RawMessage(`{"uri": "graphql://subscription/taskCreated"}`)
  388. response := tc.callMethod(ctx, "resources/read", params)
  389. if response.Error != nil {
  390. t.Fatalf("resources/read failed: %v", response.Error)
  391. }
  392. snapshotResult(t, "resources_read", response)
  393. })
  394. t.Run("Subscribe", func(t *testing.T) {
  395. response := tc.subscribeToResource(ctx, "graphql://subscription/taskCreated")
  396. if response.Error != nil {
  397. t.Fatalf("resources/subscribe failed: %v", response.Error)
  398. }
  399. // Verify subscription was registered
  400. tc.session.SubsMu.RLock()
  401. _, ok := tc.session.Subscriptions["graphql://subscription/taskCreated"]
  402. tc.session.SubsMu.RUnlock()
  403. if !ok {
  404. t.Error("Expected subscription to be registered in session")
  405. }
  406. snapshotResult(t, "resources_subscribe", response)
  407. })
  408. t.Run("Unsubscribe", func(t *testing.T) {
  409. // First subscribe
  410. tc.subscribeToResource(ctx, "graphql://subscription/taskUpdated")
  411. // Then unsubscribe
  412. response := tc.unsubscribeFromResource(ctx, "graphql://subscription/taskUpdated")
  413. if response.Error != nil {
  414. t.Fatalf("resources/unsubscribe failed: %v", response.Error)
  415. }
  416. // Verify subscription was removed
  417. tc.session.SubsMu.RLock()
  418. _, ok := tc.session.Subscriptions["graphql://subscription/taskUpdated"]
  419. tc.session.SubsMu.RUnlock()
  420. if ok {
  421. t.Error("Expected subscription to be removed from session")
  422. }
  423. snapshotResult(t, "resources_unsubscribe", response)
  424. })
  425. }
  426. // TestMCP_SubscriptionNotifications tests that subscription notifications are sent
  427. func TestMCP_SubscriptionNotifications(t *testing.T) {
  428. tc, tracker := setupMCPTestClient(t)
  429. ctx := context.Background()
  430. // Subscribe to taskCreated
  431. _ = tc.subscribeToResource(ctx, "graphql://subscription/taskCreated")
  432. // Create a task assigned to admin user (ID 1)
  433. _ = tc.callTool(ctx, "mutate", map[string]interface{}{
  434. "mutation": fmt.Sprintf(`mutation { createTask(input: {title: "Notification Test Task", content: "Testing notifications", createdById: "%s", assigneeId: "%s", statusId: "%s", priority: "medium"}) { title } }`, tracker.Users["admin@example.com"], tracker.Users["admin@example.com"], tracker.TaskStatuses["open"]),
  435. })
  436. // Wait for notification
  437. select {
  438. case event := <-tc.session.Events:
  439. var notification JSONRPCNotification
  440. if err := json.Unmarshal(event, &notification); err != nil {
  441. t.Fatalf("Failed to unmarshal notification: %v", err)
  442. }
  443. // Verify it's a resource update notification
  444. if notification.Method != "notifications/resources/updated" {
  445. t.Errorf("Expected method 'notifications/resources/updated', got '%s'", notification.Method)
  446. }
  447. t.Logf("Received notification: %s", string(event))
  448. case <-time.After(2 * time.Second):
  449. t.Error("Timeout waiting for taskCreated notification")
  450. }
  451. }
  452. // TestMCP_Ping tests the ping method
  453. func TestMCP_Ping(t *testing.T) {
  454. tc, _ := setupMCPTestClient(t)
  455. ctx := context.Background()
  456. response := tc.callMethod(ctx, "ping", nil)
  457. if response.Error != nil {
  458. t.Fatalf("ping failed: %v", response.Error)
  459. }
  460. snapshotResult(t, "ping", response)
  461. }
  462. // TestMCP_Unauthenticated tests operations without authentication
  463. func TestMCP_Unauthenticated(t *testing.T) {
  464. db, err := testutil.SetupAndBootstrapTestDB()
  465. if err != nil {
  466. t.Fatalf("Failed to setup test database: %v", err)
  467. }
  468. resolver := graph.NewResolver(db)
  469. schema := graph.NewExecutableSchema(graph.Config{Resolvers: resolver})
  470. astSchema := schema.Schema()
  471. server := NewServer(resolver, astSchema)
  472. // Create session without user
  473. session := &Session{
  474. ID: "unauth-session",
  475. User: nil, // No user
  476. Events: make(chan []byte, 100),
  477. Done: make(chan struct{}),
  478. Subscriptions: make(map[string]context.CancelFunc),
  479. }
  480. ctx := context.Background()
  481. t.Run("SubscribeRequiresAuth", func(t *testing.T) {
  482. params := json.RawMessage(`{"uri": "graphql://subscription/taskCreated"}`)
  483. req := &JSONRPCRequest{
  484. JSONRPC: "2.0",
  485. ID: "test-unauth",
  486. Method: "resources/subscribe",
  487. Params: params,
  488. }
  489. response := server.handleRequest(ctx, req, session)
  490. if response.Error == nil {
  491. t.Error("Expected error for unauthenticated subscribe")
  492. }
  493. if !strings.Contains(response.Error.Message, "Authentication required") {
  494. t.Errorf("Expected authentication error, got: %s", response.Error.Message)
  495. }
  496. })
  497. t.Run("QueryWithoutAuth", func(t *testing.T) {
  498. // Query should work without auth (no permission checks in current implementation)
  499. argsJSON, _ := json.Marshal(map[string]interface{}{
  500. "query": "query { users { email } }",
  501. })
  502. params := json.RawMessage(fmt.Sprintf(`{"name": "query", "arguments": %s}`, string(argsJSON)))
  503. req := &JSONRPCRequest{
  504. JSONRPC: "2.0",
  505. ID: "test-unauth-query",
  506. Method: "tools/call",
  507. Params: params,
  508. }
  509. response := server.handleRequest(ctx, req, session)
  510. // Query should succeed (no auth required for basic queries)
  511. if response.Error != nil {
  512. t.Logf("Query returned error (may be expected): %v", response.Error)
  513. }
  514. })
  515. }