mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-14 20:41:15 +08:00
81 lines
1.6 KiB
Go
81 lines
1.6 KiB
Go
package ai_provider_local
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/gin-contrib/gzip"
|
|
|
|
"github.com/eolinker/eosc/log"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func TestPullModel(t *testing.T) {
|
|
// 创建 Gin 引擎
|
|
r := gin.Default()
|
|
r.Use(gzip.Gzip(gzip.DefaultCompression))
|
|
// 设置路由,监听 "/stream" 路径
|
|
r.GET("/stream", streamHandler)
|
|
r.GET("/stop", stopPull)
|
|
r.GET("/models", models)
|
|
|
|
// 启动 HTTP 服务器
|
|
r.Run(":11180")
|
|
}
|
|
|
|
func streamHandler(c *gin.Context) {
|
|
// 创建一个通道,用于监测客户端关闭连接的信号
|
|
model := c.Query("model")
|
|
p, err := PullModel(model, uuid.NewString(), nil)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
done := make(chan struct{})
|
|
// 启动一个 goroutine 监听客户端关闭连接
|
|
go func() {
|
|
select {
|
|
case <-c.Writer.CloseNotify():
|
|
log.Info("client closed connection,close pipeline")
|
|
taskExecutor.ClosePipeline(model, p.id)
|
|
case <-done:
|
|
}
|
|
}()
|
|
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case msg, ok := <-p.channel:
|
|
if !ok {
|
|
return false
|
|
}
|
|
_, err := w.Write([]byte(fmt.Sprintf("%s\n", msg.Msg)))
|
|
if err != nil {
|
|
log.Error("write message error: %v", err)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
})
|
|
done <- struct{}{}
|
|
}
|
|
|
|
func stopPull(c *gin.Context) {
|
|
model := c.Query("model")
|
|
StopPull(model)
|
|
c.JSON(http.StatusOK, gin.H{"message": "stop pull model"})
|
|
}
|
|
|
|
func models(c *gin.Context) {
|
|
ms, err := ModelsInstalled()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{"models": ms})
|
|
}
|