相关推荐recommended
springboot集成websocket持久连接(权限过滤+拦截)
作者:mmseoamin日期:2024-02-02

文章目录

  • 1、为什么要使用WebSocket?
  • 2、配置方式一:实现ServletContextInitializer+@ServerEndpoint注解
    • 2.1、WebSocket配置
    • 2.2、WebSocket连接,@ServerEndpoint
    • 2.3、WebSocket请求过滤
    • 2.4、postman建立客户端连接
    • 3、配置方式二:实现WebSocketConfigurer+继承TextWebSocketHandler
      • 3.1、配置:实现WebSocketConfigurer
      • 3.2、配置:WebSocket握手,实现对websocket请求的拦截
      • 3.3、实现WebSocket服务,监听socket客户端的连接
      • 3.4、postman测试,发起websocket请求
      • 参考链接

                springboot关于websocket依赖,pom.xml:集成springboot最小依赖,不进行任何服务连接,单独启动项目

          
          
              org.springframework.boot
              spring-boot-starter-parent
              2.2.5.RELEASE
          
        
        
          
              org.springframework.boot
              spring-boot-starter
          
          
          
            org.springframework.boot
            spring-boot-starter-web
            
                
                
                    spring-boot-starter-tomcat
                    org.springframework.boot
                
            
          
          
          
              org.springframework.boot
              spring-boot-starter-undertow
          
          
          
              org.springframework.boot
              spring-boot-starter-websocket
          
        
        

        1、为什么要使用WebSocket?

                因为一般的请求都是HTTP请求(单向通信),HTTP是一个短连接(非持久化),且通信只能由客户端发起,HTTP协议做不到服务器主动向客户端推送消息。举个例子:前后端交互就是前端发送请求,从后端拿到数据后展示到页面,如果前端没有主动请求接口,那后端就不能发送数据给前端。然而,WebSocket确能很好的解决这个问题,服务端可以主动向客户端推送消息,客户端也可以主动向服务端发送消息,实现了服务端和客户端真正的平等。

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第1张

        2、配置方式一:实现ServletContextInitializer+@ServerEndpoint注解

        2.1、WebSocket配置

        package com.chengfu.config;
        import org.springframework.boot.web.servlet.ServletContextInitializer;
        import org.springframework.context.annotation.Bean;
        import org.springframework.context.annotation.Configuration;
        import org.springframework.web.socket.server.standard.ServerEndpointExporter;
        import javax.servlet.ServletContext;
        import javax.servlet.ServletException;
        @Configuration
        public class WebSocketConfig implements ServletContextInitializer {
            //    原文链接:https://blog.csdn.net/weixin_44185837/article/details/124942482
            /**
             * 这个bean的注册,用于扫描带有@ServerEndpoint的注解成为websocket,如果你使用外置的tomcat就不需要该配置文件
             */
            @Bean
            public ServerEndpointExporter serverEndpointExporter() {
                return new ServerEndpointExporter();
            }
            @Override
            public void onStartup(ServletContext servletContext) throws ServletException {
            }
        }
        

        2.2、WebSocket连接,@ServerEndpoint

        package com.chengfu.socket;
        import lombok.extern.slf4j.Slf4j;
        import org.springframework.stereotype.Component;
        import javax.websocket.*;
        import javax.websocket.server.PathParam;
        import javax.websocket.server.ServerEndpoint;
        import java.io.IOException;
        import java.util.concurrent.ConcurrentHashMap;
        import java.util.concurrent.CopyOnWriteArraySet;
        /**
         * websocket服务端,接收websocket客户端长连接
         */
        @ServerEndpoint("/websocket/api1/{id}")
        @Component
        @Slf4j
        public class WebSocketServer {
            // 与某个客户端的连接会话,需要通过它来给客户端发送数据
            private Session session;
            // session集合,存放对应的session
            private static ConcurrentHashMap sessionPool = new ConcurrentHashMap<>();
            // concurrent包的线程安全Set,用来存放每个客户端对应的WebSocket对象。
            private static CopyOnWriteArraySet webSocketSet = new CopyOnWriteArraySet<>();
            /**
             * 建立WebSocket连接
             *
             * @param session
             * @param userId  用户ID
             */
            @OnOpen
            public void onOpen(Session session, @PathParam(value = "id") Integer userId) {
                log.info("WebSocket建立连接中,连接用户ID:{}", userId);
                try {
                    Session historySession = sessionPool.get(userId);
                    // historySession不为空,说明已经有人登陆账号,应该删除登陆的WebSocket对象
                    if (historySession != null) {
                        webSocketSet.remove(historySession);
                        historySession.close();
                    }
                } catch (IOException e) {
                    log.error("重复登录异常,错误信息:" + e.getMessage(), e);
                }
                // 建立连接
                this.session = session;
                webSocketSet.add(this);
                sessionPool.put(userId, session);
                log.info("建立连接完成,当前在线人数为:{}", webSocketSet.size());
            }
            /**
             * 发生错误
             *
             * @param throwable e
             */
            @OnError
            public void onError(Throwable throwable) {
                throwable.printStackTrace();
            }
            /**
             * 连接关闭
             */
            @OnClose
            public void onClose() {
                webSocketSet.remove(this);
                log.info("连接断开,当前在线人数为:{}", webSocketSet.size());
            }
            /**
             * 接收客户端消息
             *
             * @param message 接收的消息
             */
            @OnMessage
            public void onMessage(String message) {
                log.info("收到客户端发来的消息:{}", message);
            }
            /**
             * 推送消息到指定用户
             *
             * @param userId  用户ID
             * @param message 发送的消息
             */
            public static void sendMessageByUser(Integer userId, String message) {
                log.info("用户ID:" + userId + ",推送内容:" + message);
                Session session = sessionPool.get(userId);
                try {
                    session.getBasicRemote().sendText(message);
                } catch (IOException e) {
                    log.error("推送消息到指定用户发生错误:" + e.getMessage(), e);
                }
            }
            /**
             * 群发消息
             *
             * @param message 发送的消息
             */
            public static void sendAllMessage(String message) {
                log.info("发送消息:{}", message);
                for (WebSocketServer webSocket : webSocketSet) {
                    try {
                        webSocket.session.getBasicRemote().sendText(message);
                    } catch (IOException e) {
                        log.error("群发消息发生错误:" + e.getMessage(), e);
                    }
                }
            }
        }
        

        2.3、WebSocket请求过滤

        找了一些资料没发现springboot有处理@ServerEndpoint注解的websocket请求,但是可以使用servlet容器技术来实现ws请求的过滤。ws请求的拦截可以通过配置方式二来实现,详见下。

        package com.chengfu.config;
        import org.apache.commons.lang3.StringUtils;
        import org.springframework.stereotype.Component;
        import javax.servlet.*;
        import javax.servlet.http.HttpServletRequest;
        import java.io.IOException;
        import java.util.Arrays;
        /**
         * servlet级别请求过滤
         */
        @Component
        public class WebSocketFilterConfig implements Filter {
            // websocket请求过滤清单
            private static final String[] FILTER_LIST = {"/websocket/api1/"};
            @Override
            public void init(FilterConfig filterConfig) throws ServletException {
                Filter.super.init(filterConfig);
            }
            @Override
            public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
                HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
                String servletPath = httpServletRequest.getServletPath();
                // anyMatch: 有一个条件满足就返回true
                boolean match = Arrays.stream(FILTER_LIST).anyMatch(servletPath::startsWith);
                // boolean match = Arrays.stream(FILTER_LIST).anyMatch(filterUrl -> servletPath.startsWith(filterUrl));
                // 符合ws请求的url开始校验,其他请求一概放过
                if (match) {
                    String token = httpServletRequest.getHeader("token");
                    if (StringUtils.isNotBlank(token)) {
                        filterChain.doFilter(servletRequest, servletResponse);
                    }
        //            else {
        //                JSONObject result = new JSONObject();
        //                HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
        //                httpServletResponse.setContentType("application/json;charset=utf-8");
        //                httpServletResponse.setCharacterEncoding("utf-8");
        //                PrintWriter writer = httpServletResponse.getWriter();
        //                writer.write(result.toJSONString());
        //                writer.flush();
        //                writer.close();
        //            }
                } else {
                    // 不包含过滤清单,直接放过请求
                    filterChain.doFilter(servletRequest, servletResponse);
                }
            }
            @Override
            public void destroy() {
                Filter.super.destroy();
            }
        }
        

        2.4、postman建立客户端连接

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第2张

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第3张

        连接成功!

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第4张

        3、配置方式二:实现WebSocketConfigurer+继承TextWebSocketHandler

        3.1、配置:实现WebSocketConfigurer

        package com.chengfu.config;
        import com.chengfu.interceptor.WebSocketAuthInterceptor;
        import com.chengfu.socket.WebSocketServer2;
        import org.springframework.context.annotation.Configuration;
        import org.springframework.web.socket.config.annotation.EnableWebSocket;
        import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
        import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
        // 启用配置
        @Configuration
        // 启用websocket服务端
        @EnableWebSocket
        public class WebSocketConfig2 implements WebSocketConfigurer {
            // 原文链接:https://blog.csdn.net/weixin_44185837/article/details/124942482
            // 实现自: WebSocketConfigurer
            // 注册websocket拦截器
            @Override
            public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) {
                webSocketHandlerRegistry
                        // 只有符合"/websocket/api2/**", "/websocket/api3/**"的请求url,才能进入WebSocketServer2的服务端连接进行数据处理
                        .addHandler(new WebSocketServer2(), "/websocket/api2/**", "/websocket/api3/**")
                        // WebSocketServer2的握手拦截器处理:尽量避免被无用的请求攻击,在建立连接的时候通过检查授权成功之后才能进行访问
                        .addInterceptors(new WebSocketAuthInterceptor())
                        // 允许跨域访问
                        .setAllowedOrigins("*");
            }
        }
        

        3.2、配置:WebSocket握手,实现对websocket请求的拦截

        package com.chengfu.interceptor;
        import lombok.extern.slf4j.Slf4j;
        import org.apache.commons.lang3.StringUtils;
        import org.springframework.http.HttpHeaders;
        import org.springframework.http.server.ServerHttpRequest;
        import org.springframework.http.server.ServerHttpResponse;
        import org.springframework.stereotype.Component;
        import org.springframework.web.socket.WebSocketHandler;
        import org.springframework.web.socket.server.HandshakeInterceptor;
        import java.util.List;
        import java.util.Map;
        /**
         * websocket握手拦截
         */
        @Component
        @Slf4j
        public class WebSocketAuthInterceptor implements HandshakeInterceptor {
            /**
             * 返回true允许直接通过,返回false拒绝连接
             *
             * @param serverHttpRequest
             * @param serverHttpResponse
             * @param webSocketHandler
             * @param map
             * @return
             * @throws Exception
             */
            @Override
            public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map map) throws Exception {
                log.error("握手开始!");
                // 确保授权正确才能进行websocket连接
                HttpHeaders headers = serverHttpRequest.getHeaders();
                List header = headers.get("token");
                if (header == null || header.size() == 0) {
                    return false;
                }
                String token = header.get(0);
                if (StringUtils.isBlank(token)) {
                    return false;
                }
                // TODO: token do something
                log.error("握手token:{}", token);
                return true;
            }
            @Override
            public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {
                log.error("握手结束");
            }
        }
        

        3.3、实现WebSocket服务,监听socket客户端的连接

        package com.chengfu.socket;
        import lombok.extern.slf4j.Slf4j;
        import org.springframework.http.HttpHeaders;
        import org.springframework.stereotype.Component;
        import org.springframework.web.socket.CloseStatus;
        import org.springframework.web.socket.TextMessage;
        import org.springframework.web.socket.WebSocketSession;
        import org.springframework.web.socket.handler.TextWebSocketHandler;
        import java.io.IOException;
        import java.time.LocalDateTime;
        import java.util.concurrent.ConcurrentHashMap;
        // 配合webSocketConfig2使用
        @Component
        @Slf4j
        public class WebSocketServer2 extends TextWebSocketHandler {
            /**
             * socket 建立成功事件 @OnOpen
             *
             * @param session
             * @throws Exception
             */
            @Override
            public void afterConnectionEstablished(WebSocketSession session) throws Exception {
                // websocket入参
                // ===============从url上面获取参数
                String rawPath = session.getUri().getRawPath(); //
                String rawQuery = session.getUri().getRawQuery();//  从url上面获取参数
                String query = session.getUri().getQuery(); //  从url上面获取参数
                // =================
                // ================从header上面获取参数
                HttpHeaders headers = session.getHandshakeHeaders(); // 从header上获取参数
                // ================
                String token = headers.get("token").get(0);
                if (token != null) {
                    WebSocketSession s = WebSessionManager.get(token);
                    if (s != null) {
                        // 当前用户之前已经建立连接,关闭
                        WebSessionManager.remove(token);
                    }
                    // 重新建立session
                    WebSessionManager.add(token, session);
                    log.error("建立websocket连接:{}", token);
                }
            }
            /**
             * 接收消息事件 @OnMessage
             *
             * @param session
             * @param message
             * @throws Exception
             */
            @Override
            protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
                // 获得客户端传来的消息
                String payload = message.getPayload();
                System.out.println("server 接收到发送的消息 " + payload);
                // 服务端发送回去
                session.sendMessage(new TextMessage("server 发送消息 " + payload + " " + LocalDateTime.now()));
            }
            /**
             * socket 断开连接时 @OnClose
             *
             * @param session
             * @param status
             * @throws Exception
             */
            @Override
            public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
                Object token = session.getAttributes().get("token");
                if (token != null) {
                    // 用户退出,移除缓存
                    WebSessionManager.remove(token.toString());
                }
            }
        }
        // 来源:https://www.cnblogs.com/meow-world/articles/16283492.html
        // websocket的session管理器
        class WebSessionManager {
            /**
             * 保存连接 session 的地方
             */
            private static final ConcurrentHashMap SESSION_POOL = new ConcurrentHashMap<>();
            /**
             * 添加 session
             *
             * @param key
             */
            public static void add(String key, WebSocketSession session) {
                // 添加 session
                SESSION_POOL.put(key, session);
            }
            /**
             * 删除 session,会返回删除的 session
             *
             * @param key
             * @return
             */
            public static WebSocketSession remove(String key) {
                // 删除 session
                return SESSION_POOL.remove(key);
            }
            /**
             * 删除并同步关闭连接
             *
             * @param key
             */
            public static void removeAndClose(String key) {
                WebSocketSession session = remove(key);
                if (session != null) {
                    try {
                        // 关闭连接
                        session.close();
                    } catch (IOException e) {
                        // todo: 关闭出现异常处理
                        e.printStackTrace();
                    }
                }
            }
            /**
             * 获得 session
             *
             * @param key
             * @return
             */
            public static WebSocketSession get(String key) {
                // 获得 session
                return SESSION_POOL.get(key);
            }
        }
        

        3.4、postman测试,发起websocket请求

        注!

        1. postman版本要升级到10以上,我这里测试版本为v10.17.1;

        2. Ctrl+N,打开创建请求的小窗,选择WebSocket;

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第5张

        postman请求地址:

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第6张

                建立websocket连接之前,先进行握手处理,检查headers参数之后,如果符合要求,开始建立连接。

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第7张

                拦截发现headers没有token标识,拒绝客户端连接请求,请求头添加token后,重新请求:

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第8张

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第9张

                检测到token之后,成功建立连接:

        springboot集成websocket持久连接(权限过滤+拦截),image.png,第10张

        参考链接

        1. SpringBoot使用WebSocket_springboot websocket_仰望银河系的博客-CSDN博客
        2. Springboot——拦截器_springboot 拦截器_我爱布朗熊的博客-CSDN博客
        3. 实现Websocket集群及通信的第二种方式(含拦截器)_websockethandler_Second小二的博客-CSDN博客
        4. springboot整合websocket握手拦截器 - meow_world - 博客园
        5. Postman 使用 WebSocket 请求_w3cschool