相关推荐recommended
基于langchainsql和chatglm实现自然语言查询mysql数据库
作者:mmseoamin日期:2024-01-25

首先发布一个chatglm服务,具体如下:

import os

import json

from flask import Flask

from flask import request

from transformers import AutoTokenizer, AutoModel

# system params

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

tokenizer = AutoTokenizer.from_pretrained(r".\chatglm2-6b-int4", trust_remote_code=True)

model = AutoModel.from_pretrained(r".\chatglm2-6b-int4", trust_remote_code=True).half().cuda()

model.eval()

app = Flask(__name__)

@app.route("/chat", methods=["POST"])

def chat():

    """chat

    """

    data_seq = request.get_data()

    data_dict = json.loads(data_seq)

    human_input = data_dict["human_input"]

    response, _ = model.chat(tokenizer, human_input, history=[])

    result_dict = {

        "response": response

    }

    result_seq = json.dumps(result_dict, ensure_ascii=False)

    return result_seq

if __name__ == "__main__":

    app.run(host="0.0.0.0", port=8595, debug=False)

然后就可以基于langchain进行查询具体如下:

openai_api_key = "xxxx"

import os

import openai

# !pip install langchain langchain-experimental openai -q

from langchain import OpenAI, SQLDatabase

from langchain_experimental.sql import SQLDatabaseChain

import time

import logging

import requests

from typing import Optional, List, Dict, Mapping, Any

import langchain

from langchain.llms.base import LLM

from langchain.cache import InMemoryCache

logging.basicConfig(level=logging.INFO)

# 启动llm的缓存

langchain.llm_cache = InMemoryCache()

class ChatGLM(LLM):

    # 模型服务url

    url = "http://127.0.0.1:8595/chat"

    @property

    def _llm_type(self) -> str:

        return "chatglm"

    def _construct_query(self, prompt: str) -> Dict:

        """构造请求体

        """

        query = {

            "human_input": prompt

        }

        return query

    @classmethod

    def _post(cls, url: str,

              query: Dict) -> Any:

        """POST请求

        """

        _headers = {"Content_Type": "application/json"}

        with requests.session() as sess:

            resp = sess.post(url,

                             json=query,

                             headers=_headers,

                             timeout=60)

        return resp

    def _call(self, prompt: str,

              stop: Optional[List[str]] = None) -> str:

        """_call

        """

        # construct query

        query = self._construct_query(prompt=prompt)

        # post

        resp = self._post(url=self.url,

                          query=query)

        if resp.status_code == 200:

            resp_json = resp.json()

            predictions = resp_json["response"]

            return predictions

        else:

            return "请求模型"

    @property

    def _identifying_params(self) -> Mapping[str, Any]:

        """Get the identifying parameters.

        """

        _param_dict = {

            "url": self.url

        }

        return _param_dict

# llm = OpenAI(temperature=0, openai_api_key="")

if __name__ == "__main__":

    llm = ChatGLM()

    # sqlite_db_path ='./chinook.db'

    db = SQLDatabase.from_uri(f"mysql://用户名:密码@ip:端口号/数据库名?charset=数据库编码")

    db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)

    db_chain.run(用户问题)