llm.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from openai import OpenAI, AsyncOpenAI
  2. import requests
  3. import json
  4. from utils.get_logger import setup_logger
  5. from config import model_name_vllm_url_dict
  6. logger = setup_logger(__name__)
  7. class VllmApi():
  8. def __init__(self, chat_json):
  9. openai_api_key = "sk-72f9d0e5bc894e1d828d73bdcc50ff0a"
  10. # model = chat_json.get("model")
  11. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  12. vllm_url = model_name_vllm_url_dict.get(model)
  13. openai_api_base = vllm_url
  14. self.vllm_chat_url = f"{vllm_url}/chat/completions"
  15. self.vllm_generate_url = f"{vllm_url}/completions"
  16. self.client = AsyncOpenAI(
  17. # defaults to os.environ.get("OPENAI_API_KEY")
  18. api_key=openai_api_key,
  19. base_url=openai_api_base,
  20. )
  21. async def chat(self,
  22. prompt : str = "",
  23. model: str = "deepseek-r1:7b",
  24. stream: bool = False,
  25. top_p: float = 0.9,
  26. temperature: float = 0.6,
  27. max_tokens: int = 1024,
  28. history: list = [],
  29. enable_think: bool = False
  30. ):
  31. if history:
  32. messages = history
  33. messages.insert(0, {"role": "system","content": "你是一个基于检索增强生成(RAG)的专业问答助手。"})
  34. else:
  35. messages = [{"role": "user", "content": prompt}]
  36. if model == "Qwen3-30B":
  37. if enable_think:
  38. temperature = 0.6
  39. top_p=0.95
  40. else:
  41. temperature = 0.7
  42. top_p=0.8
  43. extra_body = {
  44. "chat_template_kwargs": {"enable_thinking": enable_think},
  45. "top_k": 20
  46. }
  47. else:
  48. extra_body = {}
  49. chat_response = await self.client.chat.completions.create(
  50. model=model,
  51. messages=messages,
  52. stream=stream,
  53. top_p=top_p,
  54. temperature=temperature,
  55. max_tokens=max_tokens,
  56. # extra_body={
  57. # # "chat_template_kwargs": {"enable_thinking": enable_think},
  58. # "enable_thinking": enable_think,
  59. # "top_k":top_k
  60. # },
  61. extra_body=extra_body
  62. )
  63. # 针对deepseek的模型,是否输出think部分
  64. yield_reasoning_content = True
  65. yield_content = True
  66. has_reason = ""
  67. if stream:
  68. try:
  69. async for chunk in chat_response:
  70. logger.info(f"vllm返回的chunk信息:{chunk}")
  71. reasoning_content = None
  72. content = None
  73. chat_id = chunk.id
  74. # Check the content is reasoning_content or content
  75. if chunk.choices[0].delta.role == "assistant":
  76. continue
  77. # elif hasattr(chunk.choices[0].delta, "reasoning_content"):
  78. # reasoning_content = chunk.choices[0].delta.reasoning_content
  79. # if reasoning_content:
  80. # has_reason += reasoning_content
  81. elif hasattr(chunk.choices[0].delta, "content"):
  82. content = chunk.choices[0].delta.content
  83. if reasoning_content is not None:
  84. if yield_reasoning_content:
  85. yield_reasoning_content = False
  86. reasoning_content = "```think" + reasoning_content
  87. # print("reasoning_content:", end="", flush=True)
  88. # print(reasoning_content, end="", flush=True)
  89. # yield reasoning_content
  90. yield {"id": chat_id, "event": "add", "data": reasoning_content}
  91. elif content is not None:
  92. if yield_content:
  93. yield_content = False
  94. if has_reason:
  95. content = "think```" + content
  96. else:
  97. content = content
  98. # print("\ncontent:", end="", flush=True)
  99. # print(content, end="", flush=True)
  100. # yield content
  101. logger.info(f"返回的信息是id: {chat_id}, event: add, data: {content}")
  102. yield {"id": chat_id, "event": "add", "data": content}
  103. if chunk.choices[0].finish_reason:
  104. yield {"id": chat_id, "event": "finish", "data": ""}
  105. finally:
  106. # await self.client.close()
  107. pass
  108. else:
  109. # print(f"chat response: {chat_response.model_dump_json()}")
  110. yield {"id": "", "event": "finish", "data": chat_response.choices[0].message.content}
  111. async def generate_non_stream_async(
  112. self,
  113. prompt: list,
  114. model: str,
  115. return_full_response: bool = False,
  116. ):
  117. # messages = [
  118. # {"role": "user", "content": prompt}
  119. # ]
  120. completion = await self.client.chat.completions.create(
  121. model=model,
  122. messages=prompt,
  123. stream=False,
  124. )
  125. if return_full_response:
  126. return completion
  127. return completion.choices[0].message.content
  128. def generate(self,
  129. prompt: str,
  130. model: str = "deepseek-r1:7b",
  131. history: list = [],
  132. stream: bool = False
  133. ):
  134. completion = self.client.completions.create(
  135. model=model,
  136. prompt=prompt,
  137. max_tokens=1024,
  138. stream=stream
  139. )
  140. if stream:
  141. for chunk in completion:
  142. print(f"generate chunk: {chunk}")
  143. yield chunk
  144. else:
  145. return completion
  146. def request_generate(self, model, prompt, max_tokens: int = 1024, temperature: float = 0.6, stream: bool = False):
  147. json_data = {
  148. "model": model,
  149. "prompt": prompt,
  150. "max_tokens": max_tokens,
  151. "temperature": temperature,
  152. "stream": stream
  153. }
  154. response = requests.post(self.vllm_generate_url,json=json_data, stream=stream)
  155. response.raise_for_status()
  156. if stream:
  157. for line in response.iter_lines():
  158. if line:
  159. line_str = line.decode("utf-8")
  160. if line_str.startswith("data: "):
  161. json_str = line_str[len("data: "):]
  162. if json_str == "[DONE]":
  163. break
  164. print(f"返回的数据:{json.loads(json_str)}")
  165. yield json.loads(json_str)
  166. else:
  167. logger.info(f"直接返回结果:{response.json()}")
  168. yield response.json()
  169. def request_chat(self,
  170. model,
  171. prompt,
  172. history: list = [],
  173. temperature: float = 0.6,
  174. stream: bool = False,
  175. top_p: float = 0.7):
  176. history.append({"role": "user", "content": prompt})
  177. json_data = {
  178. "model": model,
  179. "messages": history,
  180. "temperature": temperature,
  181. "stream": stream,
  182. "top_p": top_p
  183. }
  184. response = requests.post(self.vllm_chat_url,json=json_data, stream=stream)
  185. response.raise_for_status()
  186. if stream:
  187. for line in response.iter_lines():
  188. if line:
  189. line_str = line.decode("utf-8")
  190. if line_str.startswith("data: "):
  191. json_str = line_str[len("data: "):]
  192. if json_str == "[DONE]":
  193. break
  194. print(f"chat模式返回的数据:{json.loads(json_str)}")
  195. yield json.loads(json_str)
  196. else:
  197. print(f"聊天模式直接返回结果:{response.json()}")
  198. return response.json()
  199. def main():
  200. history = [{"role": "system", "content": "你是一个非常有帮助的助手,在回答用户问题的时候请以<think>开头。"}]
  201. # prompt = "请帮我计算鸡兔同笼的问题。从上面数有35个头,从下面数有94只脚,请问分别多少只兔子多少只鸡?"
  202. prompt = "请帮我将下面提供的中文翻译成日文,要求:1、直接输出翻译的结果,2、不要进行任何解释。需要翻译的内容:我下飞机的时候行李丢了。"
  203. model = "DeepSeek-R1-Distill-Qwen-14B"
  204. vllm_chat_resp = VllmApi().request_chat(prompt=prompt, model=model, history=history, stream=True)
  205. # print("vllm 回复:")
  206. for chunk in vllm_chat_resp:
  207. pass
  208. # print(chunk, end='', flush=True)
  209. if __name__=="__main__":
  210. main()