/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.undeploy;

import com.google.common.annotations.VisibleForTesting;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.BulkDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
import org.opensearch.remote.metadata.client.WriteDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportUndeployModelsAction
extends HandledTransportAction<ActionRequest, MLUndeployModelsResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportUndeployModelsAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    SdkClient sdkClient;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    MLTaskDispatcher mlTaskDispatcher;
    MLModelManager mlModelManager;
    ModelAccessControlHelper modelAccessControlHelper;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportUndeployModelsAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, ClusterService clusterService, ThreadPool threadPool, Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, MLTaskDispatcher mlTaskDispatcher, MLModelManager mlModelManager, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/undeploy_models", transportService, actionFilters, MLDeployModelRequest::new);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.sdkClient = sdkClient;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mlTaskDispatcher = mlTaskDispatcher;
        this.mlModelManager = mlModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLUndeployModelsResponse> listener) {
        MLUndeployModelsRequest undeployModelsRequest = MLUndeployModelsRequest.fromActionRequest((ActionRequest)request);
        Object[] modelIds = undeployModelsRequest.getModelIds();
        String tenantId = undeployModelsRequest.getTenantId();
        String[] targetNodeIds = undeployModelsRequest.getNodeIds();
        log.info("Executing undeploy model action for modelIds: {}", (Object)Arrays.toString(modelIds));
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, listener)) {
            return;
        }
        if (modelIds == null) {
            log.error("No modelIds provided in undeploy.");
            listener.onFailure((Exception)new IllegalArgumentException("Must set specific model ids to undeploy"));
            return;
        }
        if (modelIds.length == 1) {
            Object modelId = modelIds[0];
            this.validateAccess((String)modelId, tenantId, (ActionListener<Boolean>)ActionListener.wrap(arg_0 -> this.lambda$doExecute$0(targetNodeIds, (String[])modelIds, tenantId, listener, (String)modelId, arg_0), arg_0 -> listener.onFailure(arg_0)));
        } else {
            if (this.modelAccessControlHelper.isModelAccessControlEnabled()) {
                throw new IllegalArgumentException("only support undeploy one model");
            }
            this.searchHiddenModels((String[])modelIds, (ActionListener<SearchResponse>)ActionListener.wrap(arg_0 -> this.lambda$doExecute$3((String[])modelIds, targetNodeIds, tenantId, listener, arg_0), e -> {
                log.error("Failed to search model index", (Throwable)e);
                listener.onFailure(e);
            }));
        }
    }

    private void undeployModels(String[] targetNodeIds, String[] modelIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
        log.debug("Initiating undeploy on nodes: {}, for modelIds: {}", (Object)Arrays.toString(targetNodeIds), (Object)Arrays.toString(modelIds));
        MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds);
        mlUndeployModelNodesRequest.setTenantId(tenantId);
        this.client.execute((ActionType)MLUndeployModelAction.INSTANCE, (ActionRequest)mlUndeployModelNodesRequest, ActionListener.wrap(response -> {
            log.info("Undeploy response received from nodes");
            boolean modelNotFoundInNodesCache = response.getNodes().stream().allMatch(nodeResponse -> {
                Map status = nodeResponse.getModelUndeployStatus();
                if (status == null) {
                    return false;
                }
                boolean modelCacheMissForModelIds = Arrays.stream(modelIds).allMatch(modelId -> {
                    String modelStatus = (String)status.get(modelId);
                    return modelStatus != null && modelStatus.equalsIgnoreCase("not_found");
                });
                return modelCacheMissForModelIds;
            });
            if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
                log.warn("No nodes service these models, performing manual `UNDEPLOY` write to model index");
                this.bulkSetModelIndexToUndeploy(modelIds, tenantId, listener, (MLUndeployModelNodesResponse)response);
                return;
            }
            log.info("Successfully undeployed model(s) from nodes: {}", (Object)Arrays.toString(modelIds));
            listener.onResponse((Object)new MLUndeployModelsResponse(response));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private void bulkSetModelIndexToUndeploy(String[] modelIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener, MLUndeployModelNodesResponse mlUndeployModelNodesResponse) {
        BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(".plugins-ml-model").build();
        for (String modelId : modelIds) {
            HashMap<String, Object> updateDocument = new HashMap<String, Object>();
            updateDocument.put("model_state", MLModelState.UNDEPLOYED.name());
            updateDocument.put("planning_worker_nodes", List.of());
            updateDocument.put("planning_worker_node_count", 0);
            updateDocument.put("last_updated_time", Instant.now().toEpochMilli());
            updateDocument.put("current_worker_node_count", 0);
            UpdateDataObjectRequest updateRequest = ((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)UpdateDataObjectRequest.builder().id(modelId)).tenantId(tenantId)).dataObject(updateDocument).build();
            bulkRequest.add((WriteDataObjectRequest)updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        }
        log.info("No nodes running these models: {}", (Object)Arrays.toString(modelIds));
        try (ThreadContext.StoredContext threadContext = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener listenerWithContextRestoration = ActionListener.runBefore(listener, () -> threadContext.restore());
            ActionListener bulkResponseListener = ActionListener.wrap(br -> listenerWithContextRestoration.onResponse((Object)new MLUndeployModelsResponse(mlUndeployModelNodesResponse)), e -> {
                String modelsNotFoundMessage = String.format("Failed to set the following modelId(s) to UNDEPLOY in index: %s", Arrays.toString(modelIds));
                log.error(modelsNotFoundMessage, (Throwable)e);
                OpenSearchStatusException exception = new OpenSearchStatusException(modelsNotFoundMessage + e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
                listenerWithContextRestoration.onFailure((Exception)exception);
            });
            this.sdkClient.bulkDataObjectAsync(bulkRequest).whenComplete((response, exception) -> {
                if (exception != null) {
                    Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)exception, (Class[])new Class[]{OpenSearchStatusException.class});
                    bulkResponseListener.onFailure(cause);
                    return;
                }
                try {
                    BulkResponse bulkResponse = BulkResponse.fromXContent((XContentParser)response.parser());
                    log.info("Executed {} bulk operations with {} failures, Took: {}", (Object)bulkResponse.getItems().length, (Object)(bulkResponse.hasFailures() ? Arrays.stream(bulkResponse.getItems()).filter(BulkItemResponse::isFailed).count() : 0L), (Object)bulkResponse.getTook());
                    List unemployedModelIds = Arrays.stream(bulkResponse.getItems()).filter(bulkItemResponse -> !bulkItemResponse.isFailed()).map(BulkItemResponse::getId).collect(Collectors.toList());
                    log.debug("Successfully set the following modelId(s) to UNDEPLOY in index: {}", (Object)Arrays.toString(unemployedModelIds.toArray()));
                    bulkResponseListener.onResponse((Object)bulkResponse);
                }
                catch (Exception e) {
                    bulkResponseListener.onFailure(e);
                }
            });
        }
        catch (Exception e2) {
            log.error("Unexpected error while setting the following modelId(s) to UNDEPLOY in index: {}", (Object)Arrays.toString(modelIds), (Object)e2);
            listener.onFailure(e2);
        }
    }

    private void validateAccess(String modelId, String tenantId, ActionListener<Boolean> listener) {
        User user = RestActionUtils.getUserContext(this.client);
        boolean isSuperAdmin = this.isSuperAdminUserWrapper(this.clusterService, this.client);
        String[] excludes = new String[]{"model_content", "content"};
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.mlModelManager.getModel(modelId, tenantId, null, excludes, (ActionListener<MLModel>)ActionListener.runBefore((ActionListener)ActionListener.wrap(mlModel -> {
                if (!TenantAwareHelper.validateTenantResource(this.mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), listener)) {
                    return;
                }
                Boolean isHidden = mlModel.getIsHidden();
                if (isHidden != null && isHidden.booleanValue()) {
                    if (isSuperAdmin) {
                        listener.onResponse((Object)true);
                    } else {
                        listener.onFailure((Exception)new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                    }
                } else {
                    this.modelAccessControlHelper.validateModelGroupAccess(user, this.mlFeatureEnabledSetting, tenantId, mlModel.getModelGroupId(), "cluster:admin/opensearch/ml/undeploy_models", this.client, this.sdkClient, listener);
                }
            }, e -> {
                log.error("Failed to find Model", (Throwable)e);
                listener.onFailure(e);
            }), () -> ((ThreadContext.StoredContext)context).restore()));
        }
        catch (Exception e2) {
            log.error("Failed to undeploy ML model");
            listener.onFailure(e2);
        }
    }

    public void searchHiddenModels(String[] modelIds, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            TermsQueryBuilder termsQuery = QueryBuilders.termsQuery((String)"_id", (String[])modelIds);
            TermQueryBuilder isHiddenQuery = QueryBuilders.termQuery((String)"is_hidden", (boolean)true);
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
            searchSourceBuilder.query((QueryBuilder)QueryBuilders.boolQuery().must((QueryBuilder)termsQuery).must((QueryBuilder)isHiddenQuery).mustNot((QueryBuilder)QueryBuilders.existsQuery((String)"chunk_number")));
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-model"}).source(searchSourceBuilder);
            SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest.builder().indices(searchRequest.indices()).searchSourceBuilder(searchRequest.source()).build();
            this.sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
                context.restore();
                if (throwable != null) {
                    Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                    log.error("Failed to search model index", (Throwable)cause);
                    if (ExceptionsHelper.unwrap((Throwable)cause, (Class[])new Class[]{IndexNotFoundException.class}) != null) {
                        listener.onResponse(null);
                    } else {
                        listener.onFailure(cause);
                    }
                } else {
                    try {
                        SearchResponse searchResponse = r.searchResponse();
                        log.info("Model Index search complete: {}", (Object)searchResponse.getHits().getTotalHits());
                        listener.onResponse((Object)searchResponse);
                    }
                    catch (Exception e) {
                        log.error("Failed to parse search response", (Throwable)e);
                        listener.onFailure((Exception)new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]));
                    }
                }
            });
        }
        catch (Exception e) {
            log.error("Failed to search model index", (Throwable)e);
            listener.onFailure(e);
        }
    }

    @VisibleForTesting
    boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
        return RestActionUtils.isSuperAdminUser(clusterService, client);
    }

    private /* synthetic */ void lambda$doExecute$3(String[] modelIds, String[] targetNodeIds, String tenantId, ActionListener listener, SearchResponse hiddenModels) throws Exception {
        if (hiddenModels != null && hiddenModels.getHits().getTotalHits() != null && hiddenModels.getHits().getTotalHits().value() != 0L && !this.isSuperAdminUserWrapper(this.clusterService, this.client)) {
            List hiddenModelIds = Arrays.stream(hiddenModels.getHits().getHits()).map(SearchHit::getId).collect(Collectors.toList());
            String[] modelsIDsToUndeploy = (String[])Arrays.stream(modelIds).filter(modelId -> !hiddenModelIds.contains(modelId)).toArray(String[]::new);
            this.undeployModels(targetNodeIds, modelsIDsToUndeploy, tenantId, (ActionListener<MLUndeployModelsResponse>)listener);
        } else {
            this.undeployModels(targetNodeIds, modelIds, tenantId, (ActionListener<MLUndeployModelsResponse>)listener);
        }
    }

    private /* synthetic */ void lambda$doExecute$0(String[] targetNodeIds, String[] modelIds, String tenantId, ActionListener listener, String modelId, Boolean hasPermissionToUndeploy) throws Exception {
        if (hasPermissionToUndeploy.booleanValue()) {
            this.undeployModels(targetNodeIds, modelIds, tenantId, (ActionListener<MLUndeployModelsResponse>)listener);
        } else {
            listener.onFailure((Exception)new IllegalArgumentException("No permission to undeploy model " + modelId));
        }
    }
}

