/*
 * Decompiled with CFR 0.152.
 */
package io.jans.as.server.rate;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.jans.as.client.util.ClientUtil;
import io.jans.as.model.common.FeatureFlagType;
import io.jans.as.model.configuration.AppConfiguration;
import io.jans.as.model.configuration.rate.KeyExtractor;
import io.jans.as.model.configuration.rate.RateLimitConfig;
import io.jans.as.model.configuration.rate.RateLimitRule;
import io.jans.as.model.error.ErrorResponseFactory;
import io.jans.as.model.exception.InvalidJwtException;
import io.jans.as.model.jwt.Jwt;
import io.jans.as.model.util.Pair;
import io.jans.as.server.rate.RateLimitContext;
import io.jans.as.server.rate.RateLimitedException;
import io.jans.service.cdi.event.ConfigurationUpdate;
import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.event.Observes;
import jakarta.inject.Inject;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.json.JSONObject;
import org.slf4j.Logger;

@ApplicationScoped
public class RateLimitService {
    private static final int DEFAULT_REQUEST_LIMIT = 10;
    private static final int DEFAULT_PERIOD_LIMIT = 60;
    public static final int KEY_LENGTH_LIMIT_FOR_DIGEST = 300;
    @Inject
    private Logger log;
    @Inject
    private AppConfiguration appConfiguration;
    @Inject
    private ErrorResponseFactory errorResponseFactory;
    private final Cache<String, Bucket> buckets = CacheBuilder.newBuilder().expireAfterWrite(2L, TimeUnit.MINUTES).build();
    private RateLimitConfig rateLimitConfiguration;

    public HttpServletRequest validateRateLimit(HttpServletRequest httpRequest) throws RateLimitedException, IOException {
        String method;
        if (!this.errorResponseFactory.isFeatureFlagEnabled(FeatureFlagType.RATE_LIMIT)) {
            return httpRequest;
        }
        if (this.rateLimitConfiguration == null || CollectionUtils.isEmpty((Collection)this.rateLimitConfiguration.getRateLimitRules())) {
            return httpRequest;
        }
        String requestPath = httpRequest.getRequestURI();
        List<RateLimitRule> matchedRules = this.matchRulesByPathAndMethod(requestPath, method = httpRequest.getMethod());
        if (matchedRules.isEmpty()) {
            return httpRequest;
        }
        RateLimitContext rateLimitContext = new RateLimitContext(httpRequest, this.rateLimitConfiguration.isRateLoggingEnabled());
        List<Pair<String, RateLimitRule>> keyWithRules = this.buildKeyPerRule(rateLimitContext, matchedRules);
        for (Pair<String, RateLimitRule> keyWithRule : keyWithRules) {
            String key = (String)keyWithRule.getFirst();
            RateLimitRule rule = (RateLimitRule)keyWithRule.getSecond();
            int requestLimit = this.getRequestLimit(rule.getRequestCount());
            int periodLimit = this.getPeriodLimit(rule.getPeriodInSeconds());
            key = RateLimitService.saveSpaceIfNeeded(key);
            try {
                Bucket bucket = (Bucket)this.buckets.get((Object)key, () -> this.newBucket(requestLimit, periodLimit));
                if (bucket.tryConsume(1L)) continue;
                String msg = String.format("Rate limited '%s'. Exceeds limit %s requests per %s seconds. Key: %s", requestPath, requestLimit, periodLimit, key);
                this.log.debug(msg);
                throw new RateLimitedException(msg);
            }
            catch (ExecutionException e) {
                this.log.error(e.getMessage(), (Throwable)e);
            }
        }
        if (rateLimitContext.isCachedRequestAvailable()) {
            return rateLimitContext.getCachedRequest();
        }
        return httpRequest;
    }

    @NotNull
    public static String saveSpaceIfNeeded(String key) {
        if (key.length() > 300) {
            key = DigestUtils.sha256Hex((String)key);
        }
        return key;
    }

    private List<Pair<String, RateLimitRule>> buildKeyPerRule(RateLimitContext rateLimitContext, List<RateLimitRule> matchedRules) {
        ArrayList<Pair<String, RateLimitRule>> keyWithRules = new ArrayList<Pair<String, RateLimitRule>>();
        for (RateLimitRule rule : matchedRules) {
            try {
                keyWithRules.add((Pair<String, RateLimitRule>)new Pair((Object)this.buildKey(rateLimitContext, rule), (Object)rule));
            }
            catch (IOException e) {
                this.log.error(e.getMessage(), (Throwable)e);
            }
        }
        return keyWithRules;
    }

    public String buildKey(RateLimitContext rateLimitContext, RateLimitRule rule) throws IOException {
        String requestPath = rateLimitContext.getRequest().getRequestURI();
        StringBuilder key = new StringBuilder(requestPath + "_");
        for (KeyExtractor keyExtractor : rule.getKeyExtractors()) {
            key.append(this.extractKey(keyExtractor, rateLimitContext)).append("_");
        }
        String keyString = key.toString();
        if (rateLimitContext.isRateLoggingEnabled() && this.log.isTraceEnabled()) {
            this.log.trace("Rate limit key: {}", (Object)keyString);
        }
        return keyString;
    }

    protected String extractKey(KeyExtractor keyExtractor, RateLimitContext rateLimitContext) throws IOException {
        StringBuilder key = new StringBuilder();
        switch (keyExtractor.getSource()) {
            case HEADER: {
                for (String header : keyExtractor.getParameterNames()) {
                    String value = rateLimitContext.getRequest().getHeader(header);
                    if (!StringUtils.isNotBlank((CharSequence)value)) continue;
                    key.append(value).append("_");
                }
                return key.toString();
            }
            case BODY: {
                String contentType = rateLimitContext.getRequest().getContentType();
                if (contentType != null && contentType.contains("application/json")) {
                    String bodyAsString = rateLimitContext.getCachedRequest().getCachedBodyAsString();
                    JSONObject jsonObject = this.parseBody(bodyAsString);
                    if (jsonObject != null) {
                        for (String name : keyExtractor.getParameterNames()) {
                            List values = ClientUtil.extractListByKey((JSONObject)jsonObject, (String)name);
                            if (values.isEmpty()) continue;
                            key.append(values).append("_");
                        }
                    }
                } else {
                    for (String name : keyExtractor.getParameterNames()) {
                        String value = rateLimitContext.getRequest().getParameter(name);
                        if (!StringUtils.isNotBlank((CharSequence)value)) continue;
                        key.append(value).append("_");
                    }
                }
                return key.toString();
            }
            case QUERY: {
                for (String name : keyExtractor.getParameterNames()) {
                    String value = rateLimitContext.getRequest().getParameter(name);
                    if (!StringUtils.isNotBlank((CharSequence)value)) continue;
                    key.append(value).append("_");
                }
                return key.toString();
            }
        }
        this.log.error("Invalid key extractor source: {}", (Object)keyExtractor.getSource());
        return "null";
    }

    public List<RateLimitRule> matchRulesByPathAndMethod(String requestPath, String method) {
        ArrayList<RateLimitRule> result = new ArrayList<RateLimitRule>();
        for (RateLimitRule rule : this.rateLimitConfiguration.getRateLimitRules()) {
            if (!rule.isWellFormed()) {
                this.log.error("Invalid rate limit rule: {}", (Object)rule);
                continue;
            }
            if (!rule.getPath().equals(requestPath) || !rule.getMethods().contains(method)) continue;
            result.add(rule);
        }
        return result;
    }

    private int getRequestLimit(Integer requestLimit) {
        if (requestLimit == null || requestLimit <= 0) {
            return 10;
        }
        return requestLimit;
    }

    private int getPeriodLimit(Integer periodInSeconds) {
        if (periodInSeconds == null || periodInSeconds <= 0) {
            periodInSeconds = 60;
        }
        return periodInSeconds;
    }

    private Bucket newBucket(int requestLimit, int periodInSeconds) {
        return Bucket.builder().addLimit(Bandwidth.builder().capacity((long)requestLimit).refillGreedy((long)requestLimit, Duration.ofSeconds(periodInSeconds)).build()).build();
    }

    public JSONObject parseBody(String body) {
        try {
            return new JSONObject(body);
        }
        catch (Exception e) {
            try {
                return Jwt.parseOrThrow((String)body).getClaims().toJsonObject();
            }
            catch (InvalidJwtException ex) {
                return null;
            }
        }
    }

    @PostConstruct
    public void init() {
        this.updateConfiguration(this.appConfiguration);
    }

    public void updateConfiguration(@Observes @ConfigurationUpdate AppConfiguration appConfiguration) {
        try {
            this.rateLimitConfiguration = appConfiguration.getRateLimitConfiguration();
            if (this.rateLimitConfiguration == null) {
                this.log.info("Rate limiting is not configured.");
            }
        }
        catch (Exception e) {
            this.log.error(e.getMessage(), (Throwable)e);
        }
    }
}

