package com.qkdata.biz.wsMessage.config;

import com.auth0.jwt.interfaces.DecodedJWT;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.qkdata.common.oauth.AuthorizationException;
import com.qkdata.common.oauth.AuthorizationResponseEnum;
import com.qkdata.common.jwt.JWTService;
import com.qkdata.common.oauth.AuthorizedUser;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;

import java.io.IOException;
import java.security.Principal;
import java.util.Map;

@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

    @Autowired
    @Lazy
    private JWTService jwtService;
    @Autowired
    @Lazy
    private ObjectMapper objectMapper;

    //注册STOMP协议节点并映射url
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        registry.addEndpoint("/websocket") //注册一个 /websocket 的 websocket 节点
                .addInterceptors(myHandshakeInterceptor())  //添加 websocket握手拦截器
                .setHandshakeHandler(myDefaultHandshakeHandler())   //添加 websocket握手处理器
                .setAllowedOrigins("*")     //设置允许可跨域的域名
                .withSockJS();  //指定使用SockJS协议
    }
    /**
     * WebSocket 握手拦截器
     * 可做一些用户认证拦截处理
     */
    private HandshakeInterceptor myHandshakeInterceptor(){
        return new HandshakeInterceptor() {
            /**
             * websocket握手连接
             * @return 返回是否同意握手
             */
            @Override
            public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
                ServletServerHttpRequest req = (ServletServerHttpRequest) request;
                //通过url的query参数获取认证参数
                String token = req.getServletRequest().getParameter("token");
                //根据token认证用户，不通过返回拒绝握手
                if (token == null){
                    return false;
                }
                Principal user = authenticate(token);
                if(user == null){
                    return false;
                }
                //保存认证用户
                attributes.put("user", user);
                return true;
            }

            @Override
            public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {

            }
        };
    }

    //WebSocket 握手处理器
    private DefaultHandshakeHandler myDefaultHandshakeHandler(){
        return new DefaultHandshakeHandler(){
            @Override
            protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, Map<String, Object> attributes) {
                //设置认证通过的用户到当前会话中
                return (Principal)attributes.get("user");
            }
        };
    }

    /**
     * 定义一些消息连接规范（也可不设置）
     * @param registry
     */
    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {
        //设置客户端接收消息地址的前缀（可不设置）
        registry.enableSimpleBroker(
                "/topic/room","/topic/login"
        );
        //设置客户端接收点对点消息地址的前缀，默认为 /user
        registry.setUserDestinationPrefix("/user");
        //设置客户端向服务器发送消息的地址前缀（可不设置）
        registry.setApplicationDestinationPrefixes("/app");
    }

    /**
     * 根据token认证授权
     * @param token
     */
    private Principal authenticate(String token){
        //用户信息需继承 Principal 并实现 getName() 方法，返回全局唯一值
        DecodedJWT jwtToken = jwtService.decode(token);
        String userJson = jwtToken.getClaim("user").asString();

        if (StringUtils.isEmpty(userJson)) {
            return null;
        }

        try {
            AuthorizedUser authorizedUser = objectMapper.readValue(userJson, AuthorizedUser.class);
            return new MyPrincipal(authorizedUser.getUsername());
        } catch (IOException e) {
            throw new AuthorizationException(AuthorizationResponseEnum.INVALID_CLAIM);
        }
    }

}
