/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.storage.scan;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.externalize.RelWriterImpl;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.NumberUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.opensearch.script.Script;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.ScoreSortBuilder;
import org.opensearch.search.sort.ScriptSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.sql.calcite.plan.AliasFieldsWrappable;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
import org.opensearch.sql.opensearch.request.PredicateAnalyzer;
import org.opensearch.sql.opensearch.storage.OpenSearchIndex;
import org.opensearch.sql.opensearch.storage.scan.context.AggPushDownAction;
import org.opensearch.sql.opensearch.storage.scan.context.AggregationBuilderAction;
import org.opensearch.sql.opensearch.storage.scan.context.FilterDigest;
import org.opensearch.sql.opensearch.storage.scan.context.LimitDigest;
import org.opensearch.sql.opensearch.storage.scan.context.OSRequestBuilderAction;
import org.opensearch.sql.opensearch.storage.scan.context.PushDownContext;
import org.opensearch.sql.opensearch.storage.scan.context.PushDownOperation;
import org.opensearch.sql.opensearch.storage.scan.context.PushDownType;
import org.opensearch.sql.opensearch.storage.scan.context.RareTopDigest;
import org.opensearch.sql.opensearch.storage.scan.context.SortExprDigest;

public abstract class AbstractCalciteIndexScan
extends TableScan
implements AliasFieldsWrappable {
    private static final Logger LOG = LogManager.getLogger(AbstractCalciteIndexScan.class);
    public final OpenSearchIndex osIndex;
    protected final RelDataType schema;
    protected final PushDownContext pushDownContext;

    protected AbstractCalciteIndexScan(RelOptCluster cluster, RelTraitSet traitSet, List<RelHint> hints, RelOptTable table, OpenSearchIndex osIndex, RelDataType schema, PushDownContext pushDownContext) {
        super(cluster, traitSet, hints, table);
        this.osIndex = Objects.requireNonNull(osIndex, "OpenSearch index");
        this.schema = schema;
        this.pushDownContext = pushDownContext;
    }

    public RelDataType deriveRowType() {
        return this.schema;
    }

    public RelWriter explainTerms(RelWriter pw) {
        Object explainString = String.valueOf(this.pushDownContext);
        if (pw instanceof RelWriterImpl) {
            explainString = (String)explainString + ", " + String.valueOf(this.pushDownContext.createRequestBuilder());
        }
        return super.explainTerms(pw).itemIf("PushDownContext", explainString, !this.pushDownContext.isEmpty());
    }

    protected Integer getQuerySizeLimit() {
        return (Integer)this.osIndex.getSettings().getSettingValue(Settings.Key.QUERY_SIZE_LIMIT);
    }

    public double estimateRowCount(RelMetadataQuery mq) {
        return this.pushDownContext.stream().reduce(this.osIndex.getMaxResultWindow().doubleValue(), (rowCount, operation) -> switch (operation.type()) {
            default -> throw new MatchException(null, null);
            case PushDownType.AGGREGATION -> mq.getRowCount((RelNode)operation.digest());
            case PushDownType.PROJECT, PushDownType.SORT, PushDownType.SORT_EXPR -> rowCount;
            case PushDownType.SORT_AGG_METRICS -> NumberUtil.min((Double)rowCount, (Double)this.osIndex.getQueryBucketSize().doubleValue());
            case PushDownType.FILTER, PushDownType.SCRIPT -> NumberUtil.multiply((Double)rowCount, (Double)RelMdUtil.guessSelectivity((RexNode)((FilterDigest)operation.digest()).condition()));
            case PushDownType.LIMIT -> Math.min(rowCount, (double)((LimitDigest)operation.digest()).limit());
            case PushDownType.RARE_TOP -> {
                RareTopDigest digest = (RareTopDigest)operation.digest();
                int factor = digest.number();
                int groupCount = digest.byList().size();
                yield groupCount == 0 ? (double)factor : (double)factor * rowCount * (1.0 - Math.pow(0.5, groupCount));
            }
        }, (a, b) -> null);
    }

    public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        double dRows = this.osIndex.getMaxResultWindow().doubleValue();
        double dCpu = 0.0;
        for (PushDownOperation operation : this.pushDownContext) {
            switch (operation.type()) {
                case AGGREGATION: {
                    dRows = mq.getRowCount((RelNode)operation.digest());
                    dCpu += dRows * (double)AbstractCalciteIndexScan.getAggMultiplier(operation);
                    break;
                }
                case PROJECT: {
                    break;
                }
                case SORT: {
                    dCpu += dRows;
                    break;
                }
                case SORT_AGG_METRICS: {
                    dRows = dRows * 0.9 / 10.0;
                    dCpu += dRows;
                    break;
                }
                case SORT_EXPR: {
                    List sortKeys = (List)operation.digest();
                    long complexExprCount = sortKeys.stream().filter(digest -> digest.getExpression() != null).count();
                    dCpu += NumberUtil.multiply((Double)dRows, (Double)(1.1 * (double)complexExprCount)).doubleValue();
                    break;
                }
                case FILTER: {
                    dRows = NumberUtil.multiply((Double)dRows, (Double)RelMdUtil.guessSelectivity((RexNode)((FilterDigest)operation.digest()).condition()));
                    break;
                }
                case SCRIPT: {
                    FilterDigest filterDigest = (FilterDigest)operation.digest();
                    dRows = NumberUtil.multiply((Double)dRows, (Double)RelMdUtil.guessSelectivity((RexNode)filterDigest.condition()));
                    dCpu += NumberUtil.multiply((Double)dRows, (Double)Math.pow(1.1, filterDigest.scriptCount())).doubleValue();
                    break;
                }
                case LIMIT: {
                    dRows = Math.min(dRows, (double)((LimitDigest)operation.digest()).limit()) - 1.0;
                    break;
                }
                case RARE_TOP: {
                    RareTopDigest digest2 = (RareTopDigest)operation.digest();
                    int factor = digest2.number();
                    int groupCount = digest2.byList().size();
                    dRows = groupCount == 0 ? (double)factor : (double)factor * dRows * (1.0 - Math.pow(0.5, groupCount));
                    dCpu += dRows * 1.125;
                }
            }
        }
        double estimateRowCountFactor = (Double)this.osIndex.getSettings().getSettingValue(Settings.Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR);
        return planner.getCostFactory().makeCost((dCpu += dRows * (double)this.getRowType().getFieldList().size()) * estimateRowCountFactor, 0.0, 0.0);
    }

    private static float getAggMultiplier(PushDownOperation operation) {
        List aggCalls = ((Aggregate)operation.digest()).getAggCallList();
        float multiplier = 1.0f + (float)aggCalls.size() * 0.125f;
        for (AggregateCall aggCall : aggCalls) {
            if (!aggCall.getAggregation().getName().equals("SUM")) continue;
            multiplier += 0.0125f;
        }
        return multiplier *= (float)Math.pow(1.1f, ((AggPushDownAction)operation.action()).getScriptCount());
    }

    protected abstract AbstractCalciteIndexScan buildScan(RelOptCluster var1, RelTraitSet var2, List<RelHint> var3, RelOptTable var4, OpenSearchIndex var5, RelDataType var6, PushDownContext var7);

    public Map<String, String> getAliasMapping() {
        return this.osIndex.getAliasMapping();
    }

    public abstract AbstractCalciteIndexScan copy();

    protected List<String> getCollationNames(List<RelFieldCollation> collations) {
        return collations.stream().map(collation -> (String)this.getRowType().getFieldNames().get(collation.getFieldIndex())).toList();
    }

    protected boolean isAnyCollationNameInAggregators(List<String> collations) {
        Stream<LogicalAggregate> aggregates = this.pushDownContext.stream().filter(action -> action.type() == PushDownType.AGGREGATION).map(action -> (LogicalAggregate)action.digest());
        return aggregates.map(aggregate -> this.isAnyCollationNameInAggregators((LogicalAggregate)aggregate, collations)).reduce(false, Boolean::logicalOr);
    }

    private boolean isAnyCollationNameInAggregators(LogicalAggregate aggregate, List<String> collations) {
        List fieldNames = aggregate.getRowType().getFieldNames();
        int groupOffset = aggregate.getGroupSet().cardinality();
        List fieldsWithoutGrouping = fieldNames.subList(groupOffset, fieldNames.size());
        return collations.stream().map(fieldsWithoutGrouping::contains).reduce(false, Boolean::logicalOr);
    }

    public AbstractCalciteIndexScan pushDownSort(List<RelFieldCollation> collations) {
        try {
            List<String> collationNames = this.getCollationNames(collations);
            if (this.getPushDownContext().isAggregatePushed() && this.isAnyCollationNameInAggregators(collationNames)) {
                return null;
            }
            RelTraitSet traitsWithCollations = this.getTraitSet().plus((RelTrait)RelCollations.of(collations));
            PushDownContext pushDownContextWithoutSort = this.pushDownContext.cloneWithoutSort();
            if (this.pushDownContext.isAggregatePushed()) {
                AggregationBuilderAction action = aggAction -> aggAction.pushDownSortIntoAggBucket(collations, this.getRowType().getFieldNames());
                List<RelFieldCollation> digest = collations;
                pushDownContextWithoutSort.add(PushDownType.SORT, digest, action);
                return this.buildScan(this.getCluster(), traitsWithCollations, (List<RelHint>)this.hints, this.table, this.osIndex, this.getRowType(), pushDownContextWithoutSort.clone());
            }
            AbstractCalciteIndexScan newScan = this.buildScan(this.getCluster(), traitsWithCollations, (List<RelHint>)this.hints, this.table, this.osIndex, this.getRowType(), pushDownContextWithoutSort);
            ArrayList<SortBuilder> builders = new ArrayList<SortBuilder>();
            for (RelFieldCollation collation : collations) {
                ScoreSortBuilder sortBuilder;
                SortOrder order;
                int index = collation.getFieldIndex();
                String fieldName = (String)this.getRowType().getFieldNames().get(index);
                RelFieldCollation.Direction direction = collation.getDirection();
                RelFieldCollation.NullDirection nullDirection = collation.nullDirection;
                SortOrder sortOrder = order = RelFieldCollation.Direction.DESCENDING.equals((Object)direction) ? SortOrder.DESC : SortOrder.ASC;
                if ("_score".equals(fieldName)) {
                    sortBuilder = SortBuilders.scoreSort();
                } else {
                    String missing = switch (nullDirection) {
                        case RelFieldCollation.NullDirection.FIRST -> "_first";
                        case RelFieldCollation.NullDirection.LAST -> "_last";
                        default -> null;
                    };
                    ExprType fieldType = this.osIndex.getFieldTypes().get(fieldName);
                    String field = OpenSearchTextType.toKeywordSubField(fieldName, fieldType);
                    sortBuilder = SortBuilders.fieldSort((String)field).missing((Object)missing);
                }
                builders.add(sortBuilder.order(order));
            }
            OSRequestBuilderAction action = requestBuilder -> requestBuilder.pushDownSort(builders);
            String digest = ((Object)builders).toString();
            newScan.pushDownContext.add(PushDownType.SORT, digest, action);
            return newScan;
        }
        catch (Exception e) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Cannot pushdown the sort {}", this.getCollationNames(collations), (Object)e);
            }
            return null;
        }
    }

    public AbstractCalciteIndexScan pushdownSortExpr(List<SortExprDigest> sortExprDigests) {
        try {
            if (sortExprDigests == null || sortExprDigests.isEmpty()) {
                return null;
            }
            AbstractCalciteIndexScan newScan = this.buildScan(this.getCluster(), this.traitSet, (List<RelHint>)this.hints, this.table, this.osIndex, this.getRowType(), this.pushDownContext.cloneWithoutSort());
            ArrayList<Supplier<SortBuilder>> sortBuilderSuppliers = new ArrayList<Supplier<SortBuilder>>();
            for (final SortExprDigest digest : sortExprDigests) {
                SortOrder order;
                SortOrder sortOrder = order = RelFieldCollation.Direction.DESCENDING.equals((Object)digest.getDirection()) ? SortOrder.DESC : SortOrder.ASC;
                if (digest.isSimpleFieldReference()) {
                    String missing = switch (digest.getNullDirection()) {
                        case RelFieldCollation.NullDirection.FIRST -> "_first";
                        case RelFieldCollation.NullDirection.LAST -> "_last";
                        default -> null;
                    };
                    sortBuilderSuppliers.add(() -> ((FieldSortBuilder)SortBuilders.fieldSort((String)digest.getFieldName()).order(order)).missing((Object)missing));
                    continue;
                }
                RexNode sortExpr = digest.getExpression();
                assert (sortExpr instanceof RexCall) : "sort expression should be RexCall";
                LinkedHashMap<String, Object> missingValueParams = new LinkedHashMap<String, Object>(){
                    {
                        this.put("MISSING_MAX", digest.isMissingMax());
                    }
                };
                PredicateAnalyzer.ScriptQueryExpression scriptExpr = new PredicateAnalyzer.ScriptQueryExpression(digest.getExpression(), this.rowType, this.osIndex.getAllFieldTypes(), this.getCluster(), (Map<String, Object>)missingValueParams);
                ScriptSortBuilder.ScriptSortType sortType = this.getScriptSortType(sortExpr.getType());
                sortBuilderSuppliers.add(() -> SortBuilders.scriptSort((Script)scriptExpr.getScript(), (ScriptSortBuilder.ScriptSortType)sortType).order(order));
            }
            OSRequestBuilderAction action = requestBuilder -> requestBuilder.pushDownSortSuppliers(sortBuilderSuppliers);
            newScan.pushDownContext.add(PushDownType.SORT_EXPR, sortExprDigests, action);
            return newScan;
        }
        catch (Exception e) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Cannot pushdown sort expressions: {}", sortExprDigests, (Object)e);
            }
            return null;
        }
    }

    public boolean noAggregatePushed() {
        if (this.getPushDownContext().isAggregatePushed()) {
            return false;
        }
        RelOptTable table = this.getTable();
        return table.unwrap(OpenSearchIndex.class) != null;
    }

    public boolean isLimitPushed() {
        return this.getPushDownContext().isLimitPushed();
    }

    public boolean isMetricsOrderPushed() {
        return this.getPushDownContext().isMeasureOrderPushed();
    }

    public boolean isTopKPushed() {
        return this.getPushDownContext().isTopKPushed();
    }

    public boolean isScriptPushed() {
        return this.getPushDownContext().isScriptPushed();
    }

    public boolean isProjectPushed() {
        return this.getPushDownContext().isProjectPushed();
    }

    private ScriptSortBuilder.ScriptSortType getScriptSortType(RelDataType relDataType) {
        if (SqlTypeName.CHAR_TYPES.contains(relDataType.getSqlTypeName())) {
            return ScriptSortBuilder.ScriptSortType.STRING;
        }
        if (SqlTypeName.INT_TYPES.contains(relDataType.getSqlTypeName()) || SqlTypeName.APPROX_TYPES.contains(relDataType.getSqlTypeName())) {
            return ScriptSortBuilder.ScriptSortType.NUMBER;
        }
        throw new OpenSearchRequestBuilder.PushDownUnSupportedException("Unsupported type for sort expression pushdown: " + String.valueOf(relDataType));
    }

    @Generated
    public OpenSearchIndex getOsIndex() {
        return this.osIndex;
    }

    @Generated
    public RelDataType getSchema() {
        return this.schema;
    }

    @Generated
    public PushDownContext getPushDownContext() {
        return this.pushDownContext;
    }
}

