赞
踩
本博客是记录自己在学习大模型和数据分析遇到的问题以及解决思路
项目用到的数据是某销售平台的一部分订单信息,之后会依据这个数据库实现单表和多表的查询。
网上下载Navicat,不选择MySQL workbench的原因是上传数据没有navicat方便。
在Navicat连接到MySQL数据库后新建数据库“数据画像”,把csv数据导入到数据库。
- import mysql.connector
-
- def create_db_connection(host_name, user_name, user_password, db_name):
- connection = None
- try:
- connection = mysql.connector.connect(
- host=host_name,
- user=user_name,
- passwd=user_password,
- database=db_name
- )
- print("MySQL Database connection successful")
- except Exception as e:
- print(f"The error '{e}' occurred")
-
- return connection
-
- connection = create_db_connection("localhost", "root", db_key, "用户画像")

- from langchain_community.chat_models import ChatZhipuAI
- from dotenv import load_dotenv
- import os
-
- load_dotenv()
- key = os.getenv('ZHIPUAI_API_KEY')
-
- llm = ChatZhipuAI(
- temperature=0.1,
- api_key=key,
- model_name="glm-4",
- )
get_schema()用来获得数据库的详细信息
- def get_schema(connection):
- cursor = connection.cursor()
- query = """
- SELECT table_name, column_name, data_type
- FROM information_schema.columns
- WHERE table_schema = DATABASE();
- """
- cursor.execute(query)
- results = cursor.fetchall()
- cursor.close()
-
- schema_dict = {}
- for row in results:
- table_name = row[0]
- column_name = row[1]
- data_type = row[2]
-
- if table_name not in schema_dict:
- schema_dict[table_name] = {}
-
- schema_dict[table_name][column_name] = data_type
- return schema_dict

提取表名、列名和数据类型,然后将其组织成嵌套字典的形式。如果表名不在schema_dict
中,则添加一个新的字典条目。最终,将列名和数据类型作为键值对添加到对应的表名条目下。后面大模型就会基于这个信息实现text2sql
run_query() 用来在数据库中运行SQL代码,返回结果
- def run_query(connection, query):
- cursor = connection.cursor()
- cursor.execute(query)
- results = cursor.fetchall()
- cursor.close()
- return results
-
利用langchain框架,定义获得SQL提示模板,并且获取数据库模式信息
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.runnables import RunnablePassthrough
- from langchain_core.runnables import RunnableLambda
- from langchain.prompts import PromptTemplate
-
- template_sql = (
- "请通过写 MYSQL代码来回答对应问题,请确保你的代码不要使用,不要使用PostgreSQL特有的语法,选择mysql的语法"
- "并且需要基于如下数据库信息:{info},冒号之前是table的名字,后面是列名和数据格式\n"
- "需要回答的问题是:{question}\n"
- "注意仅需要通过sql代码回答,不需要文字\n"
- "代码形式如下:```sql\n...\n```"
- )
-
- t_sql=PromptTemplate.from_template(template_sql)
-
- schema_info = get_schema(connection)

get_sql() 用来从大模型生成的段落中提取SQL代码
再使用langchain的语言表达式(LCEL)这个处理链会把输入(info 和 question)依次传递给每个步骤进行处理,最后通过 RunnableLambda(get_sql)
获取 SQL 语句。
- def get_sql(x):
- return x.split("```sql")[1].split("```")[0]
- chain_sql=({"info":RunnablePassthrough(),"question":RunnablePassthrough()}
- |t_sql
- |llm
- |StrOutputParser()
- |RunnableLambda(get_sql))
上一步得到的sql代码,在数据库运行得到结果,依据结果和定义提示词模板模板
- template_sql0 = (
- "请通过综合如下的数据库信息回答问题。问题,sql代码,sql代码的执行结果给出问题的自然语言回答。\n"
- "数据库信息{info}\n"
- "需要回答的问题是:{question}\n"
- "sql代码: {query}\n"
- "sql代码执行结果: {res}"
- )
-
- t_sql0=PromptTemplate.from_template(template_sql0)
处理链接收数据库信息和查询问题,通过执行SQL查询获得结果,并生成一个包含这些信息的字符串,最后使用大语言模型和字符串解析器进行处理。
- chain_sql0=({"info":RunnablePassthrough(),"question":RunnablePassthrough(),"query":chain_sql}
- |RunnablePassthrough.assign(res=lambda x: run_query(connection,x["query"]))
- |t_sql0
- |llm
- |StrOutputParser())
使用两个处理链 chain_sql
和 chain_sql0
来处理输入数据 input_data
,并打印 chain_sql0
的处理结果
- question = "客户问题"
- input_data = {"info": schema_info, "question": question}
-
- chain_sql.invoke(input_data)
- result = chain_sql0.invoke(input_data)
-
- print(result)
基本查询:
排序:
聚合函数:
分组查询:
数据更新:
数据删除:
条件和组合查询:
如果一个数据库内有多个表,通过get_schema()函数来获得数据库的所有表名,以及表内的列名以及数据类型。这样让大模型根据问题,自动匹配数据库内的所需要的表以及列名,实现text2sql。
经过测试,对待下面的查询返回的SQL代码正确率达到95%以上
多表连接能力:
复杂查询构建能力:
数据汇总和聚合能力:
数据更新和删除操作:
子查询和嵌套查询能力:
处理多个数据集:
对于商品订单信息这个表格来说
对于这种情况,更改提示词,使其大模型了解数字对具体类别的映射:
- template_sql = (
- "请通过写 MYSQL代码来回答对应问题。请确保你的代码不要使用 PostgreSQL特有的语法,选择 mysql\n"
- "并且需要基于如下数据库信息: {info}。冒号之前是 table 的名字,后面是列名和数据格式。\n"
- "注意有些列用数字代表类别:\n"
- "- pay_type: 0代表银行卡, 1代表微信, 2代表支付宝\n"
- "- status: 0代表未支付, 1代表已支付,2代表退款\n"
- "需要回答的问题是: {question}\n"
- "注意你需要通过 sql代码回答,不需要文字\n"
- "代码形式如下: ```sql\n...```"
- )
在prompt中直接告诉大模型这种映射关系
结果:大模型成功的识别了提示词中的映射关系
方法二:
方法一虽然效果显著,但是很有局限性。假如数据库中表的数量巨多,总不可能把注释一个一个复制到prompt。
我们可以在数据库新创建一个表格(notation),把所有的映射关系都列上去。
从数据库中调出notation表的详细信息
- def get_comments(connection):
- try:
- cursor = connection.cursor(dictionary=True)
- # 执行SQL查询
- sql = "SELECT * FROM notation"
- cursor.execute(sql)
- result = cursor.fetchall()
-
- # 将结果格式化为字符串
- comments = "\n".join([f"{row['table_name']} | {row['mapping']}" for row in result])
- cursor.close()
- return comments
- except Exception as e:
- print(f"Error: {e}")
- return None
-
- comments = get_comments(connection)

修改一下prompt增加一个comment变量,在chain_sql中新加入一个输入
- template_sql = (
- "请通过写 MYSQL代码来回答问题。请确保你的代码不要使用 PostgreSQL特有的语法,选择 mysql。\n"
- "在回答问题之前,请先查看表格最后一列的注释信息,了解所有列的定义以及数字类别的含义。\n"
- "数据库信息如下:\n"
- "{info}\n"
- "注释信息如下:\n"
- "{comments}\n"
- "需要回答的问题是:{question}\n"
- "注意你需要通过 sql代码回答,不需要文字。\n"
- "代码格式如下:```sql\n"
- "..."
- )
-
- chain_sql=({"info":RunnablePassthrough(),"question":RunnablePassthrough(),"comments":RunnablePassthrough()}
- |t_sql
- |llm
- |StrOutputParser()
- |RunnableLambda(get_sql))
-
- input_data = {"info": schema_info, "question": question, "commetns": comments}
- chain_sql.invoke(input_data)

结果:大模型成功的识别了提示词中的映射关系
尝试一下使用方法二能不能解决问题:选出创建时间和修改时间不超过一天,评论成功的好评一共有几条。
这个问题既包括数字代表类别的数据,以及没有语义信息的字段,如果大模型成功解决这个text2sql问题。证明方法二对待问题3和4都有效果
结果表明大模型成功识别了所有映射信息。
- question = "分析历史销售数据,然后制定相应的销售策略和营销计划"
- input_data = {"info": schema_info, "question": question, "commetns": comments}
- chain_sql.invoke(input_data)
问题:分析历史销售数据,然后制定相应的销售策略和营销计划。 SQL代码: ```sql SELECT goods_id, goods_name, SUM(pay_amount) AS total_sales, COUNT(order_id) AS total_orders, AVG(pay_amount) AS average_order_value FROM 商品订单信息 WHERE status = 1 -- 已付款的订单 GROUP BY goods_id, goods_name ORDER BY total_sales DESC; ``` SQL代码执行结果: ``` [ ('510120', 'C3', 198.17999267578125, 1, 198.17999267578125),...
1. 重点关注销售总额较高的商品,如C3和C2类商品,增加库存和营销力度。
2. 分析销售总额较低的商品,找出原因,如价格、品牌知名度等,并采取相应措施提升其销售。
3. 针对平均订单价值较高的商品,可以考虑推出捆绑销售或推出会员专享折扣或优惠券,鼓励客户多次购买,从而增加总销售额。
结论:AI生成的回答能力是非常好的。它不仅能分析历史销售数据,还能提出合理的销售策略和营销计划。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。