diff --git a/main.py b/main.py index ae7d3e5..e0483a1 100644 --- a/main.py +++ b/main.py @@ -56,7 +56,7 @@ class DifyModelManager: headers=headers, params={"user": "default_user"} ) - + if response.status_code == 200: app_info = response.json() return app_info.get("name", "Unknown App") @@ -71,7 +71,7 @@ class DifyModelManager: """刷新所有应用信息""" self.name_to_api_key.clear() self.api_key_to_name.clear() - + for api_key in self.api_keys: app_name = await self.fetch_app_info(api_key) if app_name: @@ -112,14 +112,14 @@ def get_api_key(model_name): def transform_openai_to_dify(openai_request, endpoint): """将OpenAI格式的请求转换为Dify格式""" - + if endpoint == "/chat/completions": messages = openai_request.get("messages", []) stream = openai_request.get("stream", False) - + # 尝试从历史消息中提取conversation_id conversation_id = None - + # 提取system消息内容 system_content = "" system_messages = [msg for msg in messages if msg.get("role") == "system"] @@ -127,7 +127,7 @@ def transform_openai_to_dify(openai_request, endpoint): system_content = system_messages[0].get("content", "") # 记录找到的system消息 logger.info(f"Found system message: {system_content[:100]}{'...' if len(system_content) > 100 else ''}") - + if CONVERSATION_MEMORY_MODE == 2: # 零宽字符模式 if len(messages) > 1: # 遍历历史消息,找到最近的assistant消息 @@ -138,15 +138,15 @@ def transform_openai_to_dify(openai_request, endpoint): conversation_id = decode_conversation_id(content) if conversation_id: break - + # 获取最后一条用户消息 user_query = messages[-1]["content"] if messages and messages[-1].get("role") != "system" else "" - + # 如果有system消息且是首次对话(没有conversation_id),则将system内容添加到用户查询前 if system_content and not conversation_id: user_query = f"系统指令: {system_content}\n\n用户问题: {user_query}" logger.info(f"[零宽字符模式] 首次对话,添加system内容到查询前") - + dify_request = { "inputs": {}, "query": user_query, @@ -157,12 +157,12 @@ def transform_openai_to_dify(openai_request, endpoint): else: # history_message模式(默认) # 获取最后一条用户消息 user_query = messages[-1]["content"] if messages and messages[-1].get("role") != "system" else "" - + # 构造历史消息 if len(messages) > 1: history_messages = [] has_system_in_history = False - + # 检查历史消息中是否已经包含system消息 for msg in messages[:-1]: # 除了最后一条消息 role = msg.get("role", "") @@ -171,12 +171,12 @@ def transform_openai_to_dify(openai_request, endpoint): if role == "system": has_system_in_history = True history_messages.append(f"{role}: {content}") - + # 如果历史中没有system消息但现在有system消息,则添加到历史的最前面 if system_content and not has_system_in_history: history_messages.insert(0, f"system: {system_content}") logger.info(f"[history_message模式] 添加system内容到历史消息前") - + # 将历史消息添加到查询中 if history_messages: history_context = "\n\n".join(history_messages) @@ -184,7 +184,7 @@ def transform_openai_to_dify(openai_request, endpoint): elif system_content: # 没有历史消息但有system消息 user_query = f"系统指令: {system_content}\n\n用户问题: {user_query}" logger.info(f"[history_message模式] 首次对话,添加system内容到查询前") - + dify_request = { "inputs": {}, "query": user_query, @@ -193,21 +193,21 @@ def transform_openai_to_dify(openai_request, endpoint): } return dify_request - + return None def transform_dify_to_openai(dify_response, model="claude-3-5-sonnet-v2", stream=False): """将Dify格式的响应转换为OpenAI格式""" - + if not stream: # 首先获取回答内容,支持不同的响应模式 answer = "" mode = dify_response.get("mode", "") - + # 普通聊天模式 if "answer" in dify_response: answer = dify_response.get("answer", "") - + # 如果是Agent模式,需要从agent_thoughts中提取回答 elif "agent_thoughts" in dify_response: # Agent模式下通常最后一个thought包含最终答案 @@ -216,12 +216,12 @@ def transform_dify_to_openai(dify_response, model="claude-3-5-sonnet-v2", stream for thought in agent_thoughts: if thought.get("thought"): answer = thought.get("thought", "") - + # 只在零宽字符会话记忆模式时处理conversation_id if CONVERSATION_MEMORY_MODE == 2: conversation_id = dify_response.get("conversation_id", "") history = dify_response.get("conversation_history", []) - + # 检查历史消息中是否已经有会话ID has_conversation_id = False if history: @@ -231,14 +231,14 @@ def transform_dify_to_openai(dify_response, model="claude-3-5-sonnet-v2", stream if decode_conversation_id(content) is not None: has_conversation_id = True break - + # 只在新会话且历史消息中没有会话ID时插入 if conversation_id and not has_conversation_id: logger.info(f"[Debug] Inserting conversation_id: {conversation_id}, history_length: {len(history)}") encoded = encode_conversation_id(conversation_id) answer = answer + encoded logger.info(f"[Debug] Response content after insertion: {repr(answer)}") - + return { "id": dify_response.get("message_id", ""), "object": "chat.completion", @@ -277,11 +277,11 @@ def encode_conversation_id(conversation_id): """将conversation_id编码为不可见的字符序列""" if not conversation_id: return "" - + # 使用Base64编码减少长度 import base64 encoded = base64.b64encode(conversation_id.encode()).decode() - + # 使用8种不同的零宽字符表示3位数字 # 这样可以将编码长度进一步减少 char_map = { @@ -294,7 +294,7 @@ def encode_conversation_id(conversation_id): '6': '\u2061', # 函数应用 '7': '\u2062', # 不可见乘号 } - + # 将Base64字符串转换为八进制数字 result = [] for c in encoded: @@ -312,14 +312,14 @@ def encode_conversation_id(conversation_id): val = 63 else: # '=' val = 0 - + # 每个Base64字符可以产生2个3位数字 first = (val >> 3) & 0x7 second = val & 0x7 result.append(char_map[str(first)]) if c != '=': # 不编码填充字符的后半部分 result.append(char_map[str(second)]) - + return ''.join(result) def decode_conversation_id(content): @@ -336,17 +336,17 @@ def decode_conversation_id(content): '\u2061': '6', # 函数应用 '\u2062': '7', # 不可见乘号 } - + # 提取最后一段零宽字符序列 space_chars = [] for c in reversed(content): if c not in char_to_val: break space_chars.append(c) - + if not space_chars: return None - + # 将零宽字符转换回Base64字符串 space_chars.reverse() base64_chars = [] @@ -357,7 +357,7 @@ def decode_conversation_id(content): val = (first << 3) | second else: val = first << 3 - + # 转换回Base64字符 if val < 26: base64_chars.append(chr(val + ord('A'))) @@ -369,17 +369,17 @@ def decode_conversation_id(content): base64_chars.append('+') else: base64_chars.append('/') - + # 添加Base64填充 padding = len(base64_chars) % 4 if padding: base64_chars.extend(['='] * (4 - padding)) - + # 解码Base64字符串 import base64 base64_str = ''.join(base64_chars) return base64.b64decode(base64_str).decode() - + except Exception as e: logger.debug(f"Failed to decode conversation_id: {e}") return None @@ -424,9 +424,9 @@ def chat_completions(): # 继续处理原始逻辑 openai_request = request.get_json() logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}") - + model = openai_request.get("model", "claude-3-5-sonnet-v2") - + # 验证模型是否支持 api_key = get_api_key(model) if not api_key: @@ -439,9 +439,9 @@ def chat_completions(): "code": "model_not_found" } }, 404 - + dify_request = transform_openai_to_dify(openai_request, "/chat/completions") - + if not dify_request: logger.error("Failed to transform request") return { @@ -463,11 +463,11 @@ def chat_completions(): if stream: def generate(): client = httpx.Client(timeout=None) - + def flush_chunk(chunk_data): """Helper function to flush chunks immediately""" return chunk_data.encode('utf-8') - + def calculate_delay(buffer_size): """ 根据缓冲区大小动态计算延迟 @@ -481,7 +481,7 @@ def chat_completions(): return 0.01 # 20ms延迟 else: # 内容很少,使用较慢的速度 return 0.02 # 30ms延迟 - + def send_char(char, message_id): """Helper function to send single character""" openai_chunk = { @@ -499,10 +499,10 @@ def chat_completions(): } chunk_data = f"data: {json.dumps(openai_chunk)}\n\n" return flush_chunk(chunk_data) - + # 初始化缓冲区 output_buffer = [] - + try: with client.stream( 'POST', @@ -517,74 +517,76 @@ def chat_completions(): ) as response: generate.message_id = None buffer = "" - + for raw_bytes in response.iter_raw(): if not raw_bytes: continue - + try: buffer += raw_bytes.decode('utf-8') - + while '\n' in buffer: line, buffer = buffer.split('\n', 1) line = line.strip() - + if not line or not line.startswith('data: '): continue - + try: json_str = line[6:] dify_chunk = json.loads(json_str) - + if dify_chunk.get("event") == "message" and "answer" in dify_chunk: current_answer = dify_chunk["answer"] if not current_answer: continue - + message_id = dify_chunk.get("message_id", "") if not generate.message_id: generate.message_id = message_id - - # 将当前批次的字符添加到输出缓冲区 - for char in current_answer: - output_buffer.append((char, generate.message_id)) - - # 根据缓冲区大小动态调整输出速度 - while output_buffer: - char, msg_id = output_buffer.pop(0) - yield send_char(char, msg_id) - # 根据剩余缓冲区大小计算延迟 - delay = calculate_delay(len(output_buffer)) - time.sleep(delay) - + + # # 将当前批次的字符添加到输出缓冲区 + # for char in current_answer: + # output_buffer.append((char, generate.message_id)) + + # # 根据缓冲区大小动态调整输出速度 + # while output_buffer: + # char, msg_id = output_buffer.pop(0) + # yield send_char(char, msg_id) + # # 根据剩余缓冲区大小计算延迟 + # delay = calculate_delay(len(output_buffer)) + # time.sleep(delay) + yield send_char(current_answer, message_id) + # 立即继续处理下一个请求 continue - + # 处理Agent模式的消息事件 elif dify_chunk.get("event") == "agent_message" and "answer" in dify_chunk: current_answer = dify_chunk["answer"] if not current_answer: continue - + message_id = dify_chunk.get("message_id", "") if not generate.message_id: generate.message_id = message_id - - # 将当前批次的字符添加到输出缓冲区 - for char in current_answer: - output_buffer.append((char, generate.message_id)) - - # 根据缓冲区大小动态调整输出速度 - while output_buffer: - char, msg_id = output_buffer.pop(0) - yield send_char(char, msg_id) - # 根据剩余缓冲区大小计算延迟 - delay = calculate_delay(len(output_buffer)) - time.sleep(delay) - + + # # 将当前批次的字符添加到输出缓冲区 + # for char in current_answer: + # output_buffer.append((char, generate.message_id)) + + # # 根据缓冲区大小动态调整输出速度 + # while output_buffer: + # char, msg_id = output_buffer.pop(0) + # yield send_char(char, msg_id) + # # 根据剩余缓冲区大小计算延迟 + # delay = calculate_delay(len(output_buffer)) + # time.sleep(delay) + yield send_char(current_answer, message_id) + # 立即继续处理下一个请求 continue - + # 处理Agent的思考过程,记录日志但不输出给用户 elif dify_chunk.get("event") == "agent_thought": thought_id = dify_chunk.get("id", "") @@ -592,7 +594,7 @@ def chat_completions(): tool = dify_chunk.get("tool", "") tool_input = dify_chunk.get("tool_input", "") observation = dify_chunk.get("observation", "") - + logger.info(f"[Agent Thought] ID: {thought_id}, Tool: {tool}") if thought: logger.info(f"[Agent Thought] Thought: {thought}") @@ -600,35 +602,35 @@ def chat_completions(): logger.info(f"[Agent Thought] Tool Input: {tool_input}") if observation: logger.info(f"[Agent Thought] Observation: {observation}") - + # 获取message_id以关联思考和最终输出 message_id = dify_chunk.get("message_id", "") if not generate.message_id and message_id: generate.message_id = message_id - + continue - + # 处理消息中的文件(如图片),记录日志但不直接输出给用户 elif dify_chunk.get("event") == "message_file": file_id = dify_chunk.get("id", "") file_type = dify_chunk.get("type", "") file_url = dify_chunk.get("url", "") - + logger.info(f"[Message File] ID: {file_id}, Type: {file_type}, URL: {file_url}") continue - + elif dify_chunk.get("event") == "message_end": # 快速输出剩余内容 while output_buffer: char, msg_id = output_buffer.pop(0) yield send_char(char, msg_id) time.sleep(0.001) # 固定使用最小延迟快速输出剩余内容 - + # 只在零宽字符会话记忆模式时处理conversation_id if CONVERSATION_MEMORY_MODE == 2: conversation_id = dify_chunk.get("conversation_id") history = dify_chunk.get("conversation_history", []) - + has_conversation_id = False if history: for msg in history: @@ -637,7 +639,7 @@ def chat_completions(): if decode_conversation_id(content) is not None: has_conversation_id = True break - + # 只在新会话且历史消息中没有会话ID时插入 if conversation_id and not has_conversation_id: logger.info(f"[Debug] Inserting conversation_id in stream: {conversation_id}") @@ -645,7 +647,7 @@ def chat_completions(): logger.info(f"[Debug] Stream encoded content: {repr(encoded)}") for char in encoded: yield send_char(char, generate.message_id) - + final_chunk = { "id": generate.message_id, "object": "chat.completion.chunk", @@ -659,11 +661,11 @@ def chat_completions(): } yield flush_chunk(f"data: {json.dumps(final_chunk)}\n\n") yield flush_chunk("data: [DONE]\n\n") - + except json.JSONDecodeError as e: logger.error(f"JSON decode error: {str(e)}") continue - + except Exception as e: logger.error(f"Error processing chunk: {str(e)}") continue @@ -692,7 +694,7 @@ def chat_completions(): json=dify_request, headers=headers ) - + if response.status_code != 200: error_msg = f"Dify API error: {response.text}" logger.error(f"Request failed: {error_msg}") @@ -746,13 +748,13 @@ def chat_completions(): def list_models(): """返回可用的模型列表""" logger.info("Listing available models") - + # 刷新模型信息 asyncio.run(model_manager.refresh_model_info()) - + # 获取可用模型列表 available_models = model_manager.get_available_models() - + response = { "object": "list", "data": available_models @@ -764,10 +766,10 @@ def list_models(): if __name__ == '__main__': if not VALID_API_KEYS: print("Warning: No API keys configured. Set the VALID_API_KEYS environment variable with comma-separated keys.") - + # 启动时初始化模型信息 asyncio.run(model_manager.refresh_model_info()) - + host = os.getenv("SERVER_HOST", "127.0.0.1") port = int(os.getenv("SERVER_PORT", 5000)) logger.info(f"Starting server on http://{host}:{port}")