/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.dataframe.inference;

import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.inference.TestDocsIterator;
import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.SourceSupplier;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

public class InferenceRunner {
    private static final Logger LOGGER = LogManager.getLogger(InferenceRunner.class);
    private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;
    private final Settings settings;
    private final Client client;
    private final ModelLoadingService modelLoadingService;
    private final ResultsPersisterService resultsPersisterService;
    private final TaskId parentTaskId;
    private final DataFrameAnalyticsConfig config;
    private final ExtractedFields extractedFields;
    private final ProgressTracker progressTracker;
    private final DataCountsTracker dataCountsTracker;
    private final Function<Long, TestDocsIterator> testDocsIteratorFactory;
    private final ThreadPool threadPool;
    private volatile boolean isCancelled;

    InferenceRunner(Settings settings, Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker, Function<Long, TestDocsIterator> testDocsIteratorFactory, ThreadPool threadPool) {
        this.settings = Objects.requireNonNull(settings);
        this.client = Objects.requireNonNull(client);
        this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
        this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
        this.parentTaskId = Objects.requireNonNull(parentTaskId);
        this.config = Objects.requireNonNull(config);
        this.extractedFields = Objects.requireNonNull(extractedFields);
        this.progressTracker = Objects.requireNonNull(progressTracker);
        this.dataCountsTracker = Objects.requireNonNull(dataCountsTracker);
        this.testDocsIteratorFactory = Objects.requireNonNull(testDocsIteratorFactory);
        this.threadPool = threadPool;
    }

    public void cancel() {
        this.isCancelled = true;
    }

    public void run(String modelId, ActionListener<Void> listener) {
        if (this.isCancelled) {
            listener.onResponse(null);
            return;
        }
        LOGGER.info("[{}] Started inference on test data against model [{}]", (Object)this.config.getId(), (Object)modelId);
        SubscribableListener.newForked(l -> this.modelLoadingService.getModelForInternalInference(modelId, (ActionListener<LocalModel>)l)).andThen((Executor)this.threadPool.executor("ml_utility"), this.threadPool.getThreadContext(), this::handleLocalModel).addListener(listener.delegateResponse((delegate, e) -> delegate.onFailure(this.handleException(modelId, (Exception)e))));
    }

    private void handleLocalModel(ActionListener<Void> listener, LocalModel localModel) {
        try (LocalModel localModel2 = localModel;){
            InferenceState inferenceState = this.restoreInferenceState();
            this.dataCountsTracker.setTestDocsCount(inferenceState.processedTestDocsCount);
            TestDocsIterator testDocsIterator = this.testDocsIteratorFactory.apply(inferenceState.lastIncrementalId);
            LOGGER.debug("Loaded inference model [{}]", (Object)localModel);
            this.inferTestDocs(localModel, testDocsIterator, inferenceState.processedTestDocsCount);
            listener.onResponse(null);
        }
    }

    private Exception handleException(String modelId, Exception e) {
        LOGGER.error(() -> Strings.format((String)"[%s] Error running inference on model [%s]", (Object[])new Object[]{this.config.getId(), modelId}), (Throwable)e);
        if (e instanceof ElasticsearchException) {
            ElasticsearchException elasticsearchException = (ElasticsearchException)e;
            return new ElasticsearchStatusException("[{}] failed running inference on model [{}]; cause was [{}]", elasticsearchException.status(), elasticsearchException.getRootCause(), new Object[]{this.config.getId(), modelId, elasticsearchException.getRootCause().getMessage()});
        }
        return ExceptionsHelper.serverError((String)"[{}] failed running inference on model [{}]; cause was [{}]", (Throwable)e, (Object[])new Object[]{this.config.getId(), modelId, e.getMessage()});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private InferenceState restoreInferenceState() {
        SearchRequest searchRequest = new SearchRequest(new String[]{this.config.getDest().getIndex()});
        searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).query((QueryBuilder)QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.termQuery((String)(this.config.getDest().getResultsField() + ".is_training"), (boolean)false))).fetchSource(false).aggregation((AggregationBuilder)AggregationBuilders.max((String)"ml__incremental_id").field("ml__incremental_id")).trackTotalHits(true);
        searchRequest.source(sourceBuilder);
        SearchResponse searchResponse = (SearchResponse)ClientHelper.executeWithHeaders((Map)this.config.getHeaders(), (String)"ml", (Client)this.client, () -> ((ActionFuture)this.client.search(searchRequest)).actionGet());
        try {
            Long lastIncrementalId;
            Max maxIncrementalIdAgg = (Max)searchResponse.getAggregations().get("ml__incremental_id");
            long processedTestDocCount = searchResponse.getHits().getTotalHits().value;
            Long l = lastIncrementalId = processedTestDocCount == 0L ? null : Long.valueOf((long)maxIncrementalIdAgg.value());
            if (lastIncrementalId != null) {
                LOGGER.debug(() -> Strings.format((String)"[%s] Resuming inference; last incremental id [%s]; processed test doc count [%s]", (Object[])new Object[]{this.config.getId(), lastIncrementalId, processedTestDocCount}));
            }
            InferenceState inferenceState = new InferenceState(lastIncrementalId, processedTestDocCount);
            return inferenceState;
        }
        finally {
            searchResponse.decRef();
        }
    }

    private void inferTestDocs(LocalModel model, TestDocsIterator testDocsIterator, long processedTestDocsCount) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"ml_utility"})) : Strings.format((String)"inferTestDocs must execute from [MachineLearning.UTILITY_THREAD_POOL_NAME] but thread is [%s]", (Object[])new Object[]{Thread.currentThread().getName()});
        long totalDocCount = 0L;
        long processedDocCount = processedTestDocsCount;
        try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(this.settings, this::executeBulkRequest);){
            while (testDocsIterator.hasNext()) {
                if (this.isCancelled) {
                    break;
                }
                Deque batch = testDocsIterator.next();
                if (totalDocCount == 0L) {
                    totalDocCount = testDocsIterator.getTotalHits();
                }
                for (SearchHit doc : batch) {
                    this.dataCountsTracker.incrementTestDocsCount();
                    SourceSupplier sourceSupplier = new SourceSupplier(doc);
                    InferenceResults inferenceResults = model.inferNoStats(this.featuresFromDoc(doc, sourceSupplier));
                    bulkIndexer.addAndExecuteIfNeeded(this.createIndexRequest(doc, sourceSupplier, inferenceResults, this.config.getDest().getResultsField()));
                    int progressPercent = Math.min((int)((double)(++processedDocCount) * 100.0 / (double)totalDocCount), 98);
                    this.progressTracker.updateInferenceProgress(progressPercent);
                }
            }
        }
        if (!this.isCancelled) {
            this.progressTracker.updateInferenceProgress(100);
        }
    }

    private Map<String, Object> featuresFromDoc(SearchHit doc, SourceSupplier sourceSupplier) {
        HashMap<String, Object> features = new HashMap<String, Object>();
        for (ExtractedField extractedField : this.extractedFields.getAllFields()) {
            Object[] values = extractedField.value(doc, sourceSupplier);
            if (values.length != 1) continue;
            features.put(extractedField.getName(), values[0]);
        }
        return features;
    }

    private IndexRequest createIndexRequest(SearchHit hit, SourceSupplier sourceSupplier, InferenceResults results, String resultField) {
        LinkedHashMap<String, Boolean> resultsMap = new LinkedHashMap<String, Boolean>(results.asMap());
        resultsMap.put("is_training", false);
        LinkedHashMap<String, LinkedHashMap<String, Boolean>> source = new LinkedHashMap<String, LinkedHashMap<String, Boolean>>((Map<String, LinkedHashMap<String, Boolean>>)sourceSupplier.get());
        source.put(resultField, resultsMap);
        IndexRequest indexRequest = new IndexRequest(hit.getIndex());
        indexRequest.id(hit.getId());
        indexRequest.source(source);
        indexRequest.opType(DocWriteRequest.OpType.INDEX);
        indexRequest.setParentTask(this.parentTaskId);
        return indexRequest;
    }

    private void executeBulkRequest(BulkRequest bulkRequest) {
        this.resultsPersisterService.bulkIndexWithHeadersWithRetry(this.config.getHeaders(), bulkRequest, this.config.getId(), () -> !this.isCancelled, retryMessage -> {});
    }

    public static InferenceRunner create(Settings settings, Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker, ThreadPool threadPool) {
        return new InferenceRunner(settings, client, modelLoadingService, resultsPersisterService, parentTaskId, config, extractedFields, progressTracker, dataCountsTracker, lastIncrementalId -> new TestDocsIterator(new OriginSettingClient(client, "ml"), config, extractedFields, (Long)lastIncrementalId), threadPool);
    }

    private record InferenceState(@Nullable Long lastIncrementalId, long processedTestDocsCount) {
    }
}

