package com.odianyun.db.mybatis.interceptor;

import com.github.pagehelper.Page;
import com.github.pagehelper.PageHelper;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/* loaded from: input_file:BOOT-INF/lib/ody-db-0.0.10-SNAPSHOT.jar:com/odianyun/db/mybatis/interceptor/OdyPageDialectHelper.class */
public class OdyPageDialectHelper extends PageHelper {
    private Set<String> ignoreTableSet;
    private long defaultMaxTableCount = 200000;
    private static Supplier<Boolean> skipLimitFun;
    private static Function<String, Long> tableMaxCountFun;
    private static final String tableCountSuffix = "_TABLE_COUNT";
    private static final ThreadLocal<Invocation> invocationThreadLocal = new ThreadLocal<>();
    private static final Pattern TABLE_PATTERN = Pattern.compile("\\s+from\\s+([A-Za-z0-9_\\.]+)\\s?", 2);
    private static final String TABLE_COUNT_SQL = "SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = IFNULL(DATABASE(), ?) AND TABLE_NAME = ? LIMIT 1";

    @Override // com.github.pagehelper.PageHelper, com.github.pagehelper.Dialect
    public boolean beforeCount(MappedStatement mappedStatement, Object obj, RowBounds rowBounds) {
        if (skipLimitFun != null && Boolean.TRUE.equals(skipLimitFun.get())) {
            return super.beforeCount(mappedStatement, obj, rowBounds);
        }
        Invocation invocation = invocationThreadLocal.get();
        if (invocation == null || !closePageCount(invocation)) {
            return super.beforeCount(mappedStatement, obj, rowBounds);
        }
        return false;
    }

    @Override // com.github.pagehelper.PageHelper, com.github.pagehelper.Dialect
    public Object afterPage(List list, Object obj, RowBounds rowBounds) {
        Object afterPage = super.afterPage(list, obj, rowBounds);
        if (afterPage instanceof Page) {
            Page page = (Page) afterPage;
            if (page.getTotal() == 0 && !CollectionUtils.isEmpty(page.getResult())) {
                page.setTotal(-1L);
            }
        }
        return afterPage;
    }

    @Override // com.github.pagehelper.PageHelper, com.github.pagehelper.Dialect
    public void afterAll() {
        invocationThreadLocal.remove();
        super.afterAll();
    }

    @Override // com.github.pagehelper.PageHelper, com.github.pagehelper.Dialect
    public void setProperties(Properties properties) {
        String property = properties.getProperty("ignoreTables");
        String property2 = properties.getProperty("defaultMaxTableCount");
        if (!StringUtils.isEmpty(property2)) {
            this.defaultMaxTableCount = Long.parseLong(property2);
        }
        if (!StringUtils.isEmpty(property)) {
            if (this.ignoreTableSet == null) {
                this.ignoreTableSet = new HashSet();
            }
            for (String str : property.split(",")) {
                this.ignoreTableSet.add(str.trim());
            }
        }
        super.setProperties(properties);
    }

    private boolean closePageCount(Invocation invocation) {
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[0];
        Object obj = args[1];
        Executor executor = (Executor) invocation.getTarget();
        try {
            LinkedHashMap<String, String> parseDbTableName = parseDbTableName(args.length == 4 ? mappedStatement.getBoundSql(obj) : (BoundSql) args[5]);
            if (parseDbTableName == null) {
                return false;
            }
            if (this.ignoreTableSet != null && this.ignoreTableSet.contains(parseDbTableName.get("tableName"))) {
                return false;
            }
            Long queryTableCount = queryTableCount(executor, tableCountMappedStatement(mappedStatement, mappedStatement.getId() + tableCountSuffix), parseDbTableName);
            long maxTableCount = maxTableCount(parseDbTableName);
            if (queryTableCount != null) {
                if (queryTableCount.longValue() > maxTableCount) {
                    return true;
                }
            }
            return false;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    private Long queryTableCount(Executor executor, MappedStatement mappedStatement, LinkedHashMap<String, String> linkedHashMap) throws SQLException {
        Configuration configuration = mappedStatement.getConfiguration();
        ArrayList arrayList = new ArrayList();
        linkedHashMap.forEach((str, str2) -> {
            arrayList.add(new ParameterMapping.Builder(configuration, str, str2 != null ? str2.getClass() : Object.class).build());
        });
        BoundSql boundSql = new BoundSql(configuration, TABLE_COUNT_SQL, arrayList, linkedHashMap);
        return (Long) executor.query(mappedStatement, linkedHashMap, RowBounds.DEFAULT, null, executor.createCacheKey(mappedStatement, linkedHashMap, RowBounds.DEFAULT, boundSql), boundSql).get(0);
    }

    private long maxTableCount(LinkedHashMap<String, String> linkedHashMap) {
        Long apply;
        return (tableMaxCountFun == null || (apply = tableMaxCountFun.apply(linkedHashMap.get("tableName"))) == null) ? this.defaultMaxTableCount : apply.longValue();
    }

    private MappedStatement tableCountMappedStatement(MappedStatement mappedStatement, String str) {
        MappedStatement.Builder builder = new MappedStatement.Builder(mappedStatement.getConfiguration(), str, mappedStatement.getSqlSource(), mappedStatement.getSqlCommandType());
        builder.resource(mappedStatement.getResource());
        builder.fetchSize(mappedStatement.getFetchSize());
        builder.statementType(mappedStatement.getStatementType());
        builder.timeout(mappedStatement.getTimeout());
        ArrayList arrayList = new ArrayList();
        arrayList.add(new ResultMap.Builder(mappedStatement.getConfiguration(), mappedStatement.getId(), Long.class, Collections.emptyList()).build());
        builder.resultMaps(arrayList);
        builder.resultSetType(mappedStatement.getResultSetType());
        builder.cache(mappedStatement.getCache());
        builder.flushCacheRequired(mappedStatement.isFlushCacheRequired());
        builder.useCache(mappedStatement.isUseCache());
        return builder.build();
    }

    private LinkedHashMap<String, String> parseDbTableName(BoundSql boundSql) {
        Matcher matcher = TABLE_PATTERN.matcher(boundSql.getSql());
        if (!matcher.find()) {
            return null;
        }
        String trim = matcher.group(1).trim().trim();
        LinkedHashMap<String, String> linkedHashMap = new LinkedHashMap<>();
        int indexOf = trim.indexOf(".");
        if (indexOf != -1) {
            linkedHashMap.put("dbName", trim.substring(0, indexOf));
            linkedHashMap.put("tableName", trim.substring(indexOf + 1));
        } else {
            linkedHashMap.put("dbName", null);
            linkedHashMap.put("tableName", trim);
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setInvocation(Invocation invocation) {
        invocationThreadLocal.set(invocation);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void removeInvocation() {
        invocationThreadLocal.remove();
    }

    public static void setSkipLimitFunction(Supplier<Boolean> supplier) {
        skipLimitFun = supplier;
    }

    public static void setTableMaxCountFunction(Function<String, Long> function) {
        tableMaxCountFun = function;
    }
}
