llm.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/sashabaranov/go-openai"
  7. )
  8. // LLM is an OpenAI LLM wrapper with tool-calling support
  9. type LLM struct {
  10. client *openai.Client
  11. model string
  12. temperature float32
  13. maxTokens int
  14. }
  15. // NewLLM creates a new LLM instance
  16. func NewLLM(apiKey, model string, temperature float32, baseURL string, maxTokens int) *LLM {
  17. config := openai.DefaultConfig(apiKey)
  18. if baseURL != "" {
  19. config.BaseURL = baseURL
  20. }
  21. return &LLM{
  22. client: openai.NewClientWithConfig(config),
  23. model: model,
  24. temperature: temperature,
  25. maxTokens: maxTokens,
  26. }
  27. }
  28. // ChatCompletionRequest is a request for chat completion
  29. type ChatCompletionRequest struct {
  30. Messages []openai.ChatCompletionMessage
  31. Tools []openai.Tool
  32. }
  33. // ChatCompletionResponse is a response from chat completion
  34. type ChatCompletionResponse struct {
  35. Message openai.ChatCompletionMessage
  36. }
  37. // Chat sends a chat completion request
  38. func (l *LLM) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error) {
  39. req := openai.ChatCompletionRequest{
  40. Model: l.model,
  41. Messages: messages,
  42. Temperature: l.temperature,
  43. MaxTokens: l.maxTokens,
  44. }
  45. if len(tools) > 0 {
  46. req.Tools = tools
  47. }
  48. resp, err := l.client.CreateChatCompletion(ctx, req)
  49. if err != nil {
  50. return nil, fmt.Errorf("failed to create chat completion: %w", err)
  51. }
  52. if len(resp.Choices) == 0 {
  53. return nil, fmt.Errorf("no response choices returned")
  54. }
  55. // Log warning if finish reason indicates an issue
  56. choice := resp.Choices[0]
  57. if choice.FinishReason == "length" {
  58. // Model hit token limit - may have incomplete response
  59. // This is common with reasoning models that need more tokens
  60. return nil, fmt.Errorf("response truncated: model hit token limit (finish_reason: length). Consider increasing OPENAI_MAX_TOKENS (current: %d). Usage: prompt=%d, completion=%d, total=%d",
  61. l.maxTokens, resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
  62. }
  63. return &choice.Message, nil
  64. }
  65. // ConvertMCPToolsToOpenAI converts MCP tools to OpenAI tool format
  66. func ConvertMCPToolsToOpenAI(mcpTools []Tool) []openai.Tool {
  67. tools := make([]openai.Tool, len(mcpTools))
  68. for i, t := range mcpTools {
  69. // Convert InputSchema to JSON schema format using map[string]interface{}
  70. props := make(map[string]interface{})
  71. for name, prop := range t.InputSchema.Properties {
  72. propMap := map[string]interface{}{
  73. "type": prop.Type,
  74. "description": prop.Description,
  75. }
  76. // For object types without explicit nested properties,
  77. // allow additionalProperties so the LLM can pass any key-value pairs
  78. // This is important for tools like 'query' and 'mutate' that accept
  79. // arbitrary variables objects
  80. if prop.Type == "object" {
  81. propMap["additionalProperties"] = true
  82. }
  83. props[name] = propMap
  84. }
  85. // Build parameters map, omitting empty required array
  86. params := map[string]interface{}{
  87. "type": t.InputSchema.Type,
  88. "properties": props,
  89. }
  90. // Only include required if it has elements - empty slice marshals as null
  91. if len(t.InputSchema.Required) > 0 {
  92. params["required"] = t.InputSchema.Required
  93. }
  94. tools[i] = openai.Tool{
  95. Type: openai.ToolTypeFunction,
  96. Function: &openai.FunctionDefinition{
  97. Name: t.Name,
  98. Description: t.Description,
  99. Parameters: params,
  100. },
  101. }
  102. }
  103. return tools
  104. }
  105. // ParseToolCall parses a tool call from the LLM response
  106. func ParseToolCall(toolCall openai.ToolCall) (string, map[string]interface{}, error) {
  107. name := toolCall.Function.Name
  108. var args map[string]interface{}
  109. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
  110. return name, nil, fmt.Errorf("failed to parse tool arguments: %w", err)
  111. }
  112. return name, args, nil
  113. }
  114. // TestConnection tests the connection to OpenAI API
  115. func (l *LLM) TestConnection(ctx context.Context) error {
  116. // Simple test request - use enough tokens for reasoning models
  117. // Reasoning models need more tokens for their thinking process
  118. req := openai.ChatCompletionRequest{
  119. Model: l.model,
  120. Messages: []openai.ChatCompletionMessage{
  121. {
  122. Role: openai.ChatMessageRoleUser,
  123. Content: "Hello",
  124. },
  125. },
  126. MaxTokens: 100,
  127. }
  128. _, err := l.client.CreateChatCompletion(ctx, req)
  129. if err != nil {
  130. return fmt.Errorf("failed to connect to OpenAI API: %w", err)
  131. }
  132. return nil
  133. }