package com.qkdata.common.trace;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.slf4j.MDC;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Instant;
import java.time.LocalDateTime;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
 * 记录请求参数、请求路径、响应数据、请求时间等信息
 *
 * @see org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter
 */
@Slf4j
public class HttpTraceLogFilter extends OncePerRequestFilter implements Ordered {

    private int order = Ordered.LOWEST_PRECEDENCE - 10;

    private ObjectMapper objectMapper;

    public HttpTraceLogFilter(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    @Override
    public int getOrder() {
        return this.order;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        if (!isRequestValid(request)) {
            filterChain.doFilter(request, response);
            return;
        }

        // request的stream只能被消费一次
        if (!(request instanceof ContentCachingRequestWrapper)) {
            request = new ContentCachingRequestWrapper(request);
        }
        if (!(response instanceof ContentCachingResponseWrapper)) {
            response = new ContentCachingResponseWrapper(response);
        }

        int status = HttpStatus.OK.value();
        long startTime = Instant.now().toEpochMilli();
        try {
            filterChain.doFilter(request, response);
            status = response.getStatus();
        } finally {
            long timeTaken = Instant.now().toEpochMilli() - startTime;
            TraceLog traceLog = new TraceLog();
            traceLog.setTraceId(MDC.get(TraceIdInterceptor.TRACE_ID_KEY));
            traceLog.setTimeTaken(timeTaken);
            traceLog.setCreateAt(LocalDateTime.now());
            traceLog.setMethod(request.getMethod());
            traceLog.setPath(request.getRequestURI());
            traceLog.setStatus(status);
            traceLog.setRequestBody(getRequestBody(request));
            traceLog.setResponseBody(getResponseBody(response));
            traceLog.setParameters(objectMapper.writeValueAsString(request.getParameterMap()));
            traceLog.setHeaders(objectMapper.writeValueAsString(getHeaders(request)));
//            log.info("Http Trace Log: {}", objectMapper.writeValueAsString(traceLog));

            // 复原response，正常的返回数据
            updateResponse(response);
            MDC.remove(TraceIdInterceptor.TRACE_ID_KEY);
        }

    }

    private Map<String, Object> getHeaders(HttpServletRequest request) {
        Enumeration<String> headerNames = request.getHeaderNames();
        Map<String, Object> heads = new HashMap<>();
        while (headerNames.hasMoreElements()) {
            String name = headerNames.nextElement();
            String value = request.getHeader(name);
            heads.put(name, value);
        }
        return heads;
    }


    private boolean isRequestValid(HttpServletRequest request) {
        try {
            new URI(request.getRequestURL().toString());
            return true;
        } catch (URISyntaxException ex) {
            return false;
        }
    }

    private String getRequestBody(HttpServletRequest request) {
        String requestBody = "";
        ContentCachingRequestWrapper wrapper = WebUtils.getNativeRequest(request, ContentCachingRequestWrapper.class);
        if (wrapper != null) {
            try {

                requestBody = IOUtils.toString(wrapper.getContentAsByteArray(), wrapper.getCharacterEncoding());
            } catch (IOException e) {
                // NOOP
            }
        }
        return requestBody;
    }

    private String getResponseBody(HttpServletResponse response) {
        String responseBody = "";
        ContentCachingResponseWrapper wrapper = WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
        if (wrapper != null) {
            try {
                responseBody = IOUtils.toString(wrapper.getContentAsByteArray(), wrapper.getCharacterEncoding());
            } catch (IOException e) {
                // NOOP
            }
        }
        return responseBody;
    }

    private void updateResponse(HttpServletResponse response) throws IOException {
        ContentCachingResponseWrapper responseWrapper = WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
        Objects.requireNonNull(responseWrapper).copyBodyToResponse();
    }

}
