Merge pull request #4 from yoursbest/main

add api-key based authentication
This commit is contained in:
LzSkyline 2025-03-08 22:40:07 +08:00 committed by GitHub
commit 69864b26e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 2 deletions

View File

@ -11,3 +11,6 @@ ENABLE_CONVERSATION_MEMORY=true
# Server Configuration # Server Configuration
SERVER_HOST="127.0.0.1" SERVER_HOST="127.0.0.1"
SERVER_PORT=5000 SERVER_PORT=5000
# OpenAI compatable API Keys
VALID_API_KEYS="sk-abc123,sk-def456"

44
main.py
View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import asyncio import asyncio
from flask import Flask, request, Response, stream_with_context from flask import Flask, request, Response, stream_with_context, jsonify
import httpx import httpx
import time import time
from dotenv import load_dotenv from dotenv import load_dotenv
@ -21,6 +21,9 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
# 从环境变量读取有效的API密钥逗号分隔
VALID_API_KEYS = [key.strip() for key in os.getenv("VALID_API_KEYS", "").split(",") if key]
# 获取会话ID记忆功能开关配置 # 获取会话ID记忆功能开关配置
ENABLE_CONVERSATION_MEMORY = os.getenv('ENABLE_CONVERSATION_MEMORY', 'true').lower() == 'true' ENABLE_CONVERSATION_MEMORY = os.getenv('ENABLE_CONVERSATION_MEMORY', 'true').lower() == 'true'
@ -312,6 +315,41 @@ def decode_conversation_id(content):
@app.route('/v1/chat/completions', methods=['POST']) @app.route('/v1/chat/completions', methods=['POST'])
def chat_completions(): def chat_completions():
try: try:
# 新增验证API密钥
auth_header = request.headers.get('Authorization')
if not auth_header:
return jsonify({
"error": {
"message": "Missing Authorization header",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key"
}
}), 401
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != 'bearer':
return jsonify({
"error": {
"message": "Invalid Authorization header format. Expected: Bearer <API_KEY>",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key"
}
}), 401
provided_api_key = parts[1]
if provided_api_key not in VALID_API_KEYS:
return jsonify({
"error": {
"message": "Invalid API key",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key"
}
}), 401
# 继续处理原始逻辑
openai_request = request.get_json() openai_request = request.get_json()
logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}") logger.info(f"Received request: {json.dumps(openai_request, ensure_ascii=False)}")
@ -593,7 +631,11 @@ def list_models():
logger.info(f"Available models: {json.dumps(response, ensure_ascii=False)}") logger.info(f"Available models: {json.dumps(response, ensure_ascii=False)}")
return response return response
# 在main.py的最后初始化时添加环境变量检查
if __name__ == '__main__': if __name__ == '__main__':
if not VALID_API_KEYS:
print("Warning: No API keys configured. Set the VALID_API_KEYS environment variable with comma-separated keys.")
# 启动时初始化模型信息 # 启动时初始化模型信息
asyncio.run(model_manager.refresh_model_info()) asyncio.run(model_manager.refresh_model_info())