api.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package api
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "github.com/wecisecode/util/merrs"
  10. )
  11. type GenerateRequestOptions struct {
  12. Seed int64 `json:"seed,omitempty"`
  13. }
  14. // 定义请求结构体
  15. type GenerateRequest struct {
  16. Model string `json:"model"`
  17. Prompt string `json:"prompt"`
  18. Stream bool `json:"stream"`
  19. Context []int64 `json:"context,omitempty"`
  20. Options GenerateRequestOptions `json:"options,omitempty"`
  21. }
  22. type CallFunction struct {
  23. Name string `json:"name"`
  24. Arguments map[string]any `json:"arguments,omitempty"`
  25. }
  26. type ChatToolCall struct {
  27. Function CallFunction
  28. }
  29. type ChatFunctionParameter struct {
  30. Type string `json:"type"`
  31. Description string `json:"description,omitempty"`
  32. Enum []string `json:"enum,omitempty"`
  33. }
  34. type ChatFunctionParameters struct {
  35. Type string `json:"type"` // object
  36. Properties map[string]ChatFunctionParameter `json:"properties,omitempty"`
  37. Required []string `json:"required,omitempty"`
  38. }
  39. type ChatFunction struct {
  40. Name string `json:"name"`
  41. Description string `json:"description,omitempty"`
  42. Parameters *ChatFunctionParameters `json:"parameters,omitempty"`
  43. }
  44. type ChatTool struct {
  45. Type string `json:"type"` // "function"
  46. Function *ChatFunction `json:"function"`
  47. }
  48. type ChatMessage struct {
  49. Role string `json:"role"`
  50. Content string `json:"content"`
  51. Images []string `json:"images,omitempty"`
  52. ToolCalls []*ChatToolCall `json:"tool_calls,omitempty"`
  53. }
  54. // 定义请求结构体
  55. type ChatRequest struct {
  56. Model string `json:"model"`
  57. Messages []*ChatMessage `json:"messages"`
  58. Stream bool `json:"stream"`
  59. Tools []*ChatTool `json:"tools,omitempty"`
  60. }
  61. type ChatResponse struct {
  62. Role string `json:"role"`
  63. Content string `json:"content"`
  64. Images []string `json:"images,omitempty"`
  65. Tools []*ChatTool `json:"tool_calls,omitempty"`
  66. }
  67. // 定义响应结构体
  68. type GenerateResponse struct {
  69. Response string `json:"response,omitempty"`
  70. Context []int64 `json:"context,omitempty"`
  71. Done bool `json:"done,omitempty"`
  72. }
  73. type result struct {
  74. Error error
  75. Response *GenerateResponse
  76. }
  77. type Result <-chan *result
  78. type ChanResult chan *result
  79. func newResult() ChanResult {
  80. return make(ChanResult, 10)
  81. }
  82. func (ret ChanResult) Error(e error) Result {
  83. ret <- &result{Error: e}
  84. return (chan *result)(ret)
  85. }
  86. func (ret ChanResult) Response(v *GenerateResponse) Result {
  87. ret <- &result{Response: v}
  88. return (chan *result)(ret)
  89. }
  90. func (ret ChanResult) Result() Result {
  91. return (chan *result)(ret)
  92. }
  93. func (ret ChanResult) Close() {
  94. close(ret)
  95. }
  96. func Request(context []int64, msg string) Result {
  97. ret := newResult()
  98. go func() {
  99. defer ret.Close()
  100. // Ollama 服务地址
  101. url := "http://127.0.0.1:11434/api/generate"
  102. // 创建请求体
  103. requestData := GenerateRequest{
  104. Model: "deepseek-r1:7b", // 使用的模型名称
  105. Prompt: msg, // 输入的提示
  106. Stream: true, // 流式响应
  107. Context: context, // 上下文
  108. Options: GenerateRequestOptions{
  109. Seed: 54321,
  110. },
  111. }
  112. // 将结构体转换为 JSON
  113. jsonData, err := json.Marshal(requestData)
  114. if err != nil {
  115. ret.Error(merrs.New("JSON 编码错误:", err))
  116. return
  117. }
  118. // 创建 HTTP 请求
  119. resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
  120. if err != nil {
  121. ret.Error(merrs.New("请求失败:", err))
  122. return
  123. }
  124. defer resp.Body.Close()
  125. // 检查状态码
  126. if resp.StatusCode != http.StatusOK {
  127. body, _ := io.ReadAll(resp.Body)
  128. ret.Error(merrs.New(fmt.Sprintf("错误响应: %s\n状态码: %d\n", body, resp.StatusCode)))
  129. return
  130. }
  131. // 流式处理响应
  132. scanner := bufio.NewScanner(resp.Body)
  133. for scanner.Scan() {
  134. var chunk GenerateResponse
  135. if err := json.Unmarshal(scanner.Bytes(), &chunk); err != nil {
  136. ret.Error(merrs.New("解析分块失败:", err))
  137. break
  138. }
  139. // fmt.Print(chunk.Response) // 逐块打印响应
  140. ret.Response(&chunk)
  141. if chunk.Done {
  142. break
  143. }
  144. }
  145. }()
  146. return ret.Result()
  147. }