llm.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "strings"
  8. "time"
  9. "github.com/sashabaranov/go-openai"
  10. )
  11. // LLM is an OpenAI LLM wrapper with tool-calling support
  12. type LLM struct {
  13. client *openai.Client
  14. model string
  15. temperature float32
  16. maxTokens int
  17. // Retry configuration
  18. maxRetries int
  19. retryDelay time.Duration
  20. }
  21. // NewLLM creates a new LLM instance
  22. func NewLLM(apiKey, model string, temperature float32, baseURL string, maxTokens int) *LLM {
  23. return NewLLMWithRetry(apiKey, model, temperature, baseURL, maxTokens, 3, 1*time.Second)
  24. }
  25. // NewLLMWithRetry creates a new LLM instance with custom retry configuration
  26. func NewLLMWithRetry(apiKey, model string, temperature float32, baseURL string, maxTokens int, maxRetries int, retryDelay time.Duration) *LLM {
  27. config := openai.DefaultConfig(apiKey)
  28. if baseURL != "" {
  29. config.BaseURL = baseURL
  30. }
  31. return &LLM{
  32. client: openai.NewClientWithConfig(config),
  33. model: model,
  34. temperature: temperature,
  35. maxTokens: maxTokens,
  36. maxRetries: maxRetries,
  37. retryDelay: retryDelay,
  38. }
  39. }
  40. // ChatCompletionRequest is a request for chat completion
  41. type ChatCompletionRequest struct {
  42. Messages []openai.ChatCompletionMessage
  43. Tools []openai.Tool
  44. }
  45. // ChatCompletionResponse is a response from chat completion
  46. type ChatCompletionResponse struct {
  47. Message openai.ChatCompletionMessage
  48. }
  49. // Chat sends a chat completion request with retry logic
  50. func (l *LLM) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error) {
  51. var lastErr error
  52. delay := l.retryDelay
  53. for attempt := 0; attempt <= l.maxRetries; attempt++ {
  54. select {
  55. case <-ctx.Done():
  56. return nil, fmt.Errorf("context canceled: %w", ctx.Err())
  57. default:
  58. }
  59. req := openai.ChatCompletionRequest{
  60. Model: l.model,
  61. Messages: messages,
  62. Temperature: l.temperature,
  63. MaxTokens: l.maxTokens,
  64. }
  65. if len(tools) > 0 {
  66. req.Tools = tools
  67. }
  68. resp, err := l.client.CreateChatCompletion(ctx, req)
  69. if err == nil {
  70. if len(resp.Choices) == 0 {
  71. return nil, fmt.Errorf("no response choices returned")
  72. }
  73. // Log warning if finish reason indicates an issue
  74. choice := resp.Choices[0]
  75. if choice.FinishReason == "length" {
  76. // Model hit token limit - may have incomplete response
  77. // This is common with reasoning models that need more tokens
  78. return nil, fmt.Errorf("response truncated: model hit token limit (finish_reason: length). Consider increasing OPENAI_MAX_TOKENS (current: %d). Usage: prompt=%d, completion=%d, total=%d",
  79. l.maxTokens, resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
  80. }
  81. return &choice.Message, nil
  82. }
  83. lastErr = err
  84. // Check if this error is retryable
  85. if !isRetryableError(err) {
  86. return nil, fmt.Errorf("failed to create chat completion: %w", err)
  87. }
  88. // Don't wait after the last attempt
  89. if attempt < l.maxRetries {
  90. log.Printf("LLM request failed (attempt %d/%d): %v. Retrying in %v...",
  91. attempt+1, l.maxRetries, err, delay)
  92. select {
  93. case <-ctx.Done():
  94. return nil, fmt.Errorf("context canceled during retry wait: %w", ctx.Err())
  95. case <-time.After(delay):
  96. }
  97. // Exponential backoff
  98. delay *= 2
  99. }
  100. }
  101. return nil, fmt.Errorf("failed to create chat completion after %d retries: %w", l.maxRetries+1, lastErr)
  102. }
  103. // isRetryableError checks if an error is transient and worth retrying
  104. func isRetryableError(err error) bool {
  105. if err == nil {
  106. return false
  107. }
  108. errStr := err.Error()
  109. // Context cancellation - retry if not explicitly canceled
  110. if strings.Contains(errStr, "context canceled") || strings.Contains(errStr, "context deadline exceeded") {
  111. // These can be transient if the context was canceled due to connection issues
  112. // But we should check if it's a genuine cancellation vs. a timeout
  113. return true
  114. }
  115. // Network-related errors
  116. retryablePatterns := []string{
  117. "connection refused",
  118. "connection reset",
  119. "connection closed",
  120. "network is unreachable",
  121. "no route to host",
  122. "timeout",
  123. "i/o timeout",
  124. "temporary failure",
  125. "server misbehaving",
  126. "service unavailable",
  127. "too many requests",
  128. "rate limit",
  129. "429",
  130. "500",
  131. "502",
  132. "503",
  133. "504",
  134. }
  135. lowerErr := strings.ToLower(errStr)
  136. for _, pattern := range retryablePatterns {
  137. if strings.Contains(lowerErr, strings.ToLower(pattern)) {
  138. return true
  139. }
  140. }
  141. return false
  142. }
  143. // ConvertMCPToolsToOpenAI converts MCP tools to OpenAI tool format
  144. func ConvertMCPToolsToOpenAI(mcpTools []Tool) []openai.Tool {
  145. tools := make([]openai.Tool, len(mcpTools))
  146. for i, t := range mcpTools {
  147. // Convert InputSchema to JSON schema format using map[string]interface{}
  148. props := make(map[string]interface{})
  149. for name, prop := range t.InputSchema.Properties {
  150. propMap := map[string]interface{}{
  151. "type": prop.Type,
  152. "description": prop.Description,
  153. }
  154. // For object types without explicit nested properties,
  155. // allow additionalProperties so the LLM can pass any key-value pairs
  156. // This is important for tools like 'query' and 'mutate' that accept
  157. // arbitrary variables objects
  158. if prop.Type == "object" {
  159. propMap["additionalProperties"] = true
  160. }
  161. props[name] = propMap
  162. }
  163. // Build parameters map, omitting empty required array
  164. params := map[string]interface{}{
  165. "type": t.InputSchema.Type,
  166. "properties": props,
  167. }
  168. // Only include required if it has elements - empty slice marshals as null
  169. if len(t.InputSchema.Required) > 0 {
  170. params["required"] = t.InputSchema.Required
  171. }
  172. tools[i] = openai.Tool{
  173. Type: openai.ToolTypeFunction,
  174. Function: &openai.FunctionDefinition{
  175. Name: t.Name,
  176. Description: t.Description,
  177. Parameters: params,
  178. },
  179. }
  180. }
  181. return tools
  182. }
  183. // ParseToolCall parses a tool call from the LLM response
  184. func ParseToolCall(toolCall openai.ToolCall) (string, map[string]interface{}, error) {
  185. name := toolCall.Function.Name
  186. var args map[string]interface{}
  187. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
  188. return name, nil, fmt.Errorf("failed to parse tool arguments: %w", err)
  189. }
  190. return name, args, nil
  191. }
  192. // TestConnection tests the connection to OpenAI API
  193. func (l *LLM) TestConnection(ctx context.Context) error {
  194. // Simple test request - use enough tokens for reasoning models
  195. // Reasoning models need more tokens for their thinking process
  196. req := openai.ChatCompletionRequest{
  197. Model: l.model,
  198. Messages: []openai.ChatCompletionMessage{
  199. {
  200. Role: openai.ChatMessageRoleUser,
  201. Content: "Hello",
  202. },
  203. },
  204. MaxTokens: 100,
  205. }
  206. _, err := l.client.CreateChatCompletion(ctx, req)
  207. if err != nil {
  208. return fmt.Errorf("failed to connect to OpenAI API: %w", err)
  209. }
  210. return nil
  211. }