llm.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from openai import OpenAI
  2. import requests
  3. import json
  4. from utils.get_logger import setup_logger
  5. from config import model_name_vllm_url_dict
  6. logger = setup_logger(__name__)
  7. class VllmApi():
  8. def __init__(self, chat_json):
  9. openai_api_key = "EMPTY"
  10. model = chat_json.get("model")
  11. vllm_url = model_name_vllm_url_dict.get(model)
  12. openai_api_base = vllm_url
  13. self.vllm_chat_url = f"{vllm_url}/chat/completions"
  14. self.vllm_generate_url = f"{vllm_url}/completions"
  15. self.client = OpenAI(
  16. # defaults to os.environ.get("OPENAI_API_KEY")
  17. api_key=openai_api_key,
  18. base_url=openai_api_base,
  19. )
  20. def chat(self,
  21. prompt : str = "",
  22. model: str = "deepseek-r1:7b",
  23. stream: bool = False,
  24. top_p: float = 0.9,
  25. temperature: float = 0.6,
  26. max_tokens: int = 1024,
  27. history: list = []
  28. ):
  29. if history:
  30. messages = history
  31. else:
  32. messages = [{"role": "user", "content": prompt}]
  33. chat_response = self.client.chat.completions.create(
  34. model=model,
  35. messages=messages,
  36. stream=stream,
  37. top_p=top_p,
  38. temperature=temperature,
  39. max_tokens=max_tokens
  40. )
  41. # 针对deepseek的模型,是否输出think部分
  42. yield_reasoning_content = True
  43. yield_content = True
  44. has_reason = ""
  45. if stream:
  46. for chunk in chat_response:
  47. logger.info(f"vllm返回的chunk信息:{chunk}")
  48. reasoning_content = None
  49. content = None
  50. chat_id = chunk.id
  51. # Check the content is reasoning_content or content
  52. if chunk.choices[0].delta.role == "assistant":
  53. continue
  54. elif hasattr(chunk.choices[0].delta, "reasoning_content"):
  55. reasoning_content = chunk.choices[0].delta.reasoning_content
  56. if reasoning_content:
  57. has_reason += reasoning_content
  58. elif hasattr(chunk.choices[0].delta, "content"):
  59. content = chunk.choices[0].delta.content
  60. if reasoning_content is not None:
  61. if yield_reasoning_content:
  62. yield_reasoning_content = False
  63. reasoning_content = "```think" + reasoning_content
  64. # print("reasoning_content:", end="", flush=True)
  65. # print(reasoning_content, end="", flush=True)
  66. # yield reasoning_content
  67. yield {"id": chat_id, "event": "add", "data": reasoning_content}
  68. elif content is not None:
  69. if yield_content:
  70. yield_content = False
  71. if has_reason:
  72. content = "think```" + content
  73. else:
  74. content = content
  75. # print("\ncontent:", end="", flush=True)
  76. # print(content, end="", flush=True)
  77. # yield content
  78. yield {"id": chat_id, "event": "add", "data": content}
  79. if chunk.choices[0].finish_reason:
  80. yield {"id": chat_id, "event": "finish", "data": ""}
  81. else:
  82. # print(f"chat response: {chat_response.model_dump_json()}")
  83. yield chat_response.choices[0].message.content
  84. def generate(self,
  85. prompt: str,
  86. model: str = "deepseek-r1:7b",
  87. history: list = [],
  88. stream: bool = False
  89. ):
  90. completion = self.client.completions.create(
  91. model=model,
  92. prompt=prompt,
  93. max_tokens=1024,
  94. stream=stream
  95. )
  96. if stream:
  97. for chunk in completion:
  98. print(f"generate chunk: {chunk}")
  99. yield chunk
  100. else:
  101. return completion
  102. def request_generate(self, model, prompt, max_tokens: int = 1024, temperature: float = 0.6, stream: bool = False):
  103. json_data = {
  104. "model": model,
  105. "prompt": prompt,
  106. "max_tokens": max_tokens,
  107. "temperature": temperature,
  108. "stream": stream
  109. }
  110. response = requests.post(self.vllm_generate_url,json=json_data, stream=stream)
  111. response.raise_for_status()
  112. if stream:
  113. for line in response.iter_lines():
  114. if line:
  115. line_str = line.decode("utf-8")
  116. if line_str.startswith("data: "):
  117. json_str = line_str[len("data: "):]
  118. if json_str == "[DONE]":
  119. break
  120. print(f"返回的数据:{json.loads(json_str)}")
  121. yield json.loads(json_str)
  122. else:
  123. logger.info(f"直接返回结果:{response.json()}")
  124. yield response.json()
  125. def request_chat(self,
  126. model,
  127. prompt,
  128. history: list = [],
  129. temperature: float = 0.6,
  130. stream: bool = False,
  131. top_p: float = 0.7):
  132. history.append({"role": "user", "content": prompt})
  133. json_data = {
  134. "model": model,
  135. "messages": history,
  136. "temperature": temperature,
  137. "stream": stream,
  138. "top_p": top_p
  139. }
  140. response = requests.post(self.vllm_chat_url,json=json_data, stream=stream)
  141. response.raise_for_status()
  142. if stream:
  143. for line in response.iter_lines():
  144. if line:
  145. line_str = line.decode("utf-8")
  146. if line_str.startswith("data: "):
  147. json_str = line_str[len("data: "):]
  148. if json_str == "[DONE]":
  149. break
  150. print(f"chat模式返回的数据:{json.loads(json_str)}")
  151. yield json.loads(json_str)
  152. else:
  153. print(f"聊天模式直接返回结果:{response.json()}")
  154. return response.json()
  155. def main():
  156. history = [{"role": "system", "content": "你是一个非常有帮助的助手,在回答用户问题的时候请以<think>开头。"}]
  157. # prompt = "请帮我计算鸡兔同笼的问题。从上面数有35个头,从下面数有94只脚,请问分别多少只兔子多少只鸡?"
  158. prompt = "请帮我将下面提供的中文翻译成日文,要求:1、直接输出翻译的结果,2、不要进行任何解释。需要翻译的内容:我下飞机的时候行李丢了。"
  159. model = "DeepSeek-R1-Distill-Qwen-14B"
  160. vllm_chat_resp = VllmApi().request_chat(prompt=prompt, model=model, history=history, stream=True)
  161. # print("vllm 回复:")
  162. for chunk in vllm_chat_resp:
  163. pass
  164. # print(chunk, end='', flush=True)
  165. if __name__=="__main__":
  166. main()