Files
APIPark/ai-provider/local/executor_test.go
T
2025-02-14 15:34:41 +08:00

81 lines
1.6 KiB
Go

package ai_provider_local
import (
"fmt"
"io"
"net/http"
"testing"
"github.com/gin-contrib/gzip"
"github.com/eolinker/eosc/log"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
)
func TestPullModel(t *testing.T) {
// 创建 Gin 引擎
r := gin.Default()
r.Use(gzip.Gzip(gzip.DefaultCompression))
// 设置路由,监听 "/stream" 路径
r.GET("/stream", streamHandler)
r.GET("/stop", stopPull)
r.GET("/models", models)
// 启动 HTTP 服务器
r.Run(":11180")
}
func streamHandler(c *gin.Context) {
// 创建一个通道,用于监测客户端关闭连接的信号
model := c.Query("model")
p, err := PullModel(model, uuid.NewString(), nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
done := make(chan struct{})
// 启动一个 goroutine 监听客户端关闭连接
go func() {
select {
case <-c.Writer.CloseNotify():
log.Info("client closed connection,close pipeline")
taskExecutor.ClosePipeline(model, p.id)
case <-done:
}
}()
c.Stream(func(w io.Writer) bool {
select {
case msg, ok := <-p.channel:
if !ok {
return false
}
_, err := w.Write([]byte(fmt.Sprintf("%s\n", msg.Msg)))
if err != nil {
log.Error("write message error: %v", err)
return false
}
return true
}
})
done <- struct{}{}
}
func stopPull(c *gin.Context) {
model := c.Query("model")
StopPull(model)
c.JSON(http.StatusOK, gin.H{"message": "stop pull model"})
}
func models(c *gin.Context) {
ms, err := ModelsInstalled()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"models": ms})
}