feat: 添加可配置的会话记忆功能. Resolves #1

This commit is contained in:
LzSkyline 2025-02-21 20:10:36 +08:00
parent 055bd1b863
commit cb80f7e0f3
3 changed files with 213 additions and 20 deletions

View File

@ -5,6 +5,9 @@ DIFY_API_KEYS=app-xxxxxxxxxxxxxxxx,app-yyyyyyyyyyyyyyyy,app-zzzzzzzzzzzzzzzz
# Dify API Base URL # Dify API Base URL
DIFY_API_BASE="https://api.dify.example.com/v1" DIFY_API_BASE="https://api.dify.example.com/v1"
# Keep Conversationtrue/false
ENABLE_CONVERSATION_MEMORY=true
# Server Configuration # Server Configuration
SERVER_HOST="127.0.0.1" SERVER_HOST="127.0.0.1"
SERVER_PORT=5000 SERVER_PORT=5000

View File

@ -132,7 +132,25 @@ for chunk in response:
print(chunk.choices[0].delta.content or "", end="") print(chunk.choices[0].delta.content or "", end="")
``` ```
## 特性说明 ## 特性
### 会话记忆功能
该代理支持自动记忆会话上下文,无需客户端进行额外处理。当启用此功能时:
- 在每个新会话的第一条回复中会自动嵌入不可见的会话ID
- 后续的消息会自动继承会话上下文,保持对话连贯性
- 使用零宽字符编码,(大部分情况下)不会影响消息的正常显示
可以通过环境变量控制此功能:
```shell
# 在 .env 文件中设置
ENABLE_CONVERSATION_MEMORY=true # 启用会话记忆功能
ENABLE_CONVERSATION_MEMORY=false # 禁用会话记忆功能
```
默认情况下此功能是启用的。如果您的应用场景不需要保持会话上下文,可以选择关闭此功能。
### 流式输出优化 ### 流式输出优化

210
main.py
View File

@ -10,17 +10,20 @@ import ast
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s' format='%(asctime)s - %(levelname)s - %(message)s'
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 设置 httpx 的日志级别 # 设置httpx的日志级别为WARNING减少不必要的输出
logging.getLogger("httpx").setLevel(logging.DEBUG) logging.getLogger("httpx").setLevel(logging.WARNING)
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
# 获取会话ID记忆功能开关配置
ENABLE_CONVERSATION_MEMORY = os.getenv('ENABLE_CONVERSATION_MEMORY', 'true').lower() == 'true'
class DifyModelManager: class DifyModelManager:
def __init__(self): def __init__(self):
self.api_keys = [] self.api_keys = []
@ -109,24 +112,26 @@ def transform_openai_to_dify(openai_request, endpoint):
messages = openai_request.get("messages", []) messages = openai_request.get("messages", [])
stream = openai_request.get("stream", False) stream = openai_request.get("stream", False)
# 尝试从历史消息中提取conversation_id
conversation_id = None
if len(messages) > 1:
# 遍历历史消息找到最近的assistant消息
for msg in reversed(messages[:-1]): # 除了最后一条消息
if msg.get("role") == "assistant":
content = msg.get("content", "")
# 尝试解码conversation_id
conversation_id = decode_conversation_id(content)
if conversation_id:
break
dify_request = { dify_request = {
"inputs": {}, "inputs": {},
"query": messages[-1]["content"] if messages else "", "query": messages[-1]["content"] if messages else "",
"response_mode": "streaming" if stream else "blocking", "response_mode": "streaming" if stream else "blocking",
"conversation_id": openai_request.get("conversation_id", None), "conversation_id": conversation_id,
"user": openai_request.get("user", "default_user") "user": openai_request.get("user", "default_user")
} }
# 添加历史消息
if len(messages) > 1:
history = []
for msg in messages[:-1]: # 除了最后一条消息
history.append({
"role": msg["role"],
"content": msg["content"]
})
dify_request["conversation_history"] = history
return dify_request return dify_request
return None return None
@ -135,16 +140,40 @@ def transform_dify_to_openai(dify_response, model="claude-3-5-sonnet-v2", stream
"""将Dify格式的响应转换为OpenAI格式""" """将Dify格式的响应转换为OpenAI格式"""
if not stream: if not stream:
answer = dify_response.get("answer", "")
# 只在启用会话记忆功能时处理conversation_id
if ENABLE_CONVERSATION_MEMORY:
conversation_id = dify_response.get("conversation_id", "")
history = dify_response.get("conversation_history", [])
# 检查历史消息中是否已经有会话ID
has_conversation_id = False
if history:
for msg in history:
if msg.get("role") == "assistant":
content = msg.get("content", "")
if decode_conversation_id(content) is not None:
has_conversation_id = True
break
# 只在新会话且历史消息中没有会话ID时插入
if conversation_id and not has_conversation_id:
logger.info(f"[Debug] Inserting conversation_id: {conversation_id}, history_length: {len(history)}")
encoded = encode_conversation_id(conversation_id)
answer = answer + encoded
logger.info(f"[Debug] Response content after insertion: {repr(answer)}")
return { return {
"id": dify_response.get("message_id", ""), "id": dify_response.get("message_id", ""),
"object": "chat.completion", "object": "chat.completion",
"created": dify_response.get("created", int(time.time())), "created": dify_response.get("created", int(time.time())),
"model": model, # 使用实际使用的模型 "model": model,
"choices": [{ "choices": [{
"index": 0, "index": 0,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": dify_response.get("answer", "") "content": answer
}, },
"finish_reason": "stop" "finish_reason": "stop"
}] }]
@ -169,6 +198,117 @@ def create_openai_stream_response(content, message_id, model="claude-3-5-sonnet-
}] }]
} }
def encode_conversation_id(conversation_id):
"""将conversation_id编码为不可见的字符序列"""
if not conversation_id:
return ""
# 使用Base64编码减少长度
import base64
encoded = base64.b64encode(conversation_id.encode()).decode()
# 使用8种不同的零宽字符表示3位数字
# 这样可以将编码长度进一步减少
char_map = {
'0': '\u200b', # 零宽空格
'1': '\u200c', # 零宽非连接符
'2': '\u200d', # 零宽连接符
'3': '\ufeff', # 零宽非断空格
'4': '\u2060', # 词组连接符
'5': '\u180e', # 蒙古语元音分隔符
'6': '\u2061', # 函数应用
'7': '\u2062', # 不可见乘号
}
# 将Base64字符串转换为八进制数字
result = []
for c in encoded:
# 将每个字符转换为8进制数字0-7
if c.isalpha():
if c.isupper():
val = ord(c) - ord('A')
else:
val = ord(c) - ord('a') + 26
elif c.isdigit():
val = int(c) + 52
elif c == '+':
val = 62
elif c == '/':
val = 63
else: # '='
val = 0
# 每个Base64字符可以产生2个3位数字
first = (val >> 3) & 0x7
second = val & 0x7
result.append(char_map[str(first)])
if c != '=': # 不编码填充字符的后半部分
result.append(char_map[str(second)])
return ''.join(result)
def decode_conversation_id(content):
"""从消息内容中解码conversation_id"""
try:
# 零宽字符到3位数字的映射
char_to_val = {
'\u200b': '0', # 零宽空格
'\u200c': '1', # 零宽非连接符
'\u200d': '2', # 零宽连接符
'\ufeff': '3', # 零宽非断空格
'\u2060': '4', # 词组连接符
'\u180e': '5', # 蒙古语元音分隔符
'\u2061': '6', # 函数应用
'\u2062': '7', # 不可见乘号
}
# 提取最后一段零宽字符序列
space_chars = []
for c in reversed(content):
if c not in char_to_val:
break
space_chars.append(c)
if not space_chars:
return None
# 将零宽字符转换回Base64字符串
space_chars.reverse()
base64_chars = []
for i in range(0, len(space_chars), 2):
first = int(char_to_val[space_chars[i]], 8)
if i + 1 < len(space_chars):
second = int(char_to_val[space_chars[i + 1]], 8)
val = (first << 3) | second
else:
val = first << 3
# 转换回Base64字符
if val < 26:
base64_chars.append(chr(val + ord('A')))
elif val < 52:
base64_chars.append(chr(val - 26 + ord('a')))
elif val < 62:
base64_chars.append(str(val - 52))
elif val == 62:
base64_chars.append('+')
else:
base64_chars.append('/')
# 添加Base64填充
padding = len(base64_chars) % 4
if padding:
base64_chars.extend(['='] * (4 - padding))
# 解码Base64字符串
import base64
base64_str = ''.join(base64_chars)
return base64.b64decode(base64_str).decode()
except Exception as e:
logger.debug(f"Failed to decode conversation_id: {e}")
return None
@app.route('/v1/chat/completions', methods=['POST']) @app.route('/v1/chat/completions', methods=['POST'])
def chat_completions(): def chat_completions():
try: try:
@ -176,7 +316,6 @@ def chat_completions():
logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}") logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}")
model = openai_request.get("model", "claude-3-5-sonnet-v2") model = openai_request.get("model", "claude-3-5-sonnet-v2")
logger.info(f"Using model: {model}")
# 验证模型是否支持 # 验证模型是否支持
api_key = get_api_key(model) api_key = get_api_key(model)
@ -192,7 +331,6 @@ def chat_completions():
}, 404 }, 404
dify_request = transform_openai_to_dify(openai_request, "/chat/completions") dify_request = transform_openai_to_dify(openai_request, "/chat/completions")
logger.info(f"Transformed request: {json.dumps(dify_request, ensure_ascii=False)}")
if not dify_request: if not dify_request:
logger.error("Failed to transform request") logger.error("Failed to transform request")
@ -319,6 +457,28 @@ def chat_completions():
yield send_char(char, msg_id) yield send_char(char, msg_id)
time.sleep(0.001) # 固定使用最小延迟快速输出剩余内容 time.sleep(0.001) # 固定使用最小延迟快速输出剩余内容
# 只在启用会话记忆功能时处理conversation_id
if ENABLE_CONVERSATION_MEMORY:
conversation_id = dify_chunk.get("conversation_id")
history = dify_chunk.get("conversation_history", [])
has_conversation_id = False
if history:
for msg in history:
if msg.get("role") == "assistant":
content = msg.get("content", "")
if decode_conversation_id(content) is not None:
has_conversation_id = True
break
# 只在新会话且历史消息中没有会话ID时插入
if conversation_id and not has_conversation_id:
logger.info(f"[Debug] Inserting conversation_id in stream: {conversation_id}")
encoded = encode_conversation_id(conversation_id)
logger.info(f"[Debug] Stream encoded content: {repr(encoded)}")
for char in encoded:
yield send_char(char, generate.message_id)
final_chunk = { final_chunk = {
"id": generate.message_id, "id": generate.message_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -379,8 +539,20 @@ def chat_completions():
dify_response = response.json() dify_response = response.json()
logger.info(f"Received response from Dify: {json.dumps(dify_response, ensure_ascii=False)}") logger.info(f"Received response from Dify: {json.dumps(dify_response, ensure_ascii=False)}")
logger.info(f"[Debug] Response content: {repr(dify_response.get('answer', ''))}")
openai_response = transform_dify_to_openai(dify_response, model=model) openai_response = transform_dify_to_openai(dify_response, model=model)
return openai_response conversation_id = dify_response.get("conversation_id")
if conversation_id:
# 在响应头中传递conversation_id
return Response(
json.dumps(openai_response),
content_type='application/json',
headers={
'Conversation-Id': conversation_id
}
)
else:
return openai_response
except httpx.RequestError as e: except httpx.RequestError as e:
error_msg = f"Failed to connect to Dify: {str(e)}" error_msg = f"Failed to connect to Dify: {str(e)}"
logger.error(error_msg) logger.error(error_msg)