/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.question_answering;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ServingTranslator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.question_answering.QAConstants;
import org.opensearch.ml.engine.algorithms.question_answering.sentence.DefaultSentenceSegmenter;
import org.opensearch.ml.engine.algorithms.question_answering.sentence.Sentence;
import org.opensearch.ml.engine.algorithms.question_answering.sentence.SentenceSegmenter;

public class SentenceHighlightingQATranslator
implements ServingTranslator {
    @Generated
    private static final Logger log = LogManager.getLogger(SentenceHighlightingQATranslator.class);
    private final SentenceSegmenter segmenter;
    private HuggingFaceTokenizer tokenizer;
    private final MLModelConfig modelConfig;

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    <T> T readFromModelAllConfig(String key, T defaultValue, Class<T> valueType) {
        if (this.modelConfig == null) return defaultValue;
        if (this.modelConfig.getAllConfig() == null) {
            return defaultValue;
        }
        try (XContentParser parser = JsonXContent.jsonXContent.createParser(null, null, this.modelConfig.getAllConfig());){
            Map configMap = parser.map();
            Object value = configMap.get(key);
            if (value == null) {
                T t = defaultValue;
                return t;
            }
            if (valueType == Boolean.class) {
                T t = valueType.cast(Boolean.valueOf(value.toString()));
                return t;
            }
            if (valueType == Integer.class) {
                T t = valueType.cast(Integer.valueOf(value.toString()));
                return t;
            }
            T t = valueType.cast(value.toString());
            return t;
        }
        catch (Exception e) {
            log.warn("Failed to read {} from config, using default value", (Object)key, (Object)e);
            return defaultValue;
        }
    }

    public static SentenceHighlightingQATranslator create(MLModelConfig modelConfig) {
        return SentenceHighlightingQATranslator.builder().modelConfig(modelConfig).build();
    }

    public void prepare(TranslatorContext ctx) throws IOException {
        Path path = ctx.getModel().getModelPath();
        int tokenMaxLength = this.readFromModelAllConfig("token_max_length", QAConstants.DEFAULT_TOKEN_MAX_LENGTH, Integer.class);
        int tokenOverlapStride = this.readFromModelAllConfig("token_overlap_stride", QAConstants.DEFAULT_TOKEN_OVERLAP_STRIDE_LENGTH, Integer.class);
        boolean withOverflowingTokens = this.readFromModelAllConfig("with_overflowing_tokens", QAConstants.DEFAULT_WITH_OVERFLOWING_TOKENS, Boolean.class);
        boolean padding = this.readFromModelAllConfig("padding", QAConstants.DEFAULT_PADDING, Boolean.class);
        this.tokenizer = HuggingFaceTokenizer.builder().optTokenizerPath(path.resolve("tokenizer.json")).optMaxLength(tokenMaxLength).optStride(tokenOverlapStride).optWithOverflowingTokens(withOverflowingTokens).optTruncateSecondOnly().optPadding(padding).build();
    }

    public void setArguments(Map<String, ?> arguments) {
    }

    public NDList processInput(TranslatorContext ctx, Input input) {
        try {
            NDManager manager = ctx.getNDManager();
            String question = input.getAsString("question");
            String context = input.getAsString("context");
            int chunkNumber = Integer.parseInt(input.getAsString("chunk"));
            ctx.setAttachment("question", (Object)question);
            ctx.setAttachment("context", (Object)context);
            List<Sentence> sentences = this.segmenter.segment(context);
            ctx.setAttachment("sentences", sentences);
            int[] wordLevelSentenceIds = this.createWordLevelSentenceIds(sentences, context);
            Encoding targetEncoding = this.getChunkEncoding(question, context, chunkNumber);
            int[] sentenceIdsArray = this.createSentenceIdsArray(targetEncoding, wordLevelSentenceIds, chunkNumber);
            return this.createModelInputs(manager, targetEncoding, sentenceIdsArray);
        }
        catch (Exception e) {
            log.error("Error processing input", (Throwable)e);
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Error processing input: %s", e.getMessage()), e);
        }
    }

    private Encoding getChunkEncoding(String question, String context, int chunkNumber) {
        Encoding fullEncoding = this.tokenizer.encode(question, context);
        if (chunkNumber == 0) {
            return fullEncoding;
        }
        Encoding[] overflowEncodings = fullEncoding.getOverflowing();
        if (overflowEncodings != null && chunkNumber <= overflowEncodings.length) {
            return overflowEncodings[chunkNumber - 1];
        }
        throw new IllegalArgumentException("Invalid chunk number: " + chunkNumber);
    }

    private int[] createSentenceIdsArray(Encoding encoding, int[] wordLevelSentenceIds, int chunkNumber) {
        int contextStartIndex;
        long[] wordIds = encoding.getWordIds();
        int[] sentenceIdsArray = new int[wordIds.length];
        Arrays.fill(sentenceIdsArray, -100);
        long[] typeIds = encoding.getTypeIds();
        for (int i = contextStartIndex = this.findContextStartIndex(typeIds); i < wordIds.length; ++i) {
            long wordId = wordIds[i];
            if (wordId == -1L || wordId >= (long)wordLevelSentenceIds.length) continue;
            sentenceIdsArray[i] = wordLevelSentenceIds[(int)wordId];
        }
        return sentenceIdsArray;
    }

    private int findContextStartIndex(long[] typeIds) {
        for (int i = 0; i < typeIds.length; ++i) {
            if (typeIds[i] != 1L) continue;
            return i;
        }
        return 0;
    }

    private NDList createModelInputs(NDManager manager, Encoding encoding, int[] sentenceIdsArray) {
        NDArray sentenceIdsNDArray = manager.create(sentenceIdsArray);
        NDArray inputIds = manager.create(encoding.getIds());
        NDArray attentionMask = manager.create(encoding.getAttentionMask());
        NDArray tokenTypeIds = manager.create(encoding.getTypeIds());
        sentenceIdsNDArray.setName("sentence_ids");
        inputIds.setName("input_ids");
        attentionMask.setName("attention_mask");
        tokenTypeIds.setName("token_type_ids");
        return new NDList(new NDArray[]{inputIds, attentionMask, tokenTypeIds, sentenceIdsNDArray});
    }

    private int[] createWordLevelSentenceIds(List<Sentence> sentences, String context) {
        String[] contextWords = context.split("\\s+");
        int[] wordSentenceIds = new int[contextWords.length];
        for (int sentIdx = 0; sentIdx < sentences.size(); ++sentIdx) {
            Sentence sentence = sentences.get(sentIdx);
            int startIndex = sentence.getStartIndex();
            int endIndex = sentence.getEndIndex();
            for (int wordIdx = 0; wordIdx < contextWords.length; ++wordIdx) {
                int wordStart = 0;
                for (int i = 0; i < wordIdx; ++i) {
                    wordStart += contextWords[i].length() + 1;
                }
                int wordEnd = wordStart + contextWords[wordIdx].length();
                if (wordStart < startIndex || wordEnd > endIndex) continue;
                wordSentenceIds[wordIdx] = sentIdx;
            }
        }
        return wordSentenceIds;
    }

    public Output processOutput(TranslatorContext ctx, NDList list) {
        try {
            List sentences;
            try {
                sentences = (List)ctx.getAttachment("sentences");
            }
            catch (ClassCastException e) {
                log.error("Failed to cast sentences data to expected format", (Throwable)e);
                return this.createErrorOutput("Failed to process sentences data");
            }
            if (sentences == null || sentences.isEmpty()) {
                return this.createErrorOutput("No sentences found in context");
            }
            HashSet<Integer> highlightedIndices = new HashSet<Integer>();
            for (NDArray array : list) {
                long[] indices;
                for (long idx : indices = array.toLongArray()) {
                    if (idx < 0L || idx >= (long)sentences.size()) continue;
                    highlightedIndices.add((int)idx);
                }
            }
            if (highlightedIndices.isEmpty()) {
                log.warn("No relevant sentences found in model output");
                return this.createErrorOutput("No relevant sentences found");
            }
            ArrayList<SentenceData> sentenceDataList = new ArrayList<SentenceData>();
            for (int i = 0; i < sentences.size(); ++i) {
                Sentence sentence = (Sentence)sentences.get(i);
                boolean isRelevant = highlightedIndices.contains(i);
                sentenceDataList.add(new SentenceData(sentence.getText(), isRelevant, i));
            }
            List<Map<String, Object>> highlights = SentenceHighlightingQATranslator.getRelevantSentenceDetails(sentenceDataList, sentences);
            HashMap<String, List<Map<String, Object>>> outputMap = new HashMap<String, List<Map<String, Object>>>();
            outputMap.put("highlights", highlights);
            ModelTensor tensor = ModelTensor.builder().name("highlights").dataAsMap(outputMap).build();
            Output output = new Output();
            output.add(new ModelTensors(List.of(tensor)).toBytes());
            return output;
        }
        catch (Exception e) {
            log.error("Error processing model output", (Throwable)e);
            return this.createErrorOutput("Error processing model output: " + e.getMessage());
        }
    }

    @NotNull
    private static List<Map<String, Object>> getRelevantSentenceDetails(List<SentenceData> sentenceDataList, List<Sentence> sentences) {
        ArrayList<Map<String, Object>> relevantSentenceDetails = new ArrayList<Map<String, Object>>();
        for (int i = 0; i < sentenceDataList.size(); ++i) {
            SentenceData data = sentenceDataList.get(i);
            if (!data.isRelevant) continue;
            Sentence sentence = sentences.get(i);
            HashMap<String, Object> sentenceDetail = new HashMap<String, Object>();
            sentenceDetail.put("text", data.text);
            sentenceDetail.put("position", i);
            sentenceDetail.put("start", sentence.getStartIndex());
            sentenceDetail.put("end", sentence.getEndIndex());
            relevantSentenceDetails.add(sentenceDetail);
        }
        return relevantSentenceDetails;
    }

    private Output createErrorOutput(String errorMessage) {
        Output output = new Output(400, "Bad Request");
        HashMap<String, Object> errorData = new HashMap<String, Object>();
        errorData.put("error", errorMessage);
        errorData.put("highlights", new ArrayList());
        ModelTensor tensor = ModelTensor.builder().name("error").dataAsMap(errorData).build();
        ModelTensors modelTensorOutput = new ModelTensors(List.of(tensor));
        output.add(modelTensorOutput.toBytes());
        return output;
    }

    @Generated
    private static SentenceSegmenter $default$segmenter() {
        return new DefaultSentenceSegmenter();
    }

    @Generated
    SentenceHighlightingQATranslator(SentenceSegmenter segmenter, HuggingFaceTokenizer tokenizer, MLModelConfig modelConfig) {
        this.segmenter = segmenter;
        this.tokenizer = tokenizer;
        this.modelConfig = modelConfig;
    }

    @Generated
    public static SentenceHighlightingQATranslatorBuilder builder() {
        return new SentenceHighlightingQATranslatorBuilder();
    }

    @Generated
    public SentenceSegmenter getSegmenter() {
        return this.segmenter;
    }

    @Generated
    public HuggingFaceTokenizer getTokenizer() {
        return this.tokenizer;
    }

    @Generated
    public MLModelConfig getModelConfig() {
        return this.modelConfig;
    }

    @Generated
    public static class SentenceHighlightingQATranslatorBuilder {
        @Generated
        private boolean segmenter$set;
        @Generated
        private SentenceSegmenter segmenter$value;
        @Generated
        private HuggingFaceTokenizer tokenizer;
        @Generated
        private MLModelConfig modelConfig;

        @Generated
        SentenceHighlightingQATranslatorBuilder() {
        }

        @Generated
        public SentenceHighlightingQATranslatorBuilder segmenter(SentenceSegmenter segmenter) {
            this.segmenter$value = segmenter;
            this.segmenter$set = true;
            return this;
        }

        @Generated
        public SentenceHighlightingQATranslatorBuilder tokenizer(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
            return this;
        }

        @Generated
        public SentenceHighlightingQATranslatorBuilder modelConfig(MLModelConfig modelConfig) {
            this.modelConfig = modelConfig;
            return this;
        }

        @Generated
        public SentenceHighlightingQATranslator build() {
            SentenceSegmenter segmenter$value = this.segmenter$value;
            if (!this.segmenter$set) {
                segmenter$value = SentenceHighlightingQATranslator.$default$segmenter();
            }
            return new SentenceHighlightingQATranslator(segmenter$value, this.tokenizer, this.modelConfig);
        }

        @Generated
        public String toString() {
            return "SentenceHighlightingQATranslator.SentenceHighlightingQATranslatorBuilder(segmenter$value=" + String.valueOf(this.segmenter$value) + ", tokenizer=" + String.valueOf(this.tokenizer) + ", modelConfig=" + String.valueOf(this.modelConfig) + ")";
        }
    }

    private record SentenceData(String text, boolean isRelevant, int position) {
    }
}

