package mcp import ( "context" "encoding/json" "fmt" "strings" "testing" "time" "github.com/bradleyjkemp/cupaloy/v2" "github.com/vektah/gqlparser/v2/ast" "gogs.dmsc.dev/arp/auth" "gogs.dmsc.dev/arp/graph" "gogs.dmsc.dev/arp/graph/testutil" "gorm.io/gorm" ) var snapshotter = cupaloy.New(cupaloy.SnapshotSubdirectory("testdata/snapshots")) // MCPTestClient wraps the MCP server for testing type MCPTestClient struct { server *Server db *gorm.DB schema *ast.Schema session *Session user *auth.UserContext } // IDTracker tracks entity IDs created during tests type IDTracker struct { Permissions map[string]string Roles map[string]string Users map[string]string TaskStatuses map[string]string Services map[string]string Tasks map[string]string Notes map[string]string Messages []string } func NewIDTracker() *IDTracker { return &IDTracker{ Permissions: make(map[string]string), Roles: make(map[string]string), Users: make(map[string]string), TaskStatuses: make(map[string]string), Services: make(map[string]string), Tasks: make(map[string]string), Notes: make(map[string]string), Messages: make([]string, 0), } } // setupMCPTestClient creates a test client with bootstrapped database func setupMCPTestClient(t *testing.T) (*MCPTestClient, *IDTracker) { db, err := testutil.SetupAndBootstrapTestDB() if err != nil { t.Fatalf("Failed to setup test database: %v", err) } resolver := graph.NewResolver(db) schema := graph.NewExecutableSchema(graph.Config{Resolvers: resolver}) astSchema := schema.Schema() server := NewServer(resolver, astSchema) adminUser := &auth.UserContext{ ID: 1, Email: "admin@example.com", Roles: []auth.RoleClaim{{ID: 1, Name: "admin"}}, Permissions: []string{ "user:read", "user:write", "task:read", "task:write", "service:read", "service:write", "note:read", "note:write", }, } session := &Session{ ID: "test-session-id", User: adminUser, Events: make(chan []byte, 100), Done: make(chan struct{}), Subscriptions: make(map[string]context.CancelFunc), } tracker := NewIDTracker() tracker.Users["admin@example.com"] = "1" var perms []struct { ID uint Code string } db.Table("permissions").Find(&perms) for _, perm := range perms { tracker.Permissions[perm.Code] = fmt.Sprintf("%d", perm.ID) } var roles []struct { ID uint Name string } db.Table("roles").Find(&roles) for _, role := range roles { tracker.Roles[role.Name] = fmt.Sprintf("%d", role.ID) } var statuses []struct { ID uint Code string } db.Table("task_statuses").Find(&statuses) for _, status := range statuses { tracker.TaskStatuses[status.Code] = fmt.Sprintf("%d", status.ID) } return &MCPTestClient{ server: server, db: db, schema: astSchema, session: session, user: adminUser, }, tracker } func (tc *MCPTestClient) callTool(ctx context.Context, name string, args map[string]interface{}) *JSONRPCResponse { argsJSON, _ := json.Marshal(args) params := json.RawMessage(fmt.Sprintf(`{"name": "%s", "arguments": %s}`, name, string(argsJSON))) req := &JSONRPCRequest{ JSONRPC: "2.0", ID: fmt.Sprintf("test-%d", time.Now().UnixNano()), Method: "tools/call", Params: params, } // Inject user into context ctxWithUser := auth.WithUser(ctx, tc.user) return tc.server.handleRequest(ctxWithUser, req, tc.session) } func (tc *MCPTestClient) callMethod(ctx context.Context, method string, params json.RawMessage) *JSONRPCResponse { req := &JSONRPCRequest{ JSONRPC: "2.0", ID: fmt.Sprintf("test-%d", time.Now().UnixNano()), Method: method, Params: params, } // Inject user into context ctxWithUser := auth.WithUser(ctx, tc.user) return tc.server.handleRequest(ctxWithUser, req, tc.session) } func (tc *MCPTestClient) subscribeToResource(ctx context.Context, uri string) *JSONRPCResponse { params := json.RawMessage(fmt.Sprintf(`{"uri": "%s"}`, uri)) return tc.callMethod(ctx, "resources/subscribe", params) } func (tc *MCPTestClient) unsubscribeFromResource(ctx context.Context, uri string) *JSONRPCResponse { params := json.RawMessage(fmt.Sprintf(`{"uri": "%s"}`, uri)) return tc.callMethod(ctx, "resources/unsubscribe", params) } func normalizeJSON(jsonStr string) string { var data interface{} if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { return jsonStr } normalizeData(data) bytes, _ := json.MarshalIndent(data, "", " ") return string(bytes) } func normalizeData(data interface{}) { switch v := data.(type) { case map[string]interface{}: delete(v, "id") delete(v, "ID") delete(v, "createdAt") delete(v, "updatedAt") delete(v, "sentAt") delete(v, "createdByID") delete(v, "userId") delete(v, "serviceId") delete(v, "statusId") delete(v, "assigneeId") delete(v, "conversationId") delete(v, "senderId") delete(v, "password") // Remove password hashes for key, val := range v { // Normalize embedded JSON strings in "text" field if key == "text" { if strVal, ok := val.(string); ok { var embedded interface{} if err := json.Unmarshal([]byte(strVal), &embedded); err == nil { normalizeData(embedded) if embeddedBytes, err := json.Marshal(embedded); err == nil { v[key] = string(embeddedBytes) continue } } } } normalizeData(val) } case []interface{}: for _, item := range v { normalizeData(item) } } } func snapshotResult(t *testing.T, name string, response *JSONRPCResponse) { jsonBytes, _ := json.MarshalIndent(response, "", " ") normalized := normalizeJSON(string(jsonBytes)) snapshotter.SnapshotT(t, name, normalized) } // TestMCP_Initialize tests the initialize method func TestMCP_Initialize(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() params := json.RawMessage(`{"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test-client", "version": "1.0.0"}}`) response := tc.callMethod(ctx, "initialize", params) if response.Error != nil { t.Fatalf("Initialize failed: %v", response.Error) } result, ok := response.Result.(InitializeResult) if !ok { t.Fatalf("Expected InitializeResult, got %T", response.Result) } if result.ProtocolVersion != ProtocolVersion { t.Errorf("Expected protocol version %s, got %s", ProtocolVersion, result.ProtocolVersion) } if result.Capabilities.Tools == nil { t.Error("Expected tools capability to be present") } if result.Capabilities.Resources == nil { t.Error("Expected resources capability to be present") } if !result.Capabilities.Resources.Subscribe { t.Error("Expected resources.subscribe to be true") } snapshotResult(t, "initialize", response) } // TestMCP_ToolsList tests the tools/list method func TestMCP_ToolsList(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() response := tc.callMethod(ctx, "tools/list", nil) if response.Error != nil { t.Fatalf("tools/list failed: %v", response.Error) } snapshotResult(t, "tools_list", response) } // TestMCP_ResourcesList tests the resources/list method func TestMCP_ResourcesList(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() response := tc.callMethod(ctx, "resources/list", nil) if response.Error != nil { t.Fatalf("resources/list failed: %v", response.Error) } snapshotResult(t, "resources_list", response) } // TestMCP_Introspect tests the introspect tool func TestMCP_Introspect(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() t.Run("FullSchema", func(t *testing.T) { response := tc.callTool(ctx, "introspect", map[string]interface{}{}) if response.Error != nil { t.Fatalf("introspect failed: %v", response.Error) } // Verify the response contains expected content (skip snapshot due to non-deterministic ordering) jsonBytes, _ := json.Marshal(response.Result) var result map[string]interface{} if err := json.Unmarshal(jsonBytes, &result); err != nil { t.Fatalf("Failed to unmarshal result: %v", err) } content, ok := result["content"].([]interface{}) if !ok || len(content) == 0 { t.Fatal("Expected content array with at least one item") } text, ok := content[0].(map[string]interface{})["text"].(string) if !ok { t.Fatal("Expected text field in content") } // Verify key sections are present expectedSections := []string{"Query Type", "Mutation Type", "Subscription Type", "Object Types", "Input Types"} for _, section := range expectedSections { if !strings.Contains(text, section) { t.Errorf("Expected section '%s' in introspection result", section) } } }) t.Run("QueryType", func(t *testing.T) { response := tc.callTool(ctx, "introspect", map[string]interface{}{ "typeName": "Query", }) if response.Error != nil { t.Fatalf("introspect Query failed: %v", response.Error) } snapshotResult(t, "introspect_query", response) }) t.Run("UserType", func(t *testing.T) { response := tc.callTool(ctx, "introspect", map[string]interface{}{ "typeName": "User", }) if response.Error != nil { t.Fatalf("introspect User failed: %v", response.Error) } snapshotResult(t, "introspect_user", response) }) } // TestMCP_Query tests the query tool func TestMCP_Query(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() t.Run("Users", func(t *testing.T) { response := tc.callTool(ctx, "query", map[string]interface{}{ "query": "query { users { email roles { name } } }", }) if response.Error != nil { t.Fatalf("query users failed: %v", response.Error) } snapshotResult(t, "query_users", response) }) t.Run("Tasks", func(t *testing.T) { response := tc.callTool(ctx, "query", map[string]interface{}{ "query": "query { tasks { title content priority } }", }) if response.Error != nil { t.Fatalf("query tasks failed: %v", response.Error) } snapshotResult(t, "query_tasks", response) }) t.Run("Services", func(t *testing.T) { response := tc.callTool(ctx, "query", map[string]interface{}{ "query": "query { services { name description } }", }) if response.Error != nil { t.Fatalf("query services failed: %v", response.Error) } snapshotResult(t, "query_services", response) }) t.Run("Roles", func(t *testing.T) { response := tc.callTool(ctx, "query", map[string]interface{}{ "query": "query { roles { name description permissions { code } } }", }) if response.Error != nil { t.Fatalf("query roles failed: %v", response.Error) } snapshotResult(t, "query_roles", response) }) t.Run("InvalidQuery", func(t *testing.T) { response := tc.callTool(ctx, "query", map[string]interface{}{ "query": "query { nonexistent { id } }", }) if response.Error != nil { t.Fatalf("query failed: %v", response.Error) } // The result is a CallToolResult with isError=true jsonBytes, _ := json.Marshal(response.Result) var result map[string]interface{} if err := json.Unmarshal(jsonBytes, &result); err != nil { t.Fatalf("Failed to unmarshal result: %v", err) } if isError, ok := result["isError"].(bool); !ok || !isError { t.Error("Expected isError to be true for invalid query") } snapshotResult(t, "query_invalid", response) }) } // TestMCP_Mutate tests the mutate tool func TestMCP_Mutate(t *testing.T) { tc, tracker := setupMCPTestClient(t) ctx := context.Background() t.Run("CreateUser", func(t *testing.T) { response := tc.callTool(ctx, "mutate", map[string]interface{}{ "mutation": fmt.Sprintf(`mutation { createUser(input: {email: "newuser@example.com", password: "password123", roles: ["%s"]}) { email } }`, tracker.Roles["admin"]), }) if response.Error != nil { t.Fatalf("createUser failed: %v", response.Error) } snapshotResult(t, "mutate_create_user", response) }) t.Run("CreateTask", func(t *testing.T) { response := tc.callTool(ctx, "mutate", map[string]interface{}{ "mutation": fmt.Sprintf(`mutation { createTask(input: {title: "Test Task", content: "Test content", createdById: "%s", statusId: "%s", priority: "high"}) { title content priority } }`, tracker.Users["admin@example.com"], tracker.TaskStatuses["open"]), }) if response.Error != nil { t.Fatalf("createTask failed: %v", response.Error) } snapshotResult(t, "mutate_create_task", response) }) t.Run("CreateNote", func(t *testing.T) { response := tc.callTool(ctx, "mutate", map[string]interface{}{ "mutation": fmt.Sprintf(`mutation { createNote(input: {title: "Test Note", content: "Note content", userId: "%s"}) { title content } }`, tracker.Users["admin@example.com"]), }) if response.Error != nil { t.Fatalf("createNote failed: %v", response.Error) } snapshotResult(t, "mutate_create_note", response) }) } // TestMCP_Resources tests resource operations func TestMCP_Resources(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() t.Run("Read", func(t *testing.T) { params := json.RawMessage(`{"uri": "graphql://subscription/taskCreated"}`) response := tc.callMethod(ctx, "resources/read", params) if response.Error != nil { t.Fatalf("resources/read failed: %v", response.Error) } snapshotResult(t, "resources_read", response) }) t.Run("Subscribe", func(t *testing.T) { response := tc.subscribeToResource(ctx, "graphql://subscription/taskCreated") if response.Error != nil { t.Fatalf("resources/subscribe failed: %v", response.Error) } // Verify subscription was registered tc.session.SubsMu.RLock() _, ok := tc.session.Subscriptions["graphql://subscription/taskCreated"] tc.session.SubsMu.RUnlock() if !ok { t.Error("Expected subscription to be registered in session") } snapshotResult(t, "resources_subscribe", response) }) t.Run("Unsubscribe", func(t *testing.T) { // First subscribe tc.subscribeToResource(ctx, "graphql://subscription/taskUpdated") // Then unsubscribe response := tc.unsubscribeFromResource(ctx, "graphql://subscription/taskUpdated") if response.Error != nil { t.Fatalf("resources/unsubscribe failed: %v", response.Error) } // Verify subscription was removed tc.session.SubsMu.RLock() _, ok := tc.session.Subscriptions["graphql://subscription/taskUpdated"] tc.session.SubsMu.RUnlock() if ok { t.Error("Expected subscription to be removed from session") } snapshotResult(t, "resources_unsubscribe", response) }) } // TestMCP_SubscriptionNotifications tests that subscription notifications are sent func TestMCP_SubscriptionNotifications(t *testing.T) { tc, tracker := setupMCPTestClient(t) ctx := context.Background() // Subscribe to taskCreated _ = tc.subscribeToResource(ctx, "graphql://subscription/taskCreated") // Create a task assigned to admin user (ID 1) _ = tc.callTool(ctx, "mutate", map[string]interface{}{ "mutation": fmt.Sprintf(`mutation { createTask(input: {title: "Notification Test Task", content: "Testing notifications", createdById: "%s", assigneeId: "%s", statusId: "%s", priority: "medium"}) { title } }`, tracker.Users["admin@example.com"], tracker.Users["admin@example.com"], tracker.TaskStatuses["open"]), }) // Wait for notification select { case event := <-tc.session.Events: var notification JSONRPCNotification if err := json.Unmarshal(event, ¬ification); err != nil { t.Fatalf("Failed to unmarshal notification: %v", err) } // Verify it's a resource update notification if notification.Method != "notifications/resources/updated" { t.Errorf("Expected method 'notifications/resources/updated', got '%s'", notification.Method) } t.Logf("Received notification: %s", string(event)) case <-time.After(2 * time.Second): t.Error("Timeout waiting for taskCreated notification") } } // TestMCP_Ping tests the ping method func TestMCP_Ping(t *testing.T) { tc, _ := setupMCPTestClient(t) ctx := context.Background() response := tc.callMethod(ctx, "ping", nil) if response.Error != nil { t.Fatalf("ping failed: %v", response.Error) } snapshotResult(t, "ping", response) } // TestMCP_Unauthenticated tests operations without authentication func TestMCP_Unauthenticated(t *testing.T) { db, err := testutil.SetupAndBootstrapTestDB() if err != nil { t.Fatalf("Failed to setup test database: %v", err) } resolver := graph.NewResolver(db) schema := graph.NewExecutableSchema(graph.Config{Resolvers: resolver}) astSchema := schema.Schema() server := NewServer(resolver, astSchema) // Create session without user session := &Session{ ID: "unauth-session", User: nil, // No user Events: make(chan []byte, 100), Done: make(chan struct{}), Subscriptions: make(map[string]context.CancelFunc), } ctx := context.Background() t.Run("SubscribeRequiresAuth", func(t *testing.T) { params := json.RawMessage(`{"uri": "graphql://subscription/taskCreated"}`) req := &JSONRPCRequest{ JSONRPC: "2.0", ID: "test-unauth", Method: "resources/subscribe", Params: params, } response := server.handleRequest(ctx, req, session) if response.Error == nil { t.Error("Expected error for unauthenticated subscribe") } if !strings.Contains(response.Error.Message, "Authentication required") { t.Errorf("Expected authentication error, got: %s", response.Error.Message) } }) t.Run("QueryWithoutAuth", func(t *testing.T) { // Query should work without auth (no permission checks in current implementation) argsJSON, _ := json.Marshal(map[string]interface{}{ "query": "query { users { email } }", }) params := json.RawMessage(fmt.Sprintf(`{"name": "query", "arguments": %s}`, string(argsJSON))) req := &JSONRPCRequest{ JSONRPC: "2.0", ID: "test-unauth-query", Method: "tools/call", Params: params, } response := server.handleRequest(ctx, req, session) // Query should succeed (no auth required for basic queries) if response.Error != nil { t.Logf("Query returned error (may be expected): %v", response.Error) } }) }