feat: 添加可配置的会话记忆功能. Resolves #1
This commit is contained in:
parent
055bd1b863
commit
cb80f7e0f3
@ -5,6 +5,9 @@ DIFY_API_KEYS=app-xxxxxxxxxxxxxxxx,app-yyyyyyyyyyyyyyyy,app-zzzzzzzzzzzzzzzz
|
|||||||
# Dify API Base URL
|
# Dify API Base URL
|
||||||
DIFY_API_BASE="https://api.dify.example.com/v1"
|
DIFY_API_BASE="https://api.dify.example.com/v1"
|
||||||
|
|
||||||
|
# Keep Conversation(true/false)
|
||||||
|
ENABLE_CONVERSATION_MEMORY=true
|
||||||
|
|
||||||
# Server Configuration
|
# Server Configuration
|
||||||
SERVER_HOST="127.0.0.1"
|
SERVER_HOST="127.0.0.1"
|
||||||
SERVER_PORT=5000
|
SERVER_PORT=5000
|
||||||
20
README.md
20
README.md
@ -132,7 +132,25 @@ for chunk in response:
|
|||||||
print(chunk.choices[0].delta.content or "", end="")
|
print(chunk.choices[0].delta.content or "", end="")
|
||||||
```
|
```
|
||||||
|
|
||||||
## 特性说明
|
## 特性
|
||||||
|
|
||||||
|
### 会话记忆功能
|
||||||
|
|
||||||
|
该代理支持自动记忆会话上下文,无需客户端进行额外处理。当启用此功能时:
|
||||||
|
|
||||||
|
- 在每个新会话的第一条回复中,会自动嵌入不可见的会话ID
|
||||||
|
- 后续的消息会自动继承会话上下文,保持对话连贯性
|
||||||
|
- 使用零宽字符编码,(大部分情况下)不会影响消息的正常显示
|
||||||
|
|
||||||
|
可以通过环境变量控制此功能:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 在 .env 文件中设置
|
||||||
|
ENABLE_CONVERSATION_MEMORY=true # 启用会话记忆功能
|
||||||
|
ENABLE_CONVERSATION_MEMORY=false # 禁用会话记忆功能
|
||||||
|
```
|
||||||
|
|
||||||
|
默认情况下此功能是启用的。如果您的应用场景不需要保持会话上下文,可以选择关闭此功能。
|
||||||
|
|
||||||
### 流式输出优化
|
### 流式输出优化
|
||||||
|
|
||||||
|
|||||||
208
main.py
208
main.py
@ -10,17 +10,20 @@ import ast
|
|||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 设置 httpx 的日志级别
|
# 设置httpx的日志级别为WARNING,减少不必要的输出
|
||||||
logging.getLogger("httpx").setLevel(logging.DEBUG)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# 获取会话ID记忆功能开关配置
|
||||||
|
ENABLE_CONVERSATION_MEMORY = os.getenv('ENABLE_CONVERSATION_MEMORY', 'true').lower() == 'true'
|
||||||
|
|
||||||
class DifyModelManager:
|
class DifyModelManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.api_keys = []
|
self.api_keys = []
|
||||||
@ -109,24 +112,26 @@ def transform_openai_to_dify(openai_request, endpoint):
|
|||||||
messages = openai_request.get("messages", [])
|
messages = openai_request.get("messages", [])
|
||||||
stream = openai_request.get("stream", False)
|
stream = openai_request.get("stream", False)
|
||||||
|
|
||||||
|
# 尝试从历史消息中提取conversation_id
|
||||||
|
conversation_id = None
|
||||||
|
if len(messages) > 1:
|
||||||
|
# 遍历历史消息,找到最近的assistant消息
|
||||||
|
for msg in reversed(messages[:-1]): # 除了最后一条消息
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
content = msg.get("content", "")
|
||||||
|
# 尝试解码conversation_id
|
||||||
|
conversation_id = decode_conversation_id(content)
|
||||||
|
if conversation_id:
|
||||||
|
break
|
||||||
|
|
||||||
dify_request = {
|
dify_request = {
|
||||||
"inputs": {},
|
"inputs": {},
|
||||||
"query": messages[-1]["content"] if messages else "",
|
"query": messages[-1]["content"] if messages else "",
|
||||||
"response_mode": "streaming" if stream else "blocking",
|
"response_mode": "streaming" if stream else "blocking",
|
||||||
"conversation_id": openai_request.get("conversation_id", None),
|
"conversation_id": conversation_id,
|
||||||
"user": openai_request.get("user", "default_user")
|
"user": openai_request.get("user", "default_user")
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加历史消息
|
|
||||||
if len(messages) > 1:
|
|
||||||
history = []
|
|
||||||
for msg in messages[:-1]: # 除了最后一条消息
|
|
||||||
history.append({
|
|
||||||
"role": msg["role"],
|
|
||||||
"content": msg["content"]
|
|
||||||
})
|
|
||||||
dify_request["conversation_history"] = history
|
|
||||||
|
|
||||||
return dify_request
|
return dify_request
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -135,16 +140,40 @@ def transform_dify_to_openai(dify_response, model="claude-3-5-sonnet-v2", stream
|
|||||||
"""将Dify格式的响应转换为OpenAI格式"""
|
"""将Dify格式的响应转换为OpenAI格式"""
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
|
answer = dify_response.get("answer", "")
|
||||||
|
|
||||||
|
# 只在启用会话记忆功能时处理conversation_id
|
||||||
|
if ENABLE_CONVERSATION_MEMORY:
|
||||||
|
conversation_id = dify_response.get("conversation_id", "")
|
||||||
|
history = dify_response.get("conversation_history", [])
|
||||||
|
|
||||||
|
# 检查历史消息中是否已经有会话ID
|
||||||
|
has_conversation_id = False
|
||||||
|
if history:
|
||||||
|
for msg in history:
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if decode_conversation_id(content) is not None:
|
||||||
|
has_conversation_id = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# 只在新会话且历史消息中没有会话ID时插入
|
||||||
|
if conversation_id and not has_conversation_id:
|
||||||
|
logger.info(f"[Debug] Inserting conversation_id: {conversation_id}, history_length: {len(history)}")
|
||||||
|
encoded = encode_conversation_id(conversation_id)
|
||||||
|
answer = answer + encoded
|
||||||
|
logger.info(f"[Debug] Response content after insertion: {repr(answer)}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": dify_response.get("message_id", ""),
|
"id": dify_response.get("message_id", ""),
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": dify_response.get("created", int(time.time())),
|
"created": dify_response.get("created", int(time.time())),
|
||||||
"model": model, # 使用实际使用的模型
|
"model": model,
|
||||||
"choices": [{
|
"choices": [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": dify_response.get("answer", "")
|
"content": answer
|
||||||
},
|
},
|
||||||
"finish_reason": "stop"
|
"finish_reason": "stop"
|
||||||
}]
|
}]
|
||||||
@ -169,6 +198,117 @@ def create_openai_stream_response(content, message_id, model="claude-3-5-sonnet-
|
|||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def encode_conversation_id(conversation_id):
|
||||||
|
"""将conversation_id编码为不可见的字符序列"""
|
||||||
|
if not conversation_id:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 使用Base64编码减少长度
|
||||||
|
import base64
|
||||||
|
encoded = base64.b64encode(conversation_id.encode()).decode()
|
||||||
|
|
||||||
|
# 使用8种不同的零宽字符表示3位数字
|
||||||
|
# 这样可以将编码长度进一步减少
|
||||||
|
char_map = {
|
||||||
|
'0': '\u200b', # 零宽空格
|
||||||
|
'1': '\u200c', # 零宽非连接符
|
||||||
|
'2': '\u200d', # 零宽连接符
|
||||||
|
'3': '\ufeff', # 零宽非断空格
|
||||||
|
'4': '\u2060', # 词组连接符
|
||||||
|
'5': '\u180e', # 蒙古语元音分隔符
|
||||||
|
'6': '\u2061', # 函数应用
|
||||||
|
'7': '\u2062', # 不可见乘号
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将Base64字符串转换为八进制数字
|
||||||
|
result = []
|
||||||
|
for c in encoded:
|
||||||
|
# 将每个字符转换为8进制数字(0-7)
|
||||||
|
if c.isalpha():
|
||||||
|
if c.isupper():
|
||||||
|
val = ord(c) - ord('A')
|
||||||
|
else:
|
||||||
|
val = ord(c) - ord('a') + 26
|
||||||
|
elif c.isdigit():
|
||||||
|
val = int(c) + 52
|
||||||
|
elif c == '+':
|
||||||
|
val = 62
|
||||||
|
elif c == '/':
|
||||||
|
val = 63
|
||||||
|
else: # '='
|
||||||
|
val = 0
|
||||||
|
|
||||||
|
# 每个Base64字符可以产生2个3位数字
|
||||||
|
first = (val >> 3) & 0x7
|
||||||
|
second = val & 0x7
|
||||||
|
result.append(char_map[str(first)])
|
||||||
|
if c != '=': # 不编码填充字符的后半部分
|
||||||
|
result.append(char_map[str(second)])
|
||||||
|
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
def decode_conversation_id(content):
|
||||||
|
"""从消息内容中解码conversation_id"""
|
||||||
|
try:
|
||||||
|
# 零宽字符到3位数字的映射
|
||||||
|
char_to_val = {
|
||||||
|
'\u200b': '0', # 零宽空格
|
||||||
|
'\u200c': '1', # 零宽非连接符
|
||||||
|
'\u200d': '2', # 零宽连接符
|
||||||
|
'\ufeff': '3', # 零宽非断空格
|
||||||
|
'\u2060': '4', # 词组连接符
|
||||||
|
'\u180e': '5', # 蒙古语元音分隔符
|
||||||
|
'\u2061': '6', # 函数应用
|
||||||
|
'\u2062': '7', # 不可见乘号
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取最后一段零宽字符序列
|
||||||
|
space_chars = []
|
||||||
|
for c in reversed(content):
|
||||||
|
if c not in char_to_val:
|
||||||
|
break
|
||||||
|
space_chars.append(c)
|
||||||
|
|
||||||
|
if not space_chars:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 将零宽字符转换回Base64字符串
|
||||||
|
space_chars.reverse()
|
||||||
|
base64_chars = []
|
||||||
|
for i in range(0, len(space_chars), 2):
|
||||||
|
first = int(char_to_val[space_chars[i]], 8)
|
||||||
|
if i + 1 < len(space_chars):
|
||||||
|
second = int(char_to_val[space_chars[i + 1]], 8)
|
||||||
|
val = (first << 3) | second
|
||||||
|
else:
|
||||||
|
val = first << 3
|
||||||
|
|
||||||
|
# 转换回Base64字符
|
||||||
|
if val < 26:
|
||||||
|
base64_chars.append(chr(val + ord('A')))
|
||||||
|
elif val < 52:
|
||||||
|
base64_chars.append(chr(val - 26 + ord('a')))
|
||||||
|
elif val < 62:
|
||||||
|
base64_chars.append(str(val - 52))
|
||||||
|
elif val == 62:
|
||||||
|
base64_chars.append('+')
|
||||||
|
else:
|
||||||
|
base64_chars.append('/')
|
||||||
|
|
||||||
|
# 添加Base64填充
|
||||||
|
padding = len(base64_chars) % 4
|
||||||
|
if padding:
|
||||||
|
base64_chars.extend(['='] * (4 - padding))
|
||||||
|
|
||||||
|
# 解码Base64字符串
|
||||||
|
import base64
|
||||||
|
base64_str = ''.join(base64_chars)
|
||||||
|
return base64.b64decode(base64_str).decode()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to decode conversation_id: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
@app.route('/v1/chat/completions', methods=['POST'])
|
@app.route('/v1/chat/completions', methods=['POST'])
|
||||||
def chat_completions():
|
def chat_completions():
|
||||||
try:
|
try:
|
||||||
@ -176,7 +316,6 @@ def chat_completions():
|
|||||||
logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}")
|
logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}")
|
||||||
|
|
||||||
model = openai_request.get("model", "claude-3-5-sonnet-v2")
|
model = openai_request.get("model", "claude-3-5-sonnet-v2")
|
||||||
logger.info(f"Using model: {model}")
|
|
||||||
|
|
||||||
# 验证模型是否支持
|
# 验证模型是否支持
|
||||||
api_key = get_api_key(model)
|
api_key = get_api_key(model)
|
||||||
@ -192,7 +331,6 @@ def chat_completions():
|
|||||||
}, 404
|
}, 404
|
||||||
|
|
||||||
dify_request = transform_openai_to_dify(openai_request, "/chat/completions")
|
dify_request = transform_openai_to_dify(openai_request, "/chat/completions")
|
||||||
logger.info(f"Transformed request: {json.dumps(dify_request, ensure_ascii=False)}")
|
|
||||||
|
|
||||||
if not dify_request:
|
if not dify_request:
|
||||||
logger.error("Failed to transform request")
|
logger.error("Failed to transform request")
|
||||||
@ -319,6 +457,28 @@ def chat_completions():
|
|||||||
yield send_char(char, msg_id)
|
yield send_char(char, msg_id)
|
||||||
time.sleep(0.001) # 固定使用最小延迟快速输出剩余内容
|
time.sleep(0.001) # 固定使用最小延迟快速输出剩余内容
|
||||||
|
|
||||||
|
# 只在启用会话记忆功能时处理conversation_id
|
||||||
|
if ENABLE_CONVERSATION_MEMORY:
|
||||||
|
conversation_id = dify_chunk.get("conversation_id")
|
||||||
|
history = dify_chunk.get("conversation_history", [])
|
||||||
|
|
||||||
|
has_conversation_id = False
|
||||||
|
if history:
|
||||||
|
for msg in history:
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if decode_conversation_id(content) is not None:
|
||||||
|
has_conversation_id = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# 只在新会话且历史消息中没有会话ID时插入
|
||||||
|
if conversation_id and not has_conversation_id:
|
||||||
|
logger.info(f"[Debug] Inserting conversation_id in stream: {conversation_id}")
|
||||||
|
encoded = encode_conversation_id(conversation_id)
|
||||||
|
logger.info(f"[Debug] Stream encoded content: {repr(encoded)}")
|
||||||
|
for char in encoded:
|
||||||
|
yield send_char(char, generate.message_id)
|
||||||
|
|
||||||
final_chunk = {
|
final_chunk = {
|
||||||
"id": generate.message_id,
|
"id": generate.message_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -379,7 +539,19 @@ def chat_completions():
|
|||||||
|
|
||||||
dify_response = response.json()
|
dify_response = response.json()
|
||||||
logger.info(f"Received response from Dify: {json.dumps(dify_response, ensure_ascii=False)}")
|
logger.info(f"Received response from Dify: {json.dumps(dify_response, ensure_ascii=False)}")
|
||||||
|
logger.info(f"[Debug] Response content: {repr(dify_response.get('answer', ''))}")
|
||||||
openai_response = transform_dify_to_openai(dify_response, model=model)
|
openai_response = transform_dify_to_openai(dify_response, model=model)
|
||||||
|
conversation_id = dify_response.get("conversation_id")
|
||||||
|
if conversation_id:
|
||||||
|
# 在响应头中传递conversation_id
|
||||||
|
return Response(
|
||||||
|
json.dumps(openai_response),
|
||||||
|
content_type='application/json',
|
||||||
|
headers={
|
||||||
|
'Conversation-Id': conversation_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
return openai_response
|
return openai_response
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
error_msg = f"Failed to connect to Dify: {str(e)}"
|
error_msg = f"Failed to connect to Dify: {str(e)}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user