/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.kylin.gluten;

import java.util.ArrayList;
import java.util.List;

import org.apache.gluten.execution.FilterExecTransformer;
import org.apache.commons.collections.CollectionUtils;
import org.apache.hadoop.util.Shell;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.common.util.DateFormat;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.common.util.RandomUtil;
import org.apache.kylin.common.util.TempMetadataBuilder;
import org.apache.kylin.engine.spark.NLocalWithSparkSessionTest;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.job.engine.JobEngineConfig;
import org.apache.kylin.job.impl.threadpool.NDefaultScheduler;
import org.apache.kylin.junit.TimeZoneTestRunner;
import org.apache.kylin.metadata.cube.model.NDataSegment;
import org.apache.kylin.metadata.model.NDataModelManager;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.util.ExecAndComp;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.SparderEnv;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper;
import org.apache.spark.sql.execution.gluten.KylinFileSourceScanExecTransformer;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.internal.StaticSQLConf;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;

import lombok.val;
import scala.Option;
import scala.runtime.AbstractFunction1;

@RunWith(TimeZoneTestRunner.class)
public class KylinGlutenTest extends NLocalWithSparkSessionTest implements AdaptiveSparkPlanHelper {

    private final String base = "select count(*)  FROM TEST_ORDER LEFT JOIN TEST_KYLIN_FACT ON TEST_KYLIN_FACT.ORDER_ID = TEST_ORDER.ORDER_ID ";

    @BeforeClass
    public static void initSpark() {
        if (Shell.MAC)
            overwriteSystemPropBeforeClass("org.xerial.snappy.lib.name", "libsnappyjava.jnilib");//for snappy
        if (ss != null && !ss.sparkContext().isStopped()) {
            ss.stop();
        }
        sparkConf = new SparkConf().setAppName(RandomUtil.randomUUIDStr()).setMaster("local[4]");
        sparkConf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer");
        sparkConf.set(StaticSQLConf.CATALOG_IMPLEMENTATION().key(), "in-memory");
        sparkConf.set("spark.sql.shuffle.partitions", "1");
        sparkConf.set("spark.memory.fraction", "0.1");
        // opt memory
        sparkConf.set("spark.shuffle.detectCorrupt", "false");
        // For sinai_poc/query03, enable implicit cross join conversion
        sparkConf.set("spark.sql.crossJoin.enabled", "true");
        sparkConf.set("spark.sql.adaptive.enabled", "true");
        // for gluten
        sparkConf.set("spark.plugins", "org.apache.gluten.GlutenPlugin");
        sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager");
        sparkConf.set("spark.io.compression.codec", "LZ4");
        sparkConf.set("spark.gluten.sql.columnar.libpath", "/usr/local/clickhouse/lib/libch.so");
        sparkConf.set("spark.gluten.sql.enable.native.validation", "false");
        sparkConf.set("spark.memory.offHeap.enabled", "true");
        sparkConf.set("spark.memory.offHeap.size", "2G");
        sparkConf.set("spark.gluten.sql.columnar.extended.columnar.pre.rules",
                "org.apache.spark.sql.execution.gluten.ConvertKylinFileSourceToGlutenRule");
        sparkConf.set("spark.gluten.sql.columnar.extended.expressions.transformer",
            "org.apache.spark.sql.catalyst.expressions.gluten.CustomerExpressionTransformer");
        sparkConf.set("spark.gluten.sql.columnar.extended.columnar.post.rules", "");
        sparkConf.set("spark.gluten.enabled", "false");
        sparkConf.set("spark.gluten.sql.validate.failure.printStack", "true");
        sparkConf.set(StaticSQLConf.WAREHOUSE_PATH().key(),
                TempMetadataBuilder.TEMP_TEST_METADATA + "/spark-warehouse");
        ss = SparkSession.builder().config(sparkConf).getOrCreate();
        SparderEnv.setSparkSession(ss);

    }

    @Before
    public void setUp() throws Exception {
        overwriteSystemProp("kylin.job.scheduler.poll-interval-second", "1");
        this.createTestMetadata("src/test/resources/ut_meta/file_pruning");
        NDefaultScheduler scheduler = NDefaultScheduler.getInstance(getProject());
        scheduler.init(new JobEngineConfig(KylinConfig.getInstanceFromEnv()));
        if (!scheduler.hasStarted()) {
            throw new RuntimeException("scheduler has not been started");
        }
    }

    @After
    public void after() throws Exception {
        NDefaultScheduler.destroyInstance();
        cleanupTestMetadata();
    }

    @Test
    public void testKylinWithGluten() throws Exception {
        // build three segs
        // [2009-01-01 00:00:00, 2011-01-01 00:00:00)
        // [2011-01-01 00:00:00, 2013-01-01 00:00:00)
        // [2013-01-01 00:00:00, 2015-01-01 00:00:00)
        val conf = SQLConf.get();

        // Test floor_datetime
        KylinConfig config = KylinConfig.getInstanceFromEnv();
        populateSSWithCSVData(config, getProject(), SparderEnv.getSparkSession());
        conf.setConfString("spark.gluten.enabled", "true");
        String pushDownSql = base
            + "where floor_datetime(TEST_TIME_ENC, 'DAY') > TIMESTAMP '2011-01-01 00:00:00' and "
            + "floor_datetime(TEST_TIME_ENC, 'hour') < TIMESTAMP '2013-01-01 00:00:00'";
        val df = ss.sql(pushDownSql);
        val glutenResult = df.collect();
        val filterTransformer = findFilterTransformer(df.queryExecution().executedPlan());
        Assert.assertFalse(filterTransformer.isEmpty());
        conf.setConfString("spark.gluten.enabled", "false");
        val sparkResult = ss.sql(pushDownSql).collect();
        Assert.assertEquals(((Row[])glutenResult)[0].getLong(0), ((Row[])sparkResult)[0].getLong(0));

        // Build the index
        conf.setConfString("spark.gluten.enabled", "false");
        val dfId = "8c670664-8d05-466a-802f-83c023b56c77";
        buildMultiSegs(dfId, 10001);
        populateSSWithCSVData(getTestConfig(), getProject(), SparderEnv.getSparkSession());

        // Test querying index
        conf.setConfString("spark.gluten.enabled", "true");
        String and_pruning0 = base
                + "where floor(TEST_TIME_ENC to hour) > TIMESTAMP '2011-01-01 00:00:00' and "
                + "TEST_TIME_ENC < TIMESTAMP '2013-01-01 00:00:00'";
        String and_pruning1 = base
                + "where TEST_TIME_ENC > TIMESTAMP '2011-01-01 00:00:00' and TEST_TIME_ENC = TIMESTAMP '2016-01-01 00:00:00'";

        String or_pruning0 = base
                + "where TEST_TIME_ENC > TIMESTAMP '2011-01-01 00:00:00' or TEST_TIME_ENC = TIMESTAMP '2016-01-01 00:00:00'";
        String or_pruning1 = base
                + "where TEST_TIME_ENC < TIMESTAMP '2009-01-01 00:00:00' or TEST_TIME_ENC > TIMESTAMP '2015-01-01 00:00:00'";

        String pruning0 = base + "where TEST_TIME_ENC < TIMESTAMP '2009-01-01 00:00:00'";
        String pruning1 = base + "where TEST_TIME_ENC <= TIMESTAMP '2009-01-01 00:00:00'";
        String pruning2 = base + "where TEST_TIME_ENC >= TIMESTAMP '2015-01-01 00:00:00'";

        String not_equal0 = base + "where TEST_TIME_ENC <> TIMESTAMP '2012-01-01 00:00:00'";

        String not0 = base
                + "where not (TEST_TIME_ENC < TIMESTAMP '2011-01-01 00:00:00' or TEST_TIME_ENC >= TIMESTAMP '2013-01-01 00:00:00')";

        String in_pruning0 = base
                + "where TEST_TIME_ENC in (TIMESTAMP '2009-01-01 00:00:00',TIMESTAMP '2008-01-01 00:00:00',TIMESTAMP '2016-01-01 00:00:00')";
        String in_pruning1 = base
                + "where TEST_TIME_ENC in (TIMESTAMP '2008-01-01 00:00:00',TIMESTAMP '2016-01-01 00:00:00')";

        val expectedRanges = Lists.<Pair<String, String>> newArrayList();
        val segmentRange1 = Pair.newPair("2009-01-01 00:00:00", "2011-01-01 00:00:00");
        val segmentRange2 = Pair.newPair("2011-01-01 00:00:00", "2013-01-01 00:00:00");
        val segmentRange3 = Pair.newPair("2013-01-01 00:00:00", "2015-01-01 00:00:00");

        expectedRanges.add(segmentRange1);
        expectedRanges.add(segmentRange2);
        expectedRanges.add(segmentRange3);
        assertResultsAndScanFiles(dfId, base, 3, false, expectedRanges);

        expectedRanges.clear();
        expectedRanges.add(segmentRange1);
        expectedRanges.add(segmentRange2);
        assertResultsAndScanFiles(dfId, and_pruning0, 2, false, expectedRanges);
        expectedRanges.clear();
        assertResultsAndScanFiles(dfId, and_pruning1, 0, true, expectedRanges);

        expectedRanges.add(segmentRange2);
        expectedRanges.add(segmentRange3);
        assertResultsAndScanFiles(dfId, or_pruning0, 2, false, expectedRanges);
        expectedRanges.clear();
        assertResultsAndScanFiles(dfId, or_pruning1, 0, true, expectedRanges);

        assertResultsAndScanFiles(dfId, pruning0, 0, true, expectedRanges);
        expectedRanges.add(segmentRange1);
        assertResultsAndScanFiles(dfId, pruning1, 1, false, expectedRanges);
        expectedRanges.clear();
        assertResultsAndScanFiles(dfId, pruning2, 0, true, expectedRanges);

        // pruning with "not equal" is not supported
        expectedRanges.add(segmentRange1);
        expectedRanges.add(segmentRange2);
        expectedRanges.add(segmentRange3);
        assertResultsAndScanFiles(dfId, not_equal0, 3, false, expectedRanges);

        expectedRanges.clear();
        expectedRanges.add(segmentRange2);
        assertResultsAndScanFiles(dfId, not0, 1, false, expectedRanges);

        expectedRanges.clear();
        expectedRanges.add(segmentRange1);
        assertResultsAndScanFiles(dfId, in_pruning0, 1, false, expectedRanges);
        assertResultsAndScanFiles(dfId, in_pruning1, 0, true, expectedRanges);

        List<Pair<String, String>> query = new ArrayList<>();
        query.add(Pair.newPair("base", base));
        query.add(Pair.newPair("and_pruning0", and_pruning0));
        query.add(Pair.newPair("or_pruning0", or_pruning0));
        query.add(Pair.newPair("pruning1", pruning1));
        query.add(Pair.newPair("not_equal0", not_equal0));
        query.add(Pair.newPair("not0", not0));
        query.add(Pair.newPair("in_pruning0", in_pruning0));
        ExecAndComp.execAndCompare(query, getProject(), ExecAndComp.CompareLevel.SAME, "default");
    }

    @Override
    public String getProject() {
        return "file_pruning";
    }

    private long assertResultsAndScanFiles(String modelId, String sql, long numScanFiles, boolean emptyLayout,
            List<Pair<String, String>> expectedRanges) throws Exception {
        val df = ExecAndComp.queryModelWithoutCompute(getProject(), sql);
        val context = ContextUtil.listContexts().get(0);
        if (emptyLayout) {
            Assert.assertTrue(context.storageContext.isEmptyLayout());
            Assert.assertEquals(Long.valueOf(-1), context.storageContext.getLayoutId());
            return numScanFiles;
        }
        df.collect();

        val nativeFileScanOption = findFileSourceScanExec(df.queryExecution().executedPlan());
        Assert.assertFalse(nativeFileScanOption.isEmpty());
        val nativeFileScan = (KylinFileSourceScanExecTransformer) nativeFileScanOption.get();
        val actualNum = nativeFileScan.metrics().get("numFiles").get().value();
        Assert.assertEquals(numScanFiles, actualNum);
        val segmentIds = context.storageContext.getPrunedSegments();
        assertPrunedSegmentRange(modelId, segmentIds, expectedRanges);
        return actualNum;
    }

    private Option<SparkPlan> findFileSourceScanExec(SparkPlan plan) {
        return find(plan, new AbstractFunction1<SparkPlan, Object>() {
            @Override
            public Object apply(SparkPlan v1) {
                return v1 instanceof KylinFileSourceScanExecTransformer;
            }
        });
    }

    private Option<SparkPlan> findFilterTransformer(SparkPlan plan) {
        return find(plan, new AbstractFunction1<SparkPlan, Object>() {
            @Override
            public Object apply(SparkPlan v1) {
                return v1 instanceof FilterExecTransformer;
            }
        });
    }

    private void assertPrunedSegmentRange(String dfId, List<NDataSegment> prunedSegments,
            List<Pair<String, String>> expectedRanges) {
        val model = NDataModelManager.getInstance(getTestConfig(), getProject()).getDataModelDesc(dfId);
        val partitionColDateFormat = model.getPartitionDesc().getPartitionDateFormat();

        if (CollectionUtils.isEmpty(expectedRanges)) {
            return;
        }
        Assert.assertEquals(expectedRanges.size(), prunedSegments.size());
        for (int i = 0; i < prunedSegments.size(); i++) {
            val segment = prunedSegments.get(i);
            val start = DateFormat.formatToDateStr(segment.getTSRange().getStart(), partitionColDateFormat);
            val end = DateFormat.formatToDateStr(segment.getTSRange().getEnd(), partitionColDateFormat);
            val expectedRange = expectedRanges.get(i);
            Assert.assertEquals(expectedRange.getFirst(), start);
            Assert.assertEquals(expectedRange.getSecond(), end);
        }
    }
}
