/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SortField;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.HybridCollapsingTopDocsCollector;
import org.opensearch.neuralsearch.search.collector.HybridCollectorFactory;
import org.opensearch.neuralsearch.search.collector.HybridCollectorFactoryDTO;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.query.HybridCollectorResultsUtilParams;
import org.opensearch.neuralsearch.search.query.util.HybridSearchCollectorResultUtil;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.sort.SortAndFormats;

public class HybridCollectorManager
implements CollectorManager<Collector, ReduceableSearchResult> {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridCollectorManager.class);
    private final int numHits;
    private final HitsThresholdChecker hitsThresholdChecker;
    private final SortAndFormats sortAndFormats;
    @Nullable
    private final FieldDoc after;
    private final SearchContext searchContext;
    private final Set<Class<?>> VALID_COLLECTOR_TYPES = Set.of(HybridTopScoreDocCollector.class, HybridTopFieldDocSortCollector.class, HybridCollapsingTopDocsCollector.class);

    public static CollectorManager createHybridCollectorManager(SearchContext searchContext, Query query) {
        boolean isSingleShard;
        if (searchContext.scrollContext() != null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query", new Object[0]));
        }
        IndexReader reader = searchContext.searcher().getIndexReader();
        int totalNumDocs = Math.max(0, reader.numDocs());
        int numDocs = Math.min(HybridCollectorManager.getSubqueryResultsRetrievalSize(searchContext, query), totalNumDocs);
        int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
        if (searchContext.sort() != null) {
            HybridCollectorManager.validateSortCriteria(searchContext, searchContext.trackScores());
        }
        boolean bl = isSingleShard = searchContext.numberOfShards() == 1;
        if (isSingleShard && searchContext.from() > 0) {
            searchContext.from(0);
        }
        return new HybridCollectorManager(numDocs, new HitsThresholdChecker(Math.max(numDocs, trackTotalHitsUpTo)), searchContext.sort(), searchContext.searchAfter(), searchContext);
    }

    public Collector newCollector() {
        return HybridCollectorFactory.createCollector(HybridCollectorFactoryDTO.builder().sortAndFormats(this.sortAndFormats).searchContext(this.searchContext).hitsThresholdChecker(this.hitsThresholdChecker).numHits(this.numHits).after(this.after).build());
    }

    public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
        List<HybridSearchCollector> hybridSearchCollectors = this.getHybridSearchCollectors(collectors);
        if (hybridSearchCollectors.isEmpty()) {
            throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper collectors");
        }
        return this.reduceSearchResults(this.getSearchResults(hybridSearchCollectors));
    }

    private List<ReduceableSearchResult> getSearchResults(List<HybridSearchCollector> hybridSearchCollectors) throws IOException {
        ArrayList<ReduceableSearchResult> results = new ArrayList<ReduceableSearchResult>();
        HybridCollectorResultsUtilParams hybridCollectorResultsUtilParams = new HybridCollectorResultsUtilParams.Builder().searchContext(this.searchContext).build();
        for (HybridSearchCollector collector : hybridSearchCollectors) {
            HybridSearchCollectorResultUtil hybridSearchCollectorResultUtil = new HybridSearchCollectorResultUtil(hybridCollectorResultsUtilParams, collector);
            TopDocsAndMaxScore topDocsAndMaxScore = hybridSearchCollectorResultUtil.getTopDocsAndMaxScore();
            results.add(result -> hybridSearchCollectorResultUtil.reduceCollectorResults(result, topDocsAndMaxScore));
        }
        return results;
    }

    private List<HybridSearchCollector> getHybridSearchCollectors(Collection<Collector> collectors) {
        ArrayList<HybridSearchCollector> hybridSearchCollectors = new ArrayList<HybridSearchCollector>();
        for (Collector collector : collectors) {
            if (collector instanceof MultiCollector) {
                for (Collector sub : ((MultiCollector)collector).getCollectors()) {
                    if (!(sub instanceof HybridTopScoreDocCollector) && !(sub instanceof HybridTopFieldDocSortCollector)) continue;
                    hybridSearchCollectors.add((HybridSearchCollector)sub);
                }
                continue;
            }
            if (this.isHybridNonFilteredCollector(collector)) {
                hybridSearchCollectors.add((HybridSearchCollector)collector);
                continue;
            }
            if (!this.isHybridFilteredCollector(collector)) continue;
            hybridSearchCollectors.add((HybridSearchCollector)((FilteredCollector)collector).getCollector());
        }
        return hybridSearchCollectors;
    }

    private boolean isHybridNonFilteredCollector(Collector collector) {
        return this.VALID_COLLECTOR_TYPES.stream().anyMatch(type -> type.isInstance(collector));
    }

    private boolean isHybridFilteredCollector(Collector collector) {
        return collector instanceof FilteredCollector && this.VALID_COLLECTOR_TYPES.stream().anyMatch(type -> type.isInstance(((FilteredCollector)collector).getCollector()));
    }

    private static void validateSortCriteria(SearchContext searchContext, boolean trackScores) {
        SortField[] sortFields = searchContext.sort().sort.getSort();
        boolean hasFieldSort = false;
        boolean hasScoreSort = false;
        for (SortField sortField : sortFields) {
            SortField.Type type = sortField.getType();
            if (type.equals((Object)SortField.Type.SCORE)) {
                hasScoreSort = true;
            } else {
                hasFieldSort = true;
            }
            if (hasScoreSort && hasFieldSort) break;
        }
        if (hasScoreSort && hasFieldSort) {
            throw new IllegalArgumentException("_score sort criteria cannot be applied with any other criteria. Please select one sort criteria out of them.");
        }
        if (trackScores && hasFieldSort) {
            throw new IllegalArgumentException("Hybrid search results when sorted by any field, docId or _id, track_scores must be set to false.");
        }
        if (trackScores && hasScoreSort) {
            throw new IllegalArgumentException("Hybrid search results are by default sorted by _score, track_scores must be set to false.");
        }
    }

    private ReduceableSearchResult reduceSearchResults(List<ReduceableSearchResult> results) {
        return result -> {
            for (ReduceableSearchResult r : results) {
                r.reduce(result);
            }
        };
    }

    private static int getSubqueryResultsRetrievalSize(SearchContext searchContext, Query query) {
        assert (query instanceof HybridQuery);
        HybridQuery hybridQuery = (HybridQuery)query;
        Integer paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();
        if (Objects.isNull(paginationDepth) && searchContext.from() > 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth param is missing in the search request", new Object[0]));
        }
        if (Objects.nonNull(paginationDepth)) {
            return paginationDepth;
        }
        return searchContext.size();
    }

    @Generated
    public HybridCollectorManager(int numHits, HitsThresholdChecker hitsThresholdChecker, SortAndFormats sortAndFormats, FieldDoc after, SearchContext searchContext) {
        this.numHits = numHits;
        this.hitsThresholdChecker = hitsThresholdChecker;
        this.sortAndFormats = sortAndFormats;
        this.after = after;
        this.searchContext = searchContext;
    }
}

