相关推荐recommended
开源模型应用落地-工具使用篇-Spring AI-Function Call(八)
作者:mmseoamin日期:2024-03-20

​​​​​​​一、前言

    通过“开源模型应用落地-工具使用篇-Spring AI(七)-CSDN博客”文章的学习,已经掌握了如何通过Spring AI集成OpenAI和Ollama系列的模型,现在将通过进一步的学习,让Spring AI集成大语言模型更高阶的用法,使得我们能完成更复杂的需求。


二、术语

2.1、Spring AI

  是 Spring 生态系统的一个新项目,它简化了 Java 中 AI 应用程序的创建。它提供以下功能:

  • 支持所有主要模型提供商,例如 OpenAI、Microsoft、Amazon、Google 和 Huggingface。
  • 支持的模型类型包括“聊天”和“文本到图像”,还有更多模型类型正在开发中。
  • 跨 AI 提供商的可移植 API,用于聊天和嵌入模型。
  • 支持同步和流 API 选项。
  • 支持下拉访问模型特定功能。
  • AI 模型输出到 POJO 的映射。

    2.2、Function Call

         是 GPT API 中的一项新功能。它可以让开发者在调用 GPT系列模型时,描述函数并让模型智能地输出一个包含调用这些函数所需参数的 JSON 对象。这种功能可以更可靠地将 GPT 的能力与外部工具和 API 进行连接。

        简单来说就是开放了自定义插件的接口,通过接入外部工具,增强模型的能力。

    Spring AI集成Function Call:

    Function Calling :: Spring AI Reference

    开源模型应用落地-工具使用篇-Spring AI-Function Call(八),第1张


    三、前置条件

    3.1、JDK 17+

        下载地址:Java Downloads | Oracle

        开源模型应用落地-工具使用篇-Spring AI-Function Call(八),第2张

      

    3.2、创建Maven项目

        SpringBoot版本为3.2.3

    
        org.springframework.boot
        spring-boot-starter-parent
        3.2.3
         
    

    3.3、导入Maven依赖包

    
    	org.projectlombok
    	lombok
    	true
    
    
    	ch.qos.logback
    	logback-core
    
    
    	ch.qos.logback
    	logback-classic
    
    
    	cn.hutool
    	hutool-core
    	5.8.24
    
    
    	org.springframework.ai
    	spring-ai-openai-spring-boot-starter
    	0.8.0
    
    

    3.4、 科学上网的软件


    四、技术实现

    4.1、新增配置

    spring:
      ai:
        openai:
          api-key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
          chat:
            options:
              model: gpt-3.5-turbo
              temperature: 0.45
              max_tokens: 4096
              top-p: 0.9

      PS:

    1.   openai要替换自己的api-key
    2.   模型参数根据实际情况调整

     4.2、新增本地方法类(用于本地回调的function)

    import com.fasterxml.jackson.annotation.JsonClassDescription;
    import com.fasterxml.jackson.annotation.JsonInclude;
    import com.fasterxml.jackson.annotation.JsonProperty;
    import com.fasterxml.jackson.annotation.JsonPropertyDescription;
    import lombok.extern.slf4j.Slf4j;
    import java.util.function.Function;
    @Slf4j
    public class WeatherService implements Function {
        /**
         * Weather Function request.
         */
        @JsonInclude(JsonInclude.Include.NON_NULL)
        @JsonClassDescription("Weather API request")
        public record Request(@JsonProperty(required = true,
                value = "location") @JsonPropertyDescription("The city and state e.g.广州") String location) {
        }
        /**
         * Weather Function response.
         */
        public record Response(String weather) {
        }
        @Override
        public WeatherService.Response apply(WeatherService.Request request) {
            log.info("location: {}", request.location);
            String weather = "";
            if (request.location().contains("广州")) {
                weather = "小雨转阴 13~19°C";
            } else if (request.location().contains("深圳")) {
                weather = "阴 15~26°C";
            } else {
                weather = "热到中暑 99~100°C";
            }
            return new WeatherService.Response(weather);
        }
    }
    

     4.3、新增配置类

    import org.springframework.ai.model.function.FunctionCallback;
    import org.springframework.ai.model.function.FunctionCallbackWrapper;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.context.annotation.Description;
    import java.util.function.Function;
    @Configuration
    public class FunctionConfig {
        @Bean
        public FunctionCallback weatherFunctionInfo() {
            return new FunctionCallbackWrapper("currentWeather", // (1) function name
                    "Get the weather in location", // (2) function description
                    new WeatherService()); // function code
        }
    }
    

     4.4、新增Controller类

    import cn.hutool.core.collection.CollUtil;
    import cn.hutool.core.map.MapUtil;
    import jakarta.servlet.http.HttpServletResponse;
    import lombok.extern.slf4j.Slf4j;
    import org.apache.commons.lang3.StringUtils;
    import org.springframework.ai.chat.Generation;
    import org.springframework.ai.chat.messages.AssistantMessage;
    import org.springframework.ai.chat.messages.Message;
    import org.springframework.ai.chat.messages.UserMessage;
    import org.springframework.ai.chat.prompt.Prompt;
    import org.springframework.ai.chat.prompt.SystemPromptTemplate;
    import org.springframework.ai.openai.OpenAiChatClient;
    import org.springframework.ai.openai.OpenAiChatOptions;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RestController;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    import java.util.List;
    @Slf4j
    @RestController
    @RequestMapping("/api")
    public class OpenaiTestController {
        @Autowired
        private OpenAiChatClient openAiChatClient;
        @RequestMapping("/function_call")
        public String function_call(){
            String systemPrompt = "{prompt}";
            SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
            String userPrompt = "广州的天气如何?";
            Message userMessage = new UserMessage(userPrompt);
            Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "你是一个有用的人工智能助手"));
            Prompt prompt = new Prompt(List.of(userMessage, systemMessage), OpenAiChatOptions.builder().withFunction("currentWeather").build());
            List response = openAiChatClient.call(prompt).getResults();
            String result = "";
            for (Generation generation : response){
                String content = generation.getOutput().getContent();
                result += content;
            }
            return result;
        }
    }
    

    五、测试

    调用结果:

      浏览器输出:

    开源模型应用落地-工具使用篇-Spring AI-Function Call(八),第3张

      idea输出:

    开源模型应用落地-工具使用篇-Spring AI-Function Call(八),第4张


    六、附带说明

    6.1、流式模式不支持Function Call

    开源模型应用落地-工具使用篇-Spring AI-Function Call(八),第5张

    6.2、更多的模型参数配置

    OpenAI Chat :: Spring AI Reference

    开源模型应用落地-工具使用篇-Spring AI-Function Call(八),第6张

    6.3、qwen系列模型如何支持function call

     通过vllm启动兼容openai接口的api_server,命令如下:

    python -m vllm.entrypoints.openai.api_server --served-model-name Qwen1.5-7B-Chat --model Qwen/Qwen1.5-7B-Chat 

       详细教程参见:

      使用以下代码进行测试:

    # Reference: https://openai.com/blog/function-calling-and-other-api-updates
    import json
    from pprint import pprint
    import openai
    # To start an OpenAI-like Qwen server, use the following commands:
    #   git clone https://github.com/QwenLM/Qwen-7B;
    #   cd Qwen-7B;
    #   pip install fastapi uvicorn openai pydantic sse_starlette;
    #   python openai_api.py;
    #
    # Then configure the api_base and api_key in your client:
    openai.api_base = 'http://localhost:8000/v1'
    openai.api_key = 'none'
    def call_qwen(messages, functions=None):
        print('input:')
        pprint(messages, indent=2)
        if functions:
            response = openai.ChatCompletion.create(model='Qwen',
     messages=messages,
     functions=functions)
        else:
            response = openai.ChatCompletion.create(model='Qwen',
     messages=messages)
        response = response.choices[0]['message']
        response = json.loads(json.dumps(response,
                                         ensure_ascii=False))  # fix zh rendering
        print('output:')
        pprint(response, indent=2)
        print()
        return response
    def test_1():
        messages = [{'role': 'user', 'content': '你好'}]
        call_qwen(messages)
        messages.append({'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'})
        messages.append({
            'role': 'user',
            'content': '给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。'
        })
        call_qwen(messages)
        messages.append({
            'role':
            'assistant',
            'content':
            '故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……',
        })
        messages.append({'role': 'user', 'content': '给这个故事起一个标题'})
        call_qwen(messages)
    def test_2():
        functions = [
            {
                'name_for_human':
                '谷歌搜索',
                'name_for_model':
                'google_search',
                'description_for_model':
                '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。' +
                ' Format the arguments as a JSON object.',
                'parameters': [{
                    'name': 'search_query',
                    'description': '搜索关键词或短语',
                    'required': True,
                    'schema': {
                        'type': 'string'
                    },
                }],
            },
            {
                'name_for_human':
                '文生图',
                'name_for_model':
                'image_gen',
                'description_for_model':
                '文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。' +
                ' Format the arguments as a JSON object.',
                'parameters': [{
                    'name': 'prompt',
                    'description': '英文关键词,描述了希望图像具有什么内容',
                    'required': True,
                    'schema': {
                        'type': 'string'
                    },
                }],
            },
        ]
        messages = [{'role': 'user', 'content': '(请不要调用工具)\n\n你好'}]
        call_qwen(messages, functions)
        messages.append({
            'role': 'assistant',
            'content': '你好!很高兴见到你。有什么我可以帮忙的吗?'
        }, )
        messages.append({'role': 'user', 'content': '搜索一下谁是周杰伦'})
        call_qwen(messages, functions)
        messages.append({
            'role': 'assistant',
            'content': '我应该使用Google搜索查找相关信息。',
            'function_call': {
                'name': 'google_search',
                'arguments': '{"search_query": "周杰伦"}',
            },
        })
        messages.append({
            'role': 'function',
            'name': 'google_search',
            'content': 'Jay Chou is a Taiwanese singer.',
        })
        call_qwen(messages, functions)
        messages.append(
            {
                'role': 'assistant',
                'content': '周杰伦(Jay Chou)是一位来自台湾的歌手。',
            }, )
        messages.append({'role': 'user', 'content': '搜索一下他老婆是谁'})
        call_qwen(messages, functions)
        messages.append({
            'role': 'assistant',
            'content': '我应该使用Google搜索查找相关信息。',
            'function_call': {
                'name': 'google_search',
                'arguments': '{"search_query": "周杰伦 老婆"}',
            },
        })
        messages.append({
            'role': 'function',
            'name': 'google_search',
            'content': 'Hannah Quinlivan'
        })
        call_qwen(messages, functions)
        messages.append(
            {
                'role': 'assistant',
                'content': '周杰伦的老婆是Hannah Quinlivan。',
            }, )
        messages.append({'role': 'user', 'content': '用文生图工具画个可爱的小猫吧,最好是黑猫'})
        call_qwen(messages, functions)
        messages.append({
            'role': 'assistant',
            'content': '我应该使用文生图API来生成一张可爱的小猫图片。',
            'function_call': {
                'name': 'image_gen',
                'arguments': '{"prompt": "cute black cat"}',
            },
        })
        messages.append({
            'role':
            'function',
            'name':
            'image_gen',
            'content':
            '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}',
        })
        call_qwen(messages, functions)
    def test_3():
        functions = [{
            'name': 'get_current_weather',
            'description': 'Get the current weather in a given location.',
            'parameters': {
                'type': 'object',
                'properties': {
                    'location': {
                        'type': 'string',
                        'description':
                        'The city and state, e.g. San Francisco, CA',
                    },
                    'unit': {
                        'type': 'string',
                        'enum': ['celsius', 'fahrenheit']
                    },
                },
                'required': ['location'],
            },
        }]
        messages = [{
            'role': 'user',
            # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts,
            # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting.
            'content': '波士顿天气如何?',
        }]
        call_qwen(messages, functions)
        messages.append(
            {
                'role': 'assistant',
                'content': None,
                'function_call': {
                    'name': 'get_current_weather',
                    'arguments': '{"location": "Boston, MA"}',
                },
            }, )
        messages.append({
            'role':
            'function',
            'name':
            'get_current_weather',
            'content':
            '{"temperature": "22", "unit": "celsius", "description": "Sunny"}',
        })
        call_qwen(messages, functions)
    def test_4():
        from langchain.agents import AgentType, initialize_agent, load_tools
        from langchain.chat_models import ChatOpenAI
        llm = ChatOpenAI(
            model_name='Qwen',
            openai_api_base='http://localhost:8000/v1',
            openai_api_key='EMPTY',
            streaming=False,
        )
        tools = load_tools(['arxiv'], )
        agent_chain = initialize_agent(
            tools,
            llm,
            agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
            verbose=True,
        )
        # TODO: The performance is okay with Chinese prompts, but not so good when it comes to English.
        agent_chain.run('查一下论文 1605.08386 的信息')
    if __name__ == '__main__':
        print('### Test Case 1 - No Function Calling (普通问答、无函数调用) ###')
        test_1()
        print('### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###')
        test_2()
        print('### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###')
        test_3()
        print('### Test Case 4 - Use LangChain (接入Langchain) ###')
        test_4()