package com.qkdata.common.oauth;

import com.auth0.jwt.exceptions.*;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.qkdata.biz.sys.entity.SysUserPO;
import com.qkdata.biz.enums.AccountStatusEnum;
import com.qkdata.biz.sys.service.ShiroService;
import com.qkdata.common.jwt.JWTService;
import lombok.extern.slf4j.Slf4j;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.sound.sampled.Line;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;

@Slf4j
@Component
public class OAuthRealm extends AuthorizingRealm {
    @Autowired
    private ShiroService shiroService;
    @Autowired
    private JWTService jwtService;
    private ObjectMapper mapper = new ObjectMapper();

    private DecodedJWT decodedJWT;


    @Override
    public boolean supports(AuthenticationToken token) {
        return token instanceof OAuthToken;
    }

    /**
     * 授权(验证权限时调用)
     */
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
        SysUserPO user = (SysUserPO) principals.getPrimaryPrincipal();
        Long userId = user.getId();

        //用户角色
        Set<String> roles = shiroService.getUserRoles(userId);

        //用户权限列表
        Set<String> permsSet = shiroService.getUserPermissions(userId);
//        Set<String> permsSet = new HashSet<>();
//        permsSet.add("all");
        SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();
        info.setRoles(roles);
        info.setStringPermissions(permsSet);
        return info;
    }

    /**
     * 认证
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {
        String accessToken = (String) token.getPrincipal();

        if (accessToken == null) {
            throw new AuthenticationException(AuthorizationResponseEnum.MISSING_TOKEN.text());
        }
        try {
            decodedJWT = jwtService.decode(accessToken);
            String userJson = decodedJWT.getClaim("user").asString();
            AuthorizedUser authUser = mapper.readValue(userJson,AuthorizedUser.class);
            SysUserPO user = shiroService.getUserByUserName(authUser.getUsername());
            if (user == null){
                throw new AuthenticationException(AuthorizationResponseEnum.INVALID_CLAIM.text());
            }
            if (user.getStatus() == AccountStatusEnum.DISABLE){
                throw new AuthenticationException("您的帐号已被禁用");
            }
            SimpleAuthenticationInfo info = new SimpleAuthenticationInfo(user, accessToken, getName());
            return info;
        } catch (JWTDecodeException | AlgorithmMismatchException | SignatureVerificationException | InvalidClaimException e) {
            log.warn("校验TOKEN失败: {}, token: {}", e.getMessage(), token);
            throw new AuthenticationException(AuthorizationResponseEnum.INVALID_TOKEN.text());
        } catch (TokenExpiredException e) {
            log.warn("TOKEN已过期: {}, token: {}", e.getMessage(), token);
            throw new AuthenticationException(AuthorizationResponseEnum.EXPIRED_TOKEN.text());
        } catch (IOException e) {
            throw new AuthenticationException(AuthorizationResponseEnum.INVALID_CLAIM.text());
        }

    }
}
