소스 검색

feat: http requests

Yeuoly 1 년 전
부모
커밋
e426996478

+ 3 - 0
.env.example

@@ -15,6 +15,9 @@ REDIS_PASS=difyai123456
 LIFETIME_COLLECTION_HEARTBEAT_INTERVAL=5
 LIFETIME_COLLECTION_CG_INTERVAL=60
 LIFETIME_STATE_GC_INTERVAL=300
+
+DIFY_INVOCATION_CONNECTION_IDLE_TIMEOUT=120
+
 STORAGE_PATH=examples
 
 PLATFORM=local

+ 1 - 0
cmd/server/main.go

@@ -33,6 +33,7 @@ func setDefault(config *app.Config) {
 	setDefaultInt(&config.LifetimeCollectionGCInterval, 60)
 	setDefaultInt(&config.LifetimeCollectionHeartbeatInterval, 5)
 	setDefaultInt(&config.LifetimeStateGCInterval, 300)
+	setDefaultInt(&config.DifyInvocationConnectionIdleTimeout, 120)
 }
 
 func setDefaultInt[T constraints.Integer](value *T, defaultValue T) {

+ 2 - 2
cmd/tests/main.go

@@ -6,11 +6,11 @@ import (
 	"sync/atomic"
 	"time"
 
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
 func main() {
-	response := entities.NewInvocationResponse[string](1024)
+	response := stream.NewStreamResponse[string](1024)
 
 	random_string := func() string {
 		return fmt.Sprintf("%d", rand.Intn(100000))

+ 33 - 0
internal/core/dify_invocation/http_client.go

@@ -0,0 +1,33 @@
+package dify_invocation
+
+import (
+	"net"
+	"net/http"
+	"net/url"
+	"time"
+)
+
+var (
+	baseurl *url.URL
+	client  *http.Client
+)
+
+func InitDifyInvocationDaemon(base string) error {
+	var err error
+	baseurl, err = url.Parse(base)
+	if err != nil {
+		return err
+	}
+
+	client = &http.Client{
+		Transport: &http.Transport{
+			Dial: (&net.Dialer{
+				Timeout:   5 * time.Second,
+				KeepAlive: 15 * time.Second,
+			}).Dial,
+			IdleConnTimeout: 120 * time.Second,
+		},
+	}
+
+	return nil
+}

+ 7 - 0
internal/core/dify_invocation/http_request.go

@@ -0,0 +1,7 @@
+package dify_invocation
+
+import "github.com/langgenius/dify-plugin-daemon/internal/utils/requests"
+
+func Request[T any](method string, path string, options ...requests.HttpOptions) (*T, error) {
+	return requests.RequestAndParse[T](client, difyPath(path), method, options...)
+}

+ 5 - 0
internal/core/dify_invocation/path.go

@@ -0,0 +1,5 @@
+package dify_invocation
+
+func difyPath(path ...string) string {
+	return baseurl.JoinPath(path...).String()
+}

+ 3 - 3
internal/core/plugin_daemon/daemon.go

@@ -5,23 +5,23 @@ import (
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
 type ToolResponseChunk = plugin_entities.InvokeToolResponseChunk
 
 func InvokeTool(session *session_manager.Session, provider_name string, tool_name string, tool_parameters map[string]any) (
-	*entities.InvocationResponse[ToolResponseChunk], error,
+	*stream.StreamResponse[ToolResponseChunk], error,
 ) {
 	runtime := plugin_manager.Get(session.PluginIdentity())
 	if runtime == nil {
 		return nil, errors.New("plugin not found")
 	}
 
-	response := entities.NewInvocationResponse[ToolResponseChunk](512)
+	response := stream.NewStreamResponse[ToolResponseChunk](512)
 
 	listener := runtime.Listen(session.ID())
 	listener.AddListener(func(message []byte) {

+ 2 - 0
internal/types/app/config.go

@@ -21,6 +21,8 @@ type Config struct {
 	LifetimeCollectionHeartbeatInterval int `envconfig:"LIFETIME_COLLECTION_HEARTBEAT_INTERVAL"`
 	LifetimeCollectionGCInterval        int `envconfig:"LIFETIME_COLLECTION_GC_INTERVAL"`
 	LifetimeStateGCInterval             int `envconfig:"LIFETIME_STATE_GC_INTERVAL"`
+
+	DifyInvocationConnectionIdleTimeout int `envconfig:"DIFY_INVOCATION_CONNECTION_IDLE_TIMEOUT"`
 }
 
 const (

+ 43 - 0
internal/utils/requests/http_options.go

@@ -0,0 +1,43 @@
+package requests
+
+type HttpOptions struct {
+	Type  string
+	Value interface{}
+}
+
+// milliseconds
+func HttpTimeout(timeout int64) HttpOptions {
+	return HttpOptions{"timeout", timeout}
+}
+
+func HttpHeader(header map[string]string) HttpOptions {
+	return HttpOptions{"header", header}
+}
+
+// which is used for params with in url
+func HttpParams(params map[string]string) HttpOptions {
+	return HttpOptions{"params", params}
+}
+
+// which is used for POST method only
+func HttpPayload(payload map[string]string) HttpOptions {
+	return HttpOptions{"payload", payload}
+}
+
+// which is used for POST method only
+func HttpPayloadText(payload string) HttpOptions {
+	return HttpOptions{"payloadText", payload}
+}
+
+// which is used for POST method only
+func HttpPayloadJson(payload interface{}) HttpOptions {
+	return HttpOptions{"payloadJson", payload}
+}
+
+func HttpWithDirectReferer() HttpOptions {
+	return HttpOptions{"directReferer", true}
+}
+
+func HttpWithRetCode(retCode *int) HttpOptions {
+	return HttpOptions{"retCode", retCode}
+}

+ 71 - 0
internal/utils/requests/http_request.go

@@ -0,0 +1,71 @@
+package requests
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"io"
+	"net/http"
+	"strings"
+	"time"
+)
+
+func buildHttpRequest(method string, url string, options ...HttpOptions) (*http.Request, error) {
+	req, err := http.NewRequest(method, url, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, option := range options {
+		switch option.Type {
+		case "timeout":
+			timeout := time.Second * time.Duration(option.Value.(int64))
+			ctx, cancel := context.WithTimeout(context.Background(), timeout)
+			defer cancel()
+			req = req.WithContext(ctx)
+		case "header":
+			for k, v := range option.Value.(map[string]string) {
+				req.Header.Set(k, v)
+			}
+		case "params":
+			q := req.URL.Query()
+			for k, v := range option.Value.(map[string]string) {
+				q.Add(k, v)
+			}
+			req.URL.RawQuery = q.Encode()
+		case "payload":
+			q := req.URL.Query()
+			for k, v := range option.Value.(map[string]string) {
+				q.Add(k, v)
+			}
+			req.Body = io.NopCloser(strings.NewReader(q.Encode()))
+		case "payloadText":
+			req.Body = io.NopCloser(strings.NewReader(option.Value.(string)))
+		case "payloadJson":
+			jsonStr, err := json.Marshal(option.Value)
+			if err != nil {
+				return nil, err
+			}
+			req.Body = io.NopCloser(bytes.NewBuffer(jsonStr))
+		case "directReferer":
+			req.Header.Set("Referer", url)
+		}
+	}
+
+	return req, nil
+}
+
+func Request(client *http.Client, url string, method string, options ...HttpOptions) (*http.Response, error) {
+	req, err := buildHttpRequest(method, url, options...)
+
+	if err != nil {
+		return nil, err
+	}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+
+	return resp, nil
+}

+ 97 - 0
internal/utils/requests/http_warpper.go

@@ -0,0 +1,97 @@
+package requests
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"net/http"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+)
+
+func parseJsonBody(resp *http.Response, ret interface{}) error {
+	defer resp.Body.Close()
+	json_decoder := json.NewDecoder(resp.Body)
+	return json_decoder.Decode(ret)
+}
+
+func RequestAndParse[T any](client *http.Client, url string, method string, options ...HttpOptions) (*T, error) {
+	var ret T
+
+	resp, err := Request(client, url, method, options...)
+	if err != nil {
+		return nil, err
+	}
+
+	err = parseJsonBody(resp, &ret)
+	if err != nil {
+		return nil, err
+	}
+
+	return &ret, nil
+}
+
+func GetAndParse[T any](client *http.Client, url string, options ...HttpOptions) (*T, error) {
+	return RequestAndParse[T](client, url, "GET", options...)
+}
+
+func PostAndParse[T any](client *http.Client, url string, options ...HttpOptions) (*T, error) {
+	return RequestAndParse[T](client, url, "POST", options...)
+}
+
+func PutAndParse[T any](client *http.Client, url string, options ...HttpOptions) (*T, error) {
+	return RequestAndParse[T](client, url, "PUT", options...)
+}
+
+func DeleteAndParse[T any](client *http.Client, url string, options ...HttpOptions) (*T, error) {
+	return RequestAndParse[T](client, url, "DELETE", options...)
+}
+
+func PatchAndParse[T any](client *http.Client, url string, options ...HttpOptions) (*T, error) {
+	return RequestAndParse[T](client, url, "PATCH", options...)
+}
+
+func RequestAndParseStream[T any](client *http.Client, url string, method string, options ...HttpOptions) (*stream.StreamResponse[T], error) {
+	resp, err := Request(client, url, method, options...)
+	if err != nil {
+		return nil, err
+	}
+
+	ch := stream.NewStreamResponse[T](1024)
+
+	routine.Submit(func() {
+		scanner := bufio.NewScanner(resp.Body)
+		for scanner.Scan() {
+			data := scanner.Bytes()
+			if bytes.HasPrefix(data, []byte("data: ")) {
+				// split
+				data = data[6:]
+				// unmarshal
+				t, err := parser.UnmarshalJsonBytes[T](data)
+				if err != nil {
+					continue
+				}
+
+				ch.Write(t)
+			}
+		}
+
+		ch.Close()
+	})
+
+	return ch, nil
+}
+
+func GetAndParseStream[T any](client *http.Client, url string, options ...HttpOptions) (*stream.StreamResponse[T], error) {
+	return RequestAndParseStream[T](client, url, "GET", options...)
+}
+
+func PostAndParseStream[T any](client *http.Client, url string, options ...HttpOptions) (*stream.StreamResponse[T], error) {
+	return RequestAndParseStream[T](client, url, "POST", options...)
+}
+
+func PutAndParseStream[T any](client *http.Client, url string, options ...HttpOptions) (*stream.StreamResponse[T], error) {
+	return RequestAndParseStream[T](client, url, "PUT", options...)
+}

+ 11 - 16
internal/types/entities/session.go

@@ -1,4 +1,4 @@
-package entities
+package stream
 
 import (
 	"errors"
@@ -7,12 +7,7 @@ import (
 	"github.com/gammazero/deque"
 )
 
-type InvocationSession struct {
-	ID             string
-	PluginIdentity string
-}
-
-type InvocationResponse[T any] struct {
+type StreamResponse[T any] struct {
 	q         deque.Deque[T]
 	l         *sync.Mutex
 	sig       chan bool
@@ -22,19 +17,19 @@ type InvocationResponse[T any] struct {
 	onClose   func()
 }
 
-func NewInvocationResponse[T any](max int) *InvocationResponse[T] {
-	return &InvocationResponse[T]{
+func NewStreamResponse[T any](max int) *StreamResponse[T] {
+	return &StreamResponse[T]{
 		l:   &sync.Mutex{},
 		sig: make(chan bool),
 		max: max,
 	}
 }
 
-func (r *InvocationResponse[T]) OnClose(f func()) {
+func (r *StreamResponse[T]) OnClose(f func()) {
 	r.onClose = f
 }
 
-func (r *InvocationResponse[T]) Next() bool {
+func (r *StreamResponse[T]) Next() bool {
 	r.l.Lock()
 	if r.closed {
 		r.l.Unlock()
@@ -55,7 +50,7 @@ func (r *InvocationResponse[T]) Next() bool {
 	return <-r.sig
 }
 
-func (r *InvocationResponse[T]) Read() (T, error) {
+func (r *StreamResponse[T]) Read() (T, error) {
 	r.l.Lock()
 	defer r.l.Unlock()
 
@@ -68,7 +63,7 @@ func (r *InvocationResponse[T]) Read() (T, error) {
 	}
 }
 
-func (r *InvocationResponse[T]) Write(data T) error {
+func (r *StreamResponse[T]) Write(data T) error {
 	r.l.Lock()
 	if r.closed {
 		r.l.Unlock()
@@ -90,7 +85,7 @@ func (r *InvocationResponse[T]) Write(data T) error {
 	return nil
 }
 
-func (r *InvocationResponse[T]) Close() {
+func (r *StreamResponse[T]) Close() {
 	r.l.Lock()
 	if r.closed {
 		r.l.Unlock()
@@ -109,14 +104,14 @@ func (r *InvocationResponse[T]) Close() {
 	}
 }
 
-func (r *InvocationResponse[T]) IsClosed() bool {
+func (r *StreamResponse[T]) IsClosed() bool {
 	r.l.Lock()
 	defer r.l.Unlock()
 
 	return r.closed
 }
 
-func (r *InvocationResponse[T]) Size() int {
+func (r *StreamResponse[T]) Size() int {
 	r.l.Lock()
 	defer r.l.Unlock()