@@ -1,228 +0,0 @@
""" LLM 聊天服务 """
import asyncio
import json
from pathlib import Path
from typing import Any , Dict , List , Optional
from heurams . context import config_var
from heurams . providers . llm import providers as prov
from heurams . services . logger import get_logger
logger = get_logger ( __name__ )
class ChatSession :
""" 聊天会话,管理单个对话的历史和参数 """
def __init__ (
self , session_id : str , llm_provider , system_prompt : str = " " , * * default_params
) :
""" 初始化聊天会话
Args:
session_id: 会话唯一标识符
llm_provider: LLM 提供者实例
system_prompt: 系统提示词
**default_params: 默认参数( temperature, max_tokens, model 等)
"""
self . session_id = session_id
self . llm_provider = llm_provider
self . system_prompt = system_prompt
self . default_params = default_params
# 消息历史
self . messages : List [ Dict [ str , str ] ] = [ ]
if system_prompt :
self . messages . append ( { " role " : " system " , " content " : system_prompt } )
logger . debug ( " 创建聊天会话: id= %s " , session_id )
def add_message ( self , role : str , content : str ) :
""" 添加消息到历史 """
self . messages . append ( { " role " : role , " content " : content } )
logger . debug (
" 会话 %s 添加消息: role= %s , length= %d " , self . session_id , role , len ( content )
)
def clear_history ( self ) :
""" 清空消息历史(保留系统提示) """
self . messages = [ ]
if self . system_prompt :
self . messages . append ( { " role " : " system " , " content " : self . system_prompt } )
logger . debug ( " 会话 %s 清空历史 " , self . session_id )
def set_system_prompt ( self , prompt : str ) :
""" 设置系统提示词 """
self . system_prompt = prompt
# 更新消息历史中的系统消息
if self . messages and self . messages [ 0 ] [ " role " ] == " system " :
self . messages [ 0 ] [ " content " ] = prompt
elif prompt :
self . messages . insert ( 0 , { " role " : " system " , " content " : prompt } )
logger . debug ( " 会话 %s 设置系统提示: length= %d " , self . session_id , len ( prompt ) )
async def send_message ( self , message : str , * * override_params ) - > str :
""" 发送消息并获取响应
Args:
message: 用户消息内容
**override_params: 覆盖默认参数
Returns:
模型响应内容
"""
# 添加用户消息
self . add_message ( " user " , message )
# 合并参数
params = { * * self . default_params , * * override_params }
# 发送请求
logger . debug ( " 会话 %s 发送消息: length= %d " , self . session_id , len ( message ) )
response = await self . llm_provider . chat ( self . messages , * * params )
# 添加助手响应
self . add_message ( " assistant " , response )
return response
async def send_message_stream ( self , message : str , * * override_params ) :
""" 流式发送消息
Args:
message: 用户消息内容
**override_params: 覆盖默认参数
Yields:
流式响应的文本块
"""
# 添加用户消息
self . add_message ( " user " , message )
# 合并参数
params = { * * self . default_params , * * override_params }
# 发送流式请求
logger . debug ( " 会话 %s 发送流式消息: length= %d " , self . session_id , len ( message ) )
full_response = " "
async for chunk in self . llm_provider . chat_stream ( self . messages , * * params ) :
yield chunk
full_response + = chunk
# 添加完整的助手响应到历史
self . add_message ( " assistant " , full_response )
def get_history ( self ) - > List [ Dict [ str , str ] ] :
""" 获取消息历史(不包括系统消息) """
# 返回用户和助手的消息,可选排除系统消息
return [ msg for msg in self . messages if msg [ " role " ] != " system " ]
def save_to_file ( self , file_path : Path ) :
""" 保存会话到文件 """
data = {
" session_id " : self . session_id ,
" system_prompt " : self . system_prompt ,
" default_params " : self . default_params ,
" messages " : self . messages ,
}
with open ( file_path , " w " , encoding = " utf-8 " ) as f :
json . dump ( data , f , ensure_ascii = False , indent = 2 )
logger . debug ( " 会话 %s 保存到: %s " , self . session_id , file_path )
@classmethod
def load_from_file ( cls , file_path : Path , llm_provider ) :
""" 从文件加载会话 """
with open ( file_path , " r " , encoding = " utf-8 " ) as f :
data = json . load ( f )
session = cls (
session_id = data [ " session_id " ] ,
llm_provider = llm_provider ,
system_prompt = data . get ( " system_prompt " , " " ) ,
* * data . get ( " default_params " , { } )
)
session . messages = data [ " messages " ]
logger . debug ( " 从文件加载会话: %s " , file_path )
return session
class ChatManager :
""" 聊天管理器,管理多个会话 """
def __init__ ( self ) :
self . sessions : Dict [ str , ChatSession ] = { }
self . default_session_id = " default "
logger . debug ( " 聊天管理器初始化完成 " )
def get_session (
self ,
session_id : Optional [ str ] = None ,
create_if_missing : bool = True ,
* * session_params
) - > Optional [ ChatSession ] :
""" 获取或创建聊天会话
Args:
session_id: 会话标识符, None 则使用默认会话
create_if_missing: 如果会话不存在则创建
**session_params: 传递给 ChatSession 的参数
Returns:
聊天会话实例,如果不存在且不创建则返回 None
"""
if session_id is None :
session_id = self . default_session_id
if session_id in self . sessions :
return self . sessions [ session_id ]
if create_if_missing :
# 获取 LLM 提供者
provider_name = config_var . get ( ) [ " services " ] [ " llm " ]
provider_config = config_var . get ( ) [ " providers " ] [ " llm " ] [ provider_name ]
llm_provider = prov [ provider_name ] ( provider_config )
session = ChatSession (
session_id = session_id , llm_provider = llm_provider , * * session_params
)
self . sessions [ session_id ] = session
logger . debug ( " 创建新会话: id= %s " , session_id )
return session
return None
def delete_session ( self , session_id : str ) :
""" 删除会话 """
if session_id in self . sessions :
del self . sessions [ session_id ]
logger . debug ( " 删除会话: id= %s " , session_id )
def list_sessions ( self ) - > List [ str ] :
""" 列出所有会话ID """
return list ( self . sessions . keys ( ) )
# 全局聊天管理器实例
_chat_manager : Optional [ ChatManager ] = None
def get_chat_manager ( ) - > ChatManager :
""" 获取全局聊天管理器实例 """
global _chat_manager
if _chat_manager is None :
_chat_manager = ChatManager ( )
logger . debug ( " 创建全局聊天管理器 " )
return _chat_manager
def create_chat_session (
session_id : Optional [ str ] = None , * * session_params
) - > ChatSession :
""" 创建或获取聊天会话(便捷函数) """
manager = get_chat_manager ( )
return manager . get_session ( session_id , True , * * session_params )
logger . debug ( " LLM 服务初始化完成 " )