开源模型应用落地-工具使用篇-Spring AI-高阶用法(九)
作者:mmseoamin日期:2024-03-20

一、前言

    通过“开源模型应用落地-工具使用篇-Spring AI-Function Call(八)-CSDN博客”文章的学习,已经掌握了如何通过Spring AI集成OpenAI以及如何进行function call的调用,现在将进一步学习Spring AI更高阶的用法,如:传递历史上下文对话,调整模型参数等。


二、术语

2.1、Spring AI

  是 Spring 生态系统的一个新项目,它简化了 Java 中 AI 应用程序的创建。它提供以下功能:

  • 支持所有主要模型提供商,例如 OpenAI、Microsoft、Amazon、Google 和 Huggingface。
  • 支持的模型类型包括“聊天”和“文本到图像”,还有更多模型类型正在开发中。
  • 跨 AI 提供商的可移植 API,用于聊天和嵌入模型。
  • 支持同步和流 API 选项。
  • 支持下拉访问模型特定功能。
  • AI 模型输出到 POJO 的映射。

    三、前置条件

    3.1、JDK 17+

        下载地址:Java Downloads | Oracle

        开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第1张

      

    3.2、创建Maven项目

        SpringBoot版本为3.2.3

    
        org.springframework.boot
        spring-boot-starter-parent
        3.2.3
         
    

    3.3、导入Maven依赖包

    
    	org.projectlombok
    	lombok
    	true
    
    
    	ch.qos.logback
    	logback-core
    
    
    	ch.qos.logback
    	logback-classic
    
    
    	cn.hutool
    	hutool-core
    	5.8.24
    
    
    	org.springframework.ai
    	spring-ai-openai-spring-boot-starter
    	0.8.0
    
    

    3.4、 科学上网的软件


    四、技术实现

    4.1、新增配置

    spring:
      ai:
        openai:
          api-key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

    PS:

    1.   openai要替换自己的api-key
    2.   模型参数根据实际情况调整

     4.2、历史上下文对话传递

      # 方式一

      使用 UserMessage 和 AssistantMessage 指定上下文

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第2张

      # 方式二

        使用 ChatMessage 指定上下文

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第3张

     4.3、 调整模型参数

      # 方式一

      在配置文件中指定

      开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第4张

      # 方式二

      在代码中指定

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第5张


    五、测试

    在代码中指定的上下文:

    对话次数用户AI
    第一轮
    你好
    你好!很高兴能为你提供帮助。有什么问题可以问我吗?
    第二轮
    我家在广州,你呢?
    我是一个人工智能助手,没有具体的居住地。不过我可以帮助你解答问题和提供信息。有什么我可以帮你的吗?
    第三轮我家有什么特产

    浏览器返回的结果:

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第6张

    idea返回的结果:

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第7张

      结论:

      AI能识别出我家在广州,并给出广州的特产


    六、附带说明

    6.1、更多的模型参数配置

    OpenAI Chat :: Spring AI Reference

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第8张

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第9张

    开源模型应用落地-工具使用篇-Spring AI-高阶用法(九),第10张

    6.2、完整代码

    import cn.hutool.core.collection.CollUtil;
    import cn.hutool.core.map.MapUtil;
    import jakarta.servlet.http.HttpServletResponse;
    import lombok.extern.slf4j.Slf4j;
    import org.apache.commons.lang3.StringUtils;
    import org.springframework.ai.chat.Generation;
    import org.springframework.ai.chat.messages.*;
    import org.springframework.ai.chat.prompt.ChatOptions;
    import org.springframework.ai.chat.prompt.Prompt;
    import org.springframework.ai.chat.prompt.SystemPromptTemplate;
    import org.springframework.ai.openai.OpenAiChatClient;
    import org.springframework.ai.openai.OpenAiChatOptions;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RestController;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    import java.util.List;
    @Slf4j
    @RestController
    @RequestMapping("/api")
    public class OpenaiTestController {
        @Autowired
        private OpenAiChatClient openAiChatClient;
        @RequestMapping("/history")
        public SseEmitter history(HttpServletResponse response) {
            response.setContentType("text/event-stream");
            response.setCharacterEncoding("UTF-8");
            SseEmitter emitter = new SseEmitter();
            String systemPrompt = "{prompt}";
            SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
            String userPrompt = "我家有什么特产?";
            Message userMessage = new UserMessage(userPrompt);
            Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));
            UserMessage userChatMessage1 = new UserMessage("你好");
            AssistantMessage assistantChatMessage1 = new AssistantMessage("你好!很高兴能为你提供帮助。有什么问题可以问我吗?");
            UserMessage userChatMessage2 = new UserMessage("我家在广州,你呢?");
            AssistantMessage assistantChatMessage2 = new AssistantMessage("我是一个人工智能助手,没有具体的居住地。不过我可以帮助你解答问题和提供信息。有什么我可以帮你的吗?");
    //        ChatMessage userChatMessage2 = new ChatMessage(MessageType.USER, "你好");
    //        ChatMessage assistantChatMessage2 = new ChatMessage(MessageType.ASSISTANT, "你好!很高兴能为你提供帮助。有什么问题可以问我吗?");
    //
    //        ChatMessage userChatMessage2 = new ChatMessage(MessageType.USER, "我家在广州,你呢?");
    //        ChatMessage assistantChatMessage2 = new ChatMessage(MessageType.ASSISTANT, "我是一个人工智能助手,没有具体的居住地。不过我可以帮助你解答问题和提供信息。有什么我可以帮你的吗?");
            OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder()
                    .withModel("gpt-3.5-turbo")
                    .withTemperature(0.7f)
                    .withMaxTokens(4096)
                    .withN(1)
                    .withTopP(0.9f)
                    .build();
            Prompt prompt = new Prompt(List.of(userChatMessage1, assistantChatMessage1, userChatMessage2, assistantChatMessage2, userMessage, systemMessage), openAiChatOptions);
            log.info(prompt.toString());
            openAiChatClient.stream(prompt).subscribe(x -> {
                try {
                    log.info("response: {}", x);
                    List generations = x.getResults();
                    if (CollUtil.isNotEmpty(generations)) {
                        for (Generation generation : generations) {
                            AssistantMessage assistantMessage = generation.getOutput();
                            String content = assistantMessage.getContent();
                            if (StringUtils.isNotEmpty(content)) {
                                emitter.send(content);
                            } else {
                                if (StringUtils.equals(content, "null"))
                                    emitter.complete(); // Complete the SSE connection
                            }
                        }
                    }
                } catch (Exception e) {
                    emitter.complete();
                    log.error("流式返回结果异常", e);
                }
            });
            return emitter;
        }
    }