Преглед на файлове

refactor: Extract request copying logic into separate function with test

Yeuoly преди 4 месеца
родител
ревизия
24315ec568
променени са 2 файла, в които са добавени 65 реда и са изтрити 30 реда
  1. 39 30
      internal/service/endpoint.go
  2. 26 0
      internal/service/endpoint_test.go

+ 39 - 30
internal/service/endpoint.go

@@ -7,6 +7,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net/http"
 	"sync/atomic"
 	"time"
 
@@ -27,59 +28,67 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/pkg/entities/requests"
 )
 
-func Endpoint(
-	ctx *gin.Context,
-	endpoint *models.Endpoint,
-	pluginInstallation *models.PluginInstallation,
-	maxExecutionTime time.Duration,
-	path string,
-) {
-	if !endpoint.Enabled {
-		ctx.JSON(404, exception.NotFoundError(errors.New("endpoint not found")).ToResponse())
-		return
-	}
-
-	req := ctx.Request.Clone(context.Background())
+func copyRequest(req *http.Request, hookId string, path string) (*bytes.Buffer, error) {
+	newReq := req.Clone(context.Background())
 	// get query params
 	queryParams := req.URL.Query()
 
 	// replace path with endpoint path
-	req.URL.Path = path
+	newReq.URL.Path = path
 	// set query params
-	req.URL.RawQuery = queryParams.Encode()
+	newReq.URL.RawQuery = queryParams.Encode()
 
 	// read request body until complete, max 10MB
 	body, err := io.ReadAll(io.LimitReader(req.Body, 10*1024*1024))
 	if err != nil {
-		ctx.JSON(500, exception.InternalServerError(err).ToResponse())
-		return
+		return nil, err
 	}
 
 	// replace with a new reader
-	req.Body = io.NopCloser(bytes.NewReader(body))
-	req.ContentLength = int64(len(body))
-	req.TransferEncoding = nil
+	newReq.Body = io.NopCloser(bytes.NewReader(body))
+	newReq.ContentLength = int64(len(body))
+	newReq.TransferEncoding = nil
 
 	// remove ip traces for security
-	req.Header.Del("X-Forwarded-For")
-	req.Header.Del("X-Real-IP")
-	req.Header.Del("X-Forwarded")
-	req.Header.Del("X-Original-Forwarded-For")
-	req.Header.Del("X-Original-Url")
-	req.Header.Del("X-Original-Host")
+	newReq.Header.Del("X-Forwarded-For")
+	newReq.Header.Del("X-Real-IP")
+	newReq.Header.Del("X-Forwarded")
+	newReq.Header.Del("X-Original-Forwarded-For")
+	newReq.Header.Del("X-Original-Url")
+	newReq.Header.Del("X-Original-Host")
 
 	// setup hook id to request
-	req.Header.Set("Dify-Hook-Id", endpoint.HookID)
+	newReq.Header.Set("Dify-Hook-Id", hookId)
 	// check if Dify-Hook-Url is set
 	if url := req.Header.Get("Dify-Hook-Url"); url == "" {
-		req.Header.Set(
+		newReq.Header.Set(
 			"Dify-Hook-Url",
-			fmt.Sprintf("http://%s:%s/e/%s%s", req.Host, req.URL.Port(), endpoint.HookID, path),
+			fmt.Sprintf("http://%s/e/%s%s", req.Host, hookId, path),
 		)
 	}
 
 	var buffer bytes.Buffer
-	err = req.Write(&buffer)
+	err = newReq.Write(&buffer)
+	if err != nil {
+		return nil, err
+	}
+
+	return &buffer, nil
+}
+
+func Endpoint(
+	ctx *gin.Context,
+	endpoint *models.Endpoint,
+	pluginInstallation *models.PluginInstallation,
+	maxExecutionTime time.Duration,
+	path string,
+) {
+	if !endpoint.Enabled {
+		ctx.JSON(404, exception.NotFoundError(errors.New("endpoint not found")).ToResponse())
+		return
+	}
+
+	buffer, err := copyRequest(ctx.Request, endpoint.HookID, path)
 	if err != nil {
 		ctx.JSON(500, exception.InternalServerError(err).ToResponse())
 		return

+ 26 - 0
internal/service/endpoint_test.go

@@ -0,0 +1,26 @@
+package service
+
+import (
+	"bytes"
+	"io"
+	"net/http"
+	"testing"
+)
+
+func TestCopyRequest(t *testing.T) {
+	req, err := http.NewRequest("GET", "http://localhost:8080/test?test=123", nil)
+	req.Body = io.NopCloser(bytes.NewReader([]byte("test")))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	buffer, err := copyRequest(req, "123", "/test")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	str := buffer.String()
+	if str != "GET /test?test=123 HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nDify-Hook-Id: 123\r\nDify-Hook-Url: http://localhost:8080/e/123/test\r\n\r\ntest" {
+		t.Fatal("request body is not equal, ", str)
+	}
+}