Преглед на файлове

Merge pull request #176 from langgenius/fix/lost-query-params-in-endpoint

Fix/lost query params in endpoint
Yeuoly преди 6 месеца
родител
ревизия
5c75366d95
променени са 5 файла, в които са добавени 67 реда и са изтрити 1 реда
  1. 11 1
      internal/cluster/redirect.go
  2. 39 0
      internal/cluster/redirect_test.go
  3. 6 0
      internal/server/endpoint.go
  4. 7 0
      internal/service/endpoint.go
  5. 4 0
      pkg/entities/endpoint_entities/endpoint.go

+ 11 - 1
internal/cluster/redirect.go

@@ -6,6 +6,14 @@ import (
 	"net/http"
 )
 
+func constructRedirectUrl(ip address, request *http.Request) string {
+	url := "http://" + ip.fullAddress() + request.URL.Path
+	if request.URL.RawQuery != "" {
+		url += "?" + request.URL.RawQuery
+	}
+	return url
+}
+
 // RedirectRequest redirects the request to the specified node
 func (c *Cluster) RedirectRequest(
 	node_id string, request *http.Request,
@@ -22,10 +30,12 @@ func (c *Cluster) RedirectRequest(
 
 	ip := ips[0]
 
+	url := constructRedirectUrl(ip, request)
+
 	// create a new request
 	redirectedRequest, err := http.NewRequest(
 		request.Method,
-		"http://"+ip.fullAddress()+request.URL.Path,
+		url,
 		request.Body,
 	)
 

+ 39 - 0
internal/cluster/redirect_test.go

@@ -11,6 +11,7 @@ import (
 
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
+	"github.com/langgenius/dify-plugin-daemon/pkg/entities/endpoint_entities"
 )
 
 type SimulationCheckServer struct {
@@ -180,3 +181,41 @@ func TestRedirectTraffic(t *testing.T) {
 		}
 	}
 }
+
+func TestRedirectTrafficWithQueryParams(t *testing.T) {
+	request, err := http.NewRequest("GET", "http://localhost:8080/plugin/invoke/tool?a=1&b=2", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	request.Header.Set(endpoint_entities.HeaderXOriginalHost, "localhost:8080")
+
+	ip := address{
+		Ip:   "127.0.0.1",
+		Port: 8080,
+	}
+
+	redirectedRequest := constructRedirectUrl(ip, request)
+	if redirectedRequest != "http://127.0.0.1:8080/plugin/invoke/tool?a=1&b=2" {
+		t.Fatal("redirected request is not correct")
+	}
+}
+
+func TestRedirectTrafficWithOutQueryParams(t *testing.T) {
+	request, err := http.NewRequest("GET", "http://localhost:8080/plugin/invoke/tool", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	request.Header.Set(endpoint_entities.HeaderXOriginalHost, "localhost:8080")
+
+	ip := address{
+		Ip:   "127.0.0.1",
+		Port: 8080,
+	}
+
+	redirectedRequest := constructRedirectUrl(ip, request)
+	if redirectedRequest != "http://127.0.0.1:8080/plugin/invoke/tool" {
+		t.Fatal("redirected request is not correct")
+	}
+}

+ 6 - 0
internal/server/endpoint.go

@@ -13,6 +13,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/exception"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/pkg/entities/endpoint_entities"
 	"github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
 )
 
@@ -28,6 +29,11 @@ func (app *App) Endpoint(config *app.Config) func(c *gin.Context) {
 		hookId := c.Param("hook_id")
 		path := c.Param("path")
 
+		// set X-Original-Host
+		if c.Request.Header.Get(endpoint_entities.HeaderXOriginalHost) == "" {
+			c.Request.Header.Set(endpoint_entities.HeaderXOriginalHost, c.Request.Host)
+		}
+
 		if app.endpointHandler != nil {
 			app.endpointHandler(c, hookId, time.Duration(config.PluginMaxExecutionTimeout)*time.Second, path)
 		} else {

+ 7 - 0
internal/service/endpoint.go

@@ -24,6 +24,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/encryption"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 	"github.com/langgenius/dify-plugin-daemon/pkg/entities"
+	"github.com/langgenius/dify-plugin-daemon/pkg/entities/endpoint_entities"
 	"github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/pkg/entities/requests"
 )
@@ -57,6 +58,12 @@ func copyRequest(req *http.Request, hookId string, path string) (*bytes.Buffer,
 	newReq.Header.Del("X-Original-Url")
 	newReq.Header.Del("X-Original-Host")
 
+	// check if X-Original-Host is set
+	if originalHost := req.Header.Get(endpoint_entities.HeaderXOriginalHost); originalHost != "" {
+		// replace host with original host
+		newReq.Host = originalHost
+	}
+
 	// setup hook id to request
 	newReq.Header.Set("Dify-Hook-Id", hookId)
 	// check if Dify-Hook-Url is set

+ 4 - 0
pkg/entities/endpoint_entities/endpoint.go

@@ -5,3 +5,7 @@ type EndpointResponseChunk struct {
 	Headers map[string]string `json:"headers" validate:"omitempty"`
 	Result  *string           `json:"result" validate:"omitempty"`
 }
+
+const (
+	HeaderXOriginalHost = "X-Original-Host"
+)