支持 DeepSeek R1 & GPT 4 · 基于文件问答 · LLM本地部署 · 联网搜索 · Agent 助理 · 支持 Fine-tune
视频教程 · 2.0介绍视频 || 在线体验 · 一键部署{action_name}: {action_input}\n
' else: return "" class ChuanhuCallbackHandler(BaseCallbackHandler): def __init__(self, callback) -> None: """Initialize callback handler.""" self.callback = callback def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: self.callback(get_action_description(action)) def on_tool_end( self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: """If not the final action, print out observation.""" # if observation_prefix is not None: # self.callback(f"\n\n{observation_prefix}") # self.callback(output) # if llm_prefix is not None: # self.callback(f"\n\n{llm_prefix}") if observation_prefix is not None: logging.info(observation_prefix) self.callback(output) if llm_prefix is not None: logging.info(llm_prefix) def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: # self.callback(f"{finish.log}\n\n") logging.info(finish.log) def on_llm_new_token( self, token: str, *, chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on new LLM token. Only available when streaming is enabled. Args: token (str): The new token. chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk, containing content and other information. """ logging.info(f"### CHUNK ###: {chunk}") self.callback(token) class ModelType(Enum): Unknown = -1 OpenAI = 0 ChatGLM = 1 LLaMA = 2 XMChat = 3 StableLM = 4 MOSS = 5 YuanAI = 6 Minimax = 7 ChuanhuAgent = 8 GooglePaLM = 9 LangchainChat = 10 Midjourney = 11 Spark = 12 OpenAIInstruct = 13 Claude = 14 Qwen = 15 OpenAIVision = 16 ERNIE = 17 DALLE3 = 18 GoogleGemini = 19 GoogleGemma = 20 Ollama = 21 Groq = 22 DeepSeek = 23 @classmethod def get_type(cls, model_name: str): # 1. get model type from model metadata (if exists) model_type = MODEL_METADATA[model_name]["model_type"] if model_type is not None: for member in cls: if member.name == model_type: return member # 2. infer model type from model name model_type = None model_name_lower = model_name.lower() if "gpt" in model_name_lower: try: assert MODEL_METADATA[model_name]["multimodal"] == True model_type = ModelType.OpenAIVision except Exception: if "instruct" in model_name_lower: model_type = ModelType.OpenAIInstruct elif "vision" in model_name_lower: model_type = ModelType.OpenAIVision else: model_type = ModelType.OpenAI elif "chatglm" in model_name_lower: model_type = ModelType.ChatGLM elif "groq" in model_name_lower: model_type = ModelType.Groq elif "ollama" in model_name_lower: model_type = ModelType.Ollama elif "llama" in model_name_lower or "alpaca" in model_name_lower: model_type = ModelType.LLaMA elif "xmchat" in model_name_lower: model_type = ModelType.XMChat elif "stablelm" in model_name_lower: model_type = ModelType.StableLM elif "moss" in model_name_lower: model_type = ModelType.MOSS elif "yuanai" in model_name_lower: model_type = ModelType.YuanAI elif "minimax" in model_name_lower: model_type = ModelType.Minimax elif "川虎助理" in model_name_lower: model_type = ModelType.ChuanhuAgent elif "palm" in model_name_lower: model_type = ModelType.GooglePaLM elif "gemini" in model_name_lower: model_type = ModelType.GoogleGemini elif "midjourney" in model_name_lower: model_type = ModelType.Midjourney elif "azure" in model_name_lower or "api" in model_name_lower: model_type = ModelType.LangchainChat elif "讯飞星火" in model_name_lower: model_type = ModelType.Spark elif "claude" in model_name_lower: model_type = ModelType.Claude elif "qwen" in model_name_lower: model_type = ModelType.Qwen elif "ernie" in model_name_lower: model_type = ModelType.ERNIE elif "dall" in model_name_lower: model_type = ModelType.DALLE3 elif "gemma" in model_name_lower: model_type = ModelType.GoogleGemma elif "deepseek" in model_name_lower: model_type = ModelType.DeepSeek else: model_type = ModelType.LLaMA return model_type def download(repo_id, filename, retry=10): if os.path.exists("./models/downloaded_models.json"): with open("./models/downloaded_models.json", "r") as f: downloaded_models = json.load(f) if repo_id in downloaded_models: return downloaded_models[repo_id]["path"] else: downloaded_models = {} while retry > 0: try: model_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir="models", resume_download=True, ) downloaded_models[repo_id] = {"path": model_path} with open("./models/downloaded_models.json", "w") as f: json.dump(downloaded_models, f) break except Exception: print("Error downloading model, retrying...") retry -= 1 if retry == 0: raise Exception("Error downloading model, please try again later.") return model_path class BaseLLMModel: def __init__( self, model_name, user="", config=None, ) -> None: if config is not None: temp = MODEL_METADATA[model_name].copy() keys_with_diff_values = {key: temp[key] for key in temp if key in DEFAULT_METADATA and temp[key] != DEFAULT_METADATA[key]} config.update(keys_with_diff_values) temp.update(config) config = temp else: config = MODEL_METADATA[model_name] self.model_name = config["model_name"] self.multimodal = config["multimodal"] self.description = config["description"] self.placeholder = config["placeholder"] self.token_upper_limit = config["token_limit"] self.system_prompt = config["system"] self.api_key = config["api_key"] self.api_host = config["api_host"] self.stream = config["stream"] self.interrupted = False self.need_api_key = self.api_key is not None self.history = [] self.all_token_counts = [] self.model_type = ModelType.get_type(model_name) self.history_file_path = get_first_history_name(user) self.user_name = user self.chatbot = [] self.default_single_turn = config["single_turn"] self.default_temperature = config["temperature"] self.default_top_p = config["top_p"] self.default_n_choices = config["n_choices"] self.default_stop_sequence = config["stop"] self.default_max_generation_token = config["max_generation"] self.default_presence_penalty = config["presence_penalty"] self.default_frequency_penalty = config["frequency_penalty"] self.default_logit_bias = config["logit_bias"] self.default_user_identifier = user self.default_stream = config["stream"] self.single_turn = self.default_single_turn self.temperature = self.default_temperature self.top_p = self.default_top_p self.n_choices = self.default_n_choices self.stop_sequence = self.default_stop_sequence self.max_generation_token = self.default_max_generation_token self.presence_penalty = self.default_presence_penalty self.frequency_penalty = self.default_frequency_penalty self.logit_bias = self.default_logit_bias self.user_identifier = user self.metadata = config["metadata"] def get_answer_stream_iter(self): """Implement stream prediction. Conversations are stored in self.history, with the most recent question in OpenAI format. Should return a generator that yields the next word (str) in the answer. """ logging.warning( "Stream prediction is not implemented. Using at once prediction instead." ) response, _ = self.get_answer_at_once() yield response def get_answer_at_once(self): """predict at once, need to be implemented conversations are stored in self.history, with the most recent question, in OpenAI format Should return: the answer (str) total token count (int) """ logging.warning("at once predict not implemented, using stream predict instead") response_iter = self.get_answer_stream_iter() count = 0 for response in response_iter: count += 1 return response, sum(self.all_token_counts) + count def billing_info(self): """get billing infomation, inplement if needed""" # logging.warning("billing info not implemented, using default") return BILLING_NOT_APPLICABLE_MSG def count_token(self, user_input): """get token count from input, implement if needed""" # logging.warning("token count not implemented, using default") return len(user_input) def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""): def get_return_value(): return chatbot, status_text status_text = i18n("开始实时传输回答……") if fake_input: chatbot.append((fake_input, "")) else: chatbot.append((inputs, "")) user_token_count = self.count_token(inputs) self.all_token_counts.append(user_token_count) logging.debug(f"输入token计数: {user_token_count}") stream_iter = self.get_answer_stream_iter() if display_append: display_append = ( '\n\n