package tools import ( "context" "encoding/json" "fmt" "github.com/vektah/gqlparser/v2" "github.com/vektah/gqlparser/v2/ast" "gogs.dmsc.dev/arp/graph" ) // Query executes a GraphQL query func Query(ctx context.Context, resolver *graph.Resolver, schema *ast.Schema, args map[string]interface{}) (CallToolResult, error) { queryStr, ok := args["query"].(string) if !ok { return CallToolResult{ Content: []ContentBlock{ {Type: "text", Text: "Missing required 'query' parameter"}, }, IsError: true, }, nil } // Parse variables variables := make(map[string]interface{}) if v, ok := args["variables"].(map[string]interface{}); ok { variables = v } // Parse the query queryDoc, err := gqlparser.LoadQuery(schema, queryStr) if err != nil { return CallToolResult{ Content: []ContentBlock{ {Type: "text", Text: fmt.Sprintf("Query parse error: %v", err)}, }, IsError: true, }, nil } // Execute each operation var results []interface{} for _, op := range queryDoc.Operations { if op.Operation != ast.Query { continue } result, errMsg := executeQuery(ctx, resolver, schema, queryDoc, op, variables) if errMsg != "" { return CallToolResult{ Content: []ContentBlock{ {Type: "text", Text: errMsg}, }, IsError: true, }, nil } results = append(results, result) } // Format response var responseText string if len(results) == 1 { bytes, _ := json.MarshalIndent(results[0], "", " ") responseText = string(bytes) } else { bytes, _ := json.MarshalIndent(results, "", " ") responseText = string(bytes) } return CallToolResult{ Content: []ContentBlock{ {Type: "text", Text: responseText}, }, }, nil } func executeQuery(ctx context.Context, resolver *graph.Resolver, schema *ast.Schema, doc *ast.QueryDocument, op *ast.OperationDefinition, variables map[string]interface{}) (map[string]interface{}, string) { result := make(map[string]interface{}) for _, sel := range op.SelectionSet { field, ok := sel.(*ast.Field) if !ok { continue } value, errMsg := resolveQueryField(ctx, resolver, schema, field, variables) if errMsg != "" { return nil, errMsg } result[field.Alias] = value } return result, "" } func resolveQueryField(ctx context.Context, resolver *graph.Resolver, schema *ast.Schema, field *ast.Field, variables map[string]interface{}) (interface{}, string) { // Get field arguments args := make(map[string]interface{}) for _, arg := range field.Arguments { value, err := arg.Value.Value(variables) if err != nil { return nil, fmt.Sprintf("failed to evaluate argument %s: %v", arg.Name, err) } args[arg.Name] = value } // Resolve based on field name switch field.Name { case "users": users, err := resolver.Query().Users(ctx) if err != nil { return nil, err.Error() } return users, "" case "user": id, _ := args["id"].(string) user, err := resolver.Query().User(ctx, id) if err != nil { return nil, err.Error() } return user, "" case "notes": notes, err := resolver.Query().Notes(ctx) if err != nil { return nil, err.Error() } return notes, "" case "note": id, _ := args["id"].(string) note, err := resolver.Query().Note(ctx, id) if err != nil { return nil, err.Error() } return note, "" case "roles": roles, err := resolver.Query().Roles(ctx) if err != nil { return nil, err.Error() } return roles, "" case "role": id, _ := args["id"].(string) role, err := resolver.Query().Role(ctx, id) if err != nil { return nil, err.Error() } return role, "" case "permissions": perms, err := resolver.Query().Permissions(ctx) if err != nil { return nil, err.Error() } return perms, "" case "permission": id, _ := args["id"].(string) perm, err := resolver.Query().Permission(ctx, id) if err != nil { return nil, err.Error() } return perm, "" case "services": services, err := resolver.Query().Services(ctx) if err != nil { return nil, err.Error() } return services, "" case "service": id, _ := args["id"].(string) service, err := resolver.Query().Service(ctx, id) if err != nil { return nil, err.Error() } return service, "" case "tasks": tasks, err := resolver.Query().Tasks(ctx) if err != nil { return nil, err.Error() } return tasks, "" case "task": id, _ := args["id"].(string) task, err := resolver.Query().Task(ctx, id) if err != nil { return nil, err.Error() } return task, "" case "taskStatuses": statuses, err := resolver.Query().TaskStatuses(ctx) if err != nil { return nil, err.Error() } return statuses, "" case "taskStatus": id, _ := args["id"].(string) status, err := resolver.Query().TaskStatus(ctx, id) if err != nil { return nil, err.Error() } return status, "" case "messages": messages, err := resolver.Query().Messages(ctx) if err != nil { return nil, err.Error() } return messages, "" case "message": id, _ := args["id"].(string) message, err := resolver.Query().Message(ctx, id) if err != nil { return nil, err.Error() } return message, "" default: return nil, fmt.Sprintf("unknown field: %s", field.Name) } }