当前位置:   article > 正文

VannaAI 介绍及使用 - 第四篇_vanna 实战

vanna 实战

使用案例

前言

本篇主要记录实际使用的案例,仅供参考,请大家继续关注指正。

一、项目背景

1.项目需要使用Vanna的text-to-sql能力,但是呢gpt3.5能力稍弱一些,所以我们想要使用更强大的gpt4o。

2.因为gpt4o国内使用受限,且涉及公司数据,所以我们是需要另外部门提供的api,通过将请求参数发送到他们的接口,他们再去请求gpt4o,将结果返给我们,所以需要自定义请求方式,不能使用Vanna源码的请求。

3.我们发现在实际使用过程中,用户的有些问题是没法生成sql回答的,那么我们需要大模型能正常回答问题,不止是生成sql,但是呢又要不影响正常sql的生成,那可能涉及对prompt提示词的改造。

 二、案例

1.初始化

这里因为是样例,没有完整项目的代码那么完善,仅作为举例说明展示。

  1. from vanna.openai import OpenAI_Chat
  2. from vanna.vannadb import VannaDB_VectorStore
  3. from openai import OpenAI
  4. import time
  5. import json
  6. def ask(
  7. question: str,
  8. user: str,
  9. password: str,
  10. host: str,
  11. port: int,
  12. database: str,
  13. ):
  14. vn.connect_to_mysql(user=user, password=password, host=host, port=port, dbname=database)
  15. # 记录开始时间
  16. start_time = time.time()
  17. # 调用新的 ask 方法
  18. sql, data, fig, llm_response = vn.ask(question=question, )
  19. print(f"调用 vn.ask 的耗时: {time.time() - start_time:.4f} 秒")
  20. # 处理返回的结果
  21. if llm_response is not None:
  22. result = {
  23. "sql": None,
  24. "data": llm_response,
  25. "chart": None,
  26. }
  27. else:
  28. if data is not None:
  29. image = fig.to_json() if fig else None
  30. result = {
  31. "sql": sql,
  32. "data": data.to_json(orient="records"),
  33. "chart": image
  34. }
  35. else:
  36. result = {
  37. "sql": sql,
  38. "data": None,
  39. "chart": None
  40. }
  41. print(json.dumps(result, ensure_ascii=False))
  42. if __name__ == "__main__":
  43. # 使用自己的vanna模型及api key
  44. MY_VANNA_MODEL = "my_vanna_model"
  45. MY_VANNA_API_KEY = "my_vanna_api_key"
  46. # 使用自己的llm模型及所需配置
  47. my_llm_api_key = 'EMPTY' #因为是第三方提供的接口,暂不需要apikey
  48. my_llm_base_url = 'http://0.0.0.0:8000/xx/xx/xx' # 第三方提供的完整接口,后续不需要拼接
  49. my_llm_name = 'gpt4o' #告知第三方我们需要的是gpt4o
  50. # 初始化客户端
  51. client = OpenAI(
  52. api_key=my_llm_api_key ,
  53. base_url=my_llm_base_url
  54. )
  55. class MyVanna(VannaDB_VectorStore, OpenAI_Chat):
  56. def __init__(self, client=None, config=None):
  57. VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY,config=config)
  58. OpenAI_Chat.__init__(self, client=client, config=config)
  59. # 使用自定义的大模型及vanna提供的向量库
  60. vn = MyVanna(client=client, config={"model": my_llm_name, })
  61. # 自定义问题,用于测试
  62. question = "中国有哪些省份"
  63. user = "root"
  64. password = "1234"
  65. host = "127.0.0.1"
  66. port = 3306
  67. database = "自己的库名"
  68. ask(question, user, password, host, port, database)

这里设置一个main是为了模拟传入不同的参数和问题,这个可以自己根据业务需求去调整,比如从数据库中取出相应参数,或者配置文件读取,从前端请求传入问题之类的。这里只是便于测试,简单举例。

还定义了一个ask方法,这是为了方便处理返回结果,其中也做了数据库初始化连接,这里也可以根据自己的业务情况,调整结果处理方式,我这里vn.ask返回四个参数是因为我对Vanna的ask源码做了改动,以为我需要额外返回自然回答的情况。详情见后续的ask代码。

 2.ask方法改造

ask方法在Vanna的base.py文件中,改造后的ask方法如下,直接上代码:

  1. def ask(
  2. self,
  3. question: Union[str, None] = None,
  4. # session_id: str = None, # 添加 session_id 参数
  5. print_results: bool = True,
  6. # auto_train: bool = True,
  7. auto_train: bool = False, # 关闭默认训练,随着时间的推移,prompt会越来越多,token消耗会越来越快
  8. visualize: bool = True, # if False, will not generate plotly code
  9. ) -> Union[
  10. Tuple[
  11. Union[str, None],
  12. Union[pd.DataFrame, None],
  13. Union[plotly.graph_objs.Figure, None],
  14. Union[str, None]
  15. ],
  16. None,
  17. ]:
  18. """
  19. **Example:**
  20. ```python
  21. vn.ask("What are the top 10 customers by sales?")
  22. ```
  23. Ask Vanna.AI a question and get the SQL query that answers it.
  24. Args:
  25. question (str): The question to ask.
  26. print_results (bool): Whether to print the results of the SQL query.
  27. auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
  28. visualize (bool): Whether to generate plotly code and display the plotly figure.
  29. Returns:
  30. Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure.
  31. """
  32. # 判断否text2sql对话,默认为是
  33. is_text2sql = 1
  34. if question is None:
  35. question = input("Enter a question: ")
  36. try:
  37. sql = self.generate_sql(question=question)
  38. if self.is_sql_valid(sql) is False:
  39. return None, None, None, sql
  40. except Exception as e:
  41. print(e)
  42. is_text2sql = 0
  43. llm_response = self.ask_llm(question=question, is_text2sql=is_text2sql)
  44. return None, None, None, llm_r
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/959933
推荐阅读
相关标签
  

闽ICP备14008679号