/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.connector.functions.preprocess;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.connector.functions.preprocess.ConnectorPreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

public class NovaMultiModalEmbeddingPreProcessFunction
extends ConnectorPreProcessFunction {
    @Generated
    private static final Logger log = LogManager.getLogger(NovaMultiModalEmbeddingPreProcessFunction.class);
    private static final ObjectMapper objectMapper = new ObjectMapper();

    public NovaMultiModalEmbeddingPreProcessFunction() {
        this.returnDirectlyForRemoteInferenceInput = true;
    }

    @Override
    public void validate(MLInput mlInput) {
        this.validateTextDocsInput(mlInput);
        List<String> docs = ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs();
        if (docs.size() == 0 || docs.size() == 1 && docs.get(0) == null) {
            throw new IllegalArgumentException("No input provided");
        }
    }

    @Override
    public RemoteInferenceInputDataSet process(MLInput mlInput) {
        TextDocsInputDataSet inputData = (TextDocsInputDataSet)mlInput.getInputDataset();
        String input = inputData.getDocs().get(0);
        HashMap<String, String> parametersMap = new HashMap<String, String>();
        String parameterName = this.detectModalityParameter(input);
        String content = this.extractContent(input);
        parametersMap.put(parameterName, content);
        return RemoteInferenceInputDataSet.builder().parameters(StringUtils.convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build();
    }

    private String extractContent(String input) {
        if (input == null || !input.startsWith("{")) {
            return input;
        }
        try {
            JsonNode node = objectMapper.readTree(input);
            JsonNode value = node.get("text");
            if (value != null) {
                return value.asText();
            }
            value = node.get("image");
            if (value != null) {
                return value.asText();
            }
            value = node.get("audio");
            if (value != null) {
                return value.asText();
            }
            value = node.get("video");
            if (value != null) {
                return value.asText();
            }
            return input;
        }
        catch (JsonProcessingException e) {
            log.warn("Failed to parse JSON: {}", (Object)e.getMessage());
            return input;
        }
    }

    private String detectModalityParameter(String input) {
        try {
            JsonNode node = objectMapper.readTree(input);
            if (node.has("text")) {
                return "text";
            }
            if (node.has("image")) {
                return "image";
            }
            if (node.has("video")) {
                return "video";
            }
            if (node.has("audio")) {
                return "audio";
            }
            return "text";
        }
        catch (JsonProcessingException e) {
            log.warn("Failed to detect modality from input, defaulting to text: {}", (Object)e.getMessage());
            return "text";
        }
    }
}

