From cb80f7e0f3f55e444a8478a8f19df17c93fca3f5 Mon Sep 17 00:00:00 2001 From: LzSkyline Date: Fri, 21 Feb 2025 20:10:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=8F=AF=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E7=9A=84=E4=BC=9A=E8=AF=9D=E8=AE=B0=E5=BF=86=E5=8A=9F?= =?UTF-8?q?=E8=83=BD.=20Resolves=20#1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 3 + README.md | 20 ++++- main.py | 210 ++++++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 213 insertions(+), 20 deletions(-) diff --git a/.env.example b/.env.example index c546124..1213de1 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,9 @@ DIFY_API_KEYS=app-xxxxxxxxxxxxxxxx,app-yyyyyyyyyyyyyyyy,app-zzzzzzzzzzzzzzzz # Dify API Base URL DIFY_API_BASE="https://api.dify.example.com/v1" +# Keep Conversation(true/false) +ENABLE_CONVERSATION_MEMORY=true + # Server Configuration SERVER_HOST="127.0.0.1" SERVER_PORT=5000 \ No newline at end of file diff --git a/README.md b/README.md index 6b9a2a8..893e29b 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,25 @@ for chunk in response: print(chunk.choices[0].delta.content or "", end="") ``` -## 特性说明 +## 特性 + +### 会话记忆功能 + +该代理支持自动记忆会话上下文,无需客户端进行额外处理。当启用此功能时: + +- 在每个新会话的第一条回复中,会自动嵌入不可见的会话ID +- 后续的消息会自动继承会话上下文,保持对话连贯性 +- 使用零宽字符编码,(大部分情况下)不会影响消息的正常显示 + +可以通过环境变量控制此功能: + +```shell +# 在 .env 文件中设置 +ENABLE_CONVERSATION_MEMORY=true # 启用会话记忆功能 +ENABLE_CONVERSATION_MEMORY=false # 禁用会话记忆功能 +``` + +默认情况下此功能是启用的。如果您的应用场景不需要保持会话上下文,可以选择关闭此功能。 ### 流式输出优化 diff --git a/main.py b/main.py index 4079ab2..2e50e6e 100644 --- a/main.py +++ b/main.py @@ -10,17 +10,20 @@ import ast # 配置日志 logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) -# 设置 httpx 的日志级别 -logging.getLogger("httpx").setLevel(logging.DEBUG) +# 设置httpx的日志级别为WARNING,减少不必要的输出 +logging.getLogger("httpx").setLevel(logging.WARNING) # 加载环境变量 load_dotenv() +# 获取会话ID记忆功能开关配置 +ENABLE_CONVERSATION_MEMORY = os.getenv('ENABLE_CONVERSATION_MEMORY', 'true').lower() == 'true' + class DifyModelManager: def __init__(self): self.api_keys = [] @@ -109,24 +112,26 @@ def transform_openai_to_dify(openai_request, endpoint): messages = openai_request.get("messages", []) stream = openai_request.get("stream", False) + # 尝试从历史消息中提取conversation_id + conversation_id = None + if len(messages) > 1: + # 遍历历史消息,找到最近的assistant消息 + for msg in reversed(messages[:-1]): # 除了最后一条消息 + if msg.get("role") == "assistant": + content = msg.get("content", "") + # 尝试解码conversation_id + conversation_id = decode_conversation_id(content) + if conversation_id: + break + dify_request = { "inputs": {}, "query": messages[-1]["content"] if messages else "", "response_mode": "streaming" if stream else "blocking", - "conversation_id": openai_request.get("conversation_id", None), + "conversation_id": conversation_id, "user": openai_request.get("user", "default_user") } - # 添加历史消息 - if len(messages) > 1: - history = [] - for msg in messages[:-1]: # 除了最后一条消息 - history.append({ - "role": msg["role"], - "content": msg["content"] - }) - dify_request["conversation_history"] = history - return dify_request return None @@ -135,16 +140,40 @@ def transform_dify_to_openai(dify_response, model="claude-3-5-sonnet-v2", stream """将Dify格式的响应转换为OpenAI格式""" if not stream: + answer = dify_response.get("answer", "") + + # 只在启用会话记忆功能时处理conversation_id + if ENABLE_CONVERSATION_MEMORY: + conversation_id = dify_response.get("conversation_id", "") + history = dify_response.get("conversation_history", []) + + # 检查历史消息中是否已经有会话ID + has_conversation_id = False + if history: + for msg in history: + if msg.get("role") == "assistant": + content = msg.get("content", "") + if decode_conversation_id(content) is not None: + has_conversation_id = True + break + + # 只在新会话且历史消息中没有会话ID时插入 + if conversation_id and not has_conversation_id: + logger.info(f"[Debug] Inserting conversation_id: {conversation_id}, history_length: {len(history)}") + encoded = encode_conversation_id(conversation_id) + answer = answer + encoded + logger.info(f"[Debug] Response content after insertion: {repr(answer)}") + return { "id": dify_response.get("message_id", ""), "object": "chat.completion", "created": dify_response.get("created", int(time.time())), - "model": model, # 使用实际使用的模型 + "model": model, "choices": [{ "index": 0, "message": { "role": "assistant", - "content": dify_response.get("answer", "") + "content": answer }, "finish_reason": "stop" }] @@ -169,6 +198,117 @@ def create_openai_stream_response(content, message_id, model="claude-3-5-sonnet- }] } +def encode_conversation_id(conversation_id): + """将conversation_id编码为不可见的字符序列""" + if not conversation_id: + return "" + + # 使用Base64编码减少长度 + import base64 + encoded = base64.b64encode(conversation_id.encode()).decode() + + # 使用8种不同的零宽字符表示3位数字 + # 这样可以将编码长度进一步减少 + char_map = { + '0': '\u200b', # 零宽空格 + '1': '\u200c', # 零宽非连接符 + '2': '\u200d', # 零宽连接符 + '3': '\ufeff', # 零宽非断空格 + '4': '\u2060', # 词组连接符 + '5': '\u180e', # 蒙古语元音分隔符 + '6': '\u2061', # 函数应用 + '7': '\u2062', # 不可见乘号 + } + + # 将Base64字符串转换为八进制数字 + result = [] + for c in encoded: + # 将每个字符转换为8进制数字(0-7) + if c.isalpha(): + if c.isupper(): + val = ord(c) - ord('A') + else: + val = ord(c) - ord('a') + 26 + elif c.isdigit(): + val = int(c) + 52 + elif c == '+': + val = 62 + elif c == '/': + val = 63 + else: # '=' + val = 0 + + # 每个Base64字符可以产生2个3位数字 + first = (val >> 3) & 0x7 + second = val & 0x7 + result.append(char_map[str(first)]) + if c != '=': # 不编码填充字符的后半部分 + result.append(char_map[str(second)]) + + return ''.join(result) + +def decode_conversation_id(content): + """从消息内容中解码conversation_id""" + try: + # 零宽字符到3位数字的映射 + char_to_val = { + '\u200b': '0', # 零宽空格 + '\u200c': '1', # 零宽非连接符 + '\u200d': '2', # 零宽连接符 + '\ufeff': '3', # 零宽非断空格 + '\u2060': '4', # 词组连接符 + '\u180e': '5', # 蒙古语元音分隔符 + '\u2061': '6', # 函数应用 + '\u2062': '7', # 不可见乘号 + } + + # 提取最后一段零宽字符序列 + space_chars = [] + for c in reversed(content): + if c not in char_to_val: + break + space_chars.append(c) + + if not space_chars: + return None + + # 将零宽字符转换回Base64字符串 + space_chars.reverse() + base64_chars = [] + for i in range(0, len(space_chars), 2): + first = int(char_to_val[space_chars[i]], 8) + if i + 1 < len(space_chars): + second = int(char_to_val[space_chars[i + 1]], 8) + val = (first << 3) | second + else: + val = first << 3 + + # 转换回Base64字符 + if val < 26: + base64_chars.append(chr(val + ord('A'))) + elif val < 52: + base64_chars.append(chr(val - 26 + ord('a'))) + elif val < 62: + base64_chars.append(str(val - 52)) + elif val == 62: + base64_chars.append('+') + else: + base64_chars.append('/') + + # 添加Base64填充 + padding = len(base64_chars) % 4 + if padding: + base64_chars.extend(['='] * (4 - padding)) + + # 解码Base64字符串 + import base64 + base64_str = ''.join(base64_chars) + return base64.b64decode(base64_str).decode() + + except Exception as e: + logger.debug(f"Failed to decode conversation_id: {e}") + return None + @app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): try: @@ -176,7 +316,6 @@ def chat_completions(): logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}") model = openai_request.get("model", "claude-3-5-sonnet-v2") - logger.info(f"Using model: {model}") # 验证模型是否支持 api_key = get_api_key(model) @@ -192,7 +331,6 @@ def chat_completions(): }, 404 dify_request = transform_openai_to_dify(openai_request, "/chat/completions") - logger.info(f"Transformed request: {json.dumps(dify_request, ensure_ascii=False)}") if not dify_request: logger.error("Failed to transform request") @@ -319,6 +457,28 @@ def chat_completions(): yield send_char(char, msg_id) time.sleep(0.001) # 固定使用最小延迟快速输出剩余内容 + # 只在启用会话记忆功能时处理conversation_id + if ENABLE_CONVERSATION_MEMORY: + conversation_id = dify_chunk.get("conversation_id") + history = dify_chunk.get("conversation_history", []) + + has_conversation_id = False + if history: + for msg in history: + if msg.get("role") == "assistant": + content = msg.get("content", "") + if decode_conversation_id(content) is not None: + has_conversation_id = True + break + + # 只在新会话且历史消息中没有会话ID时插入 + if conversation_id and not has_conversation_id: + logger.info(f"[Debug] Inserting conversation_id in stream: {conversation_id}") + encoded = encode_conversation_id(conversation_id) + logger.info(f"[Debug] Stream encoded content: {repr(encoded)}") + for char in encoded: + yield send_char(char, generate.message_id) + final_chunk = { "id": generate.message_id, "object": "chat.completion.chunk", @@ -379,8 +539,20 @@ def chat_completions(): dify_response = response.json() logger.info(f"Received response from Dify: {json.dumps(dify_response, ensure_ascii=False)}") + logger.info(f"[Debug] Response content: {repr(dify_response.get('answer', ''))}") openai_response = transform_dify_to_openai(dify_response, model=model) - return openai_response + conversation_id = dify_response.get("conversation_id") + if conversation_id: + # 在响应头中传递conversation_id + return Response( + json.dumps(openai_response), + content_type='application/json', + headers={ + 'Conversation-Id': conversation_id + } + ) + else: + return openai_response except httpx.RequestError as e: error_msg = f"Failed to connect to Dify: {str(e)}" logger.error(error_msg)