vanna:基于RAG的text2sql框架
作者:mmseoamin日期:2024-04-27

文章目录

    • vanna简介及使用
    • vanna的原理
    • vanna的源码理解
    • 总结
    • 参考资料

      vanna简介及使用

      vanna是一个开源的利用了RAG的SQL生成python框架,在2024年3月已经有了5.8k的star数。

      Vanna is an MIT-licensed open-source Python RAG (Retrieval-Augmented Generation) framework for SQL generation and related functionality.

      Chat with your SQL database 📊. Accurate Text-to-SQL Generation via LLMs using RAG

      使用pip即可安装vanna:pip install vanna。

      vanna的使用主要分为三步:1. 确认所用的大模型和向量数据库;2. 将已有数据库的建表语句、文档、常用SQL及其自然语言查询问题进行向量编码存储到向量数据库(只用进行一次,除非数据有更改);3. 使用自然语言查询数据库。

      ## 第一步,假设使用 OpenAI LLM + ChromaDB 向量数据库
      from vanna.openai.openai_chat import OpenAI_Chat
      from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
      class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
          def __init__(self, config=None):
              ChromaDB_VectorStore.__init__(self, config=config)
              OpenAI_Chat.__init__(self, config=config)
      vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'})
      ## 第二步,将已有数据库相关信息存储起来
      # 建表语句ddl
      vn.train(ddl="""
          CREATE TABLE IF NOT EXISTS my-table (
              id INT PRIMARY KEY,
              name VARCHAR(100),
              age INT
          )
      """)
      # 数据库相关文档 documentation
      vn.train(documentation="Our business defines XYZ as ...")
      # 常用SQL
      vn.train(sql="SELECT name, age FROM my-table WHERE name = 'John Doe'")
      ## 第三步,就可以直接使用自然语言来查询数据了
      vn.ask("What are the top 10 customers by sales?")
      

      常用vanna函数(更多参见vanna 文档)

      # 训练(实际是添加数据到向量数据库)
      vn.train(ddl="")  #建表语句
      vn.train(documentation="") #文档
      vn.train(sql="", question="") #问题和sql对
      vn.train(sql="") #只有sql没有提供问题,会使用LLM来生成相应的问题
      vn.train(plan="") #一般是根据提供的数据库来生成训练计划,最终写入到向量数据库的还是ddl、documentation、sql/question三类
      # 查看已经加入到向量数据库的数据
      vn.get_training_data() #所有数据
      vn.get_related_sql()   #sql
      vn.get_related_ddl()   #ddl
      # 查询
      vn.ask()
      # 查询实际上是由下面四个函数依次执行的
      vn.generate_sql()  #生成sql语句
      vn.run_sql() #执行sql语句
      vn.generate_plotly_code() #根据执行结果生成plotly绘图代码
      vn.get_plotly_figure() #使用plotly绘图
      

      vanna的原理

      下图是来自vanna文档,用来解释vanna的原理。

      vanna:基于RAG的text2sql框架,在这里插入图片描述,第1张

      vanna是基于检索增强(RAG)的sql生成框架,会先用向量数据库将待查询数据库的建表语句、文档、常用SQL及其自然语言查询问题存储起来。在用户发起查询请求时,会先从向量数据库中检索出相关的建表语句、文档、SQL问答对放入到prompt里(DDL和文档作为上下文、SQL问答对作为few-shot样例),LLM根据prompt生成查询SQL并执行,框架会进一步将查询结果使用plotly可视化出来或用LLM生成后续问题。

      如果用户反馈LLM生成的结果是正确的,可以将这一问答对存储到向量数据库,可以使得以后的生成结果更准确。

      这篇博客记录了vanna尝试不同LLM和添加不同的上下文到prompt时生成SQL的准确率,表明在prompt中加入相关SQL问答对作为few-shot对于提升结果准确性很重要,GPT-4是效果最好的LLM。

      vanna:基于RAG的text2sql框架,在这里插入图片描述,第2张

      vanna的源码理解

      vanna所谓的训练(即vn.train())最终分为三类数据:ddldocumentationsql/question。使用向量数据库chromadb的实现时创建了三个collection,也就是三类数据将分别存储和检索。对于sql/question会将数据变成{"question": question,"sql": sql}json字符串存储。如果用户在训练时只提供了sql没有提供问题,会使用LLM来生成相应的问题(使用的prompt为"The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question.")。

      在查询阶段的vn.ask()由vn.generate_sql() 、vn.run_sql() 、vn.generate_plotly_code() 、vn.get_plotly_figure() 四个函数组成。其中最关键的是vn.generate_sql(),它分为以下关键几步:

      • get_similar_question_sql(question, **kwargs)去向量数据库中检索与问题相似的sql/question对

      • get_related_ddl(question, **kwargs) 去向量数据库中检索与问题相似的建表语句ddl

      • get_related_documentation(question, **kwargs) 去向量数据库中检索与问题相似的文档

      • get_sql_prompt(question,question_sql_list,ddl_list,doc_list, **kwargs) 生成prompt,

        ##以openAI的API调用为例, prompt生成的流程 分为下面几个部分
        initial_prompt = """
        The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
        """
        ## 如果有相关ddl,且没超过上下文窗口大小
        if len(ddd_list)>0:
          initial_prompt += "You may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
          for ddl in ddl_list:
            initial_prompt += f"{ddl}\n\n"
        ## 如果有相关documentation,且没超过上下文窗口大小
        if len(doc_list)>0:
        		initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
        		for documentation in doc_list:
            	initial_prompt += f"{documentation}\n\n"
        message_log = [{"role": "system", "content": initial_prompt}]
        ## 如果有相关query/sql问答对,作为few-shot
        for example in question_sql_list:
            if example is None:
                print("example is None")
            else:
                if example is not None and "question" in example and "sql" in example:
                    message_log.append({"role": "user", "content": user_message(example["question"])})
                    message_log.append( {"role": "assistant", "content":example["sql"]})
           
         message_log.append({"role": "user", "content": question})
        

        注:生成followup question时的system_message是下面的样子

        ## prompt 分为下面几个部分
        initial_prompt = """
        The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
        """
        ## 如果有相关ddl,且没超过上下文窗口大小
        if len(ddd_list)>0:
          initial_prompt += "You may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
          for ddl in ddl_list:
            initial_prompt += f"{ddl}\n\n"
        ## 如果有相关documentation,且没超过上下文窗口大小
        if len(doc_list)>0:
        		initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
        		for documentation in doc_list:
            	initial_prompt += f"{documentation}\n\n"
        ## 如果有相关query/sql问答对,且没超过上下文窗口大小
        if len(question_sql_list)>0:
          	initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
        		for question in question_sql_list:
              initial_prompt += f"{question['question']}\n{question['sql']}\n\n"
        
      • submit_prompt(prompt, **kwargs) 提交prompt到大模型生成sql

      • extract_sql(llm_response) 使用正则从LLM的回复中获取sql

        总结

        vanna使用RAG的方式来提高text2sql的准确性,个人觉得将prompt中的上下文分为DDL(建表语句schema)、数据库文档、相关问题和sql三大类是vanna框架里很重要的一个思路。从代码来看,对这三类数据编码和检索的向量模型是同一个,这对向量模型的通用表征能力要求很高。在实际使用时,与其他RAG应用一样,document的分块对于检索准确率同样有很大影响。

        参考资料

        1. vanna github
        2. vanna 文档