/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.prediction;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportPredictionStreamTaskAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionStreamTaskAction.class);
    private MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private TransportService transportService;
    private MLModelCacheHelper modelCacheHelper;
    private Client client;
    private SdkClient sdkClient;
    private ClusterService clusterService;
    private MLModelManager mlModelManager;
    private ModelAccessControlHelper modelAccessControlHelper;
    private volatile boolean enableAutomaticDeployment;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;
    public static StreamTransportService streamTransportService;
    private static StreamTransportService streamTransportServiceInstance;

    @Inject
    public TransportPredictionStreamTaskAction(TransportService transportService, ActionFilters actionFilters, MLModelCacheHelper modelCacheHelper, MLPredictTaskRunner mlPredictTaskRunner, ClusterService clusterService, Client client, SdkClient sdkClient, MLModelManager mlModelManager, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings, @Nullable StreamTransportService streamTransportService) {
        super("cluster:admin/opensearch/ml/predict/stream", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mlPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = modelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.sdkClient = sdkClient;
        this.mlModelManager = mlModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        if (streamTransportServiceInstance == null) {
            streamTransportServiceInstance = streamTransportService;
        }
        TransportPredictionStreamTaskAction.streamTransportService = streamTransportServiceInstance;
        this.enableAutomaticDeployment = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> {
            this.enableAutomaticDeployment = it;
        });
        if (streamTransportService != null) {
            streamTransportService.registerRequestHandler("cluster:admin/opensearch/ml/predict/stream", "opensearch_ml_predict_stream", MLPredictionTaskRequest::new, this::messageReceived);
        } else {
            log.warn("StreamTransportService is not available.");
        }
    }

    public static StreamTransportService getStreamTransportService() {
        return streamTransportService;
    }

    public void messageReceived(MLPredictionTaskRequest request, TransportChannel channel, Task task) {
        StreamPredictActionListener streamListener = new StreamPredictActionListener(channel);
        this.doExecute(task, (ActionRequest)request, (ActionListener<MLTaskResponse>)streamListener, channel);
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
        TransportChannel channel = ((MLPredictionTaskRequest)request).getStreamingChannel();
        if (channel != null) {
            this.doExecute(task, request, listener, channel);
        } else {
            listener.onFailure((Exception)new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests"));
        }
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener, TransportChannel channel) {
        final MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest((ActionRequest)request);
        mlPredictionTaskRequest.setStreamingChannel(channel);
        final String modelId = mlPredictionTaskRequest.getModelId();
        final String tenantId = mlPredictionTaskRequest.getTenantId();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, listener)) {
            return;
        }
        User user = mlPredictionTaskRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            mlPredictionTaskRequest.setUser(user);
        }
        final User userInfo = user;
        try (final ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            final ActionListener wrappedListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            MLModel cachedMlModel = this.modelCacheHelper.getModelInfo(modelId);
            ActionListener<MLModel> modelActionListener = new ActionListener<MLModel>(){

                public void onResponse(MLModel mlModel) {
                    context.restore();
                    TransportPredictionStreamTaskAction.this.modelCacheHelper.setModelInfo(modelId, mlModel);
                    FunctionName functionName = mlModel.getAlgorithm();
                    if (FunctionName.isDLModel((FunctionName)functionName) && !TransportPredictionStreamTaskAction.this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
                        throw new UnsupportedOperationException("Streaming is not supported for local model.");
                    }
                    mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
                    TransportPredictionStreamTaskAction.this.modelAccessControlHelper.validateModelGroupAccess(userInfo, TransportPredictionStreamTaskAction.this.mlFeatureEnabledSetting, tenantId, mlModel.getModelGroupId(), "cluster:admin/opensearch/ml/predict/stream", TransportPredictionStreamTaskAction.this.client, TransportPredictionStreamTaskAction.this.sdkClient, (ActionListener<Boolean>)ActionListener.wrap(access -> {
                        if (!access.booleanValue()) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("User Doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                        } else if (TransportPredictionStreamTaskAction.this.modelCacheHelper.getIsModelEnabled(modelId) != null && !TransportPredictionStreamTaskAction.this.modelCacheHelper.getIsModelEnabled(modelId).booleanValue()) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN, new Object[0]));
                        } else if (FunctionName.isDLModel((FunctionName)functionName)) {
                            if (TransportPredictionStreamTaskAction.this.modelCacheHelper.getRateLimiter(modelId) != null && !TransportPredictionStreamTaskAction.this.modelCacheHelper.getRateLimiter(modelId).request()) {
                                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            } else if (userInfo != null && TransportPredictionStreamTaskAction.this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()) != null && !TransportPredictionStreamTaskAction.this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()).request()) {
                                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            } else {
                                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Non-streaming requests are not supported by the streaming transport action", RestStatus.BAD_REQUEST, new Object[0]));
                            }
                        } else {
                            TransportPredictionStreamTaskAction.this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
                            TransportPredictionStreamTaskAction.this.executePredictStream(mlPredictionTaskRequest, (ActionListener<MLTaskResponse>)wrappedListener, modelId);
                        }
                    }, arg_0 -> ((ActionListener)wrappedListener).onFailure(arg_0)));
                }

                public void onFailure(Exception e) {
                    log.error("Failed to find model {}", (Object)modelId, (Object)e);
                    wrappedListener.onFailure(e);
                }
            };
            if (cachedMlModel != null) {
                modelActionListener.onResponse((Object)cachedMlModel);
            } else {
                this.mlModelManager.getModel(modelId, tenantId, modelActionListener);
            }
        }
        catch (Exception e) {
            log.error("Failed to predict " + mlPredictionTaskRequest.toString(), (Throwable)e);
            listener.onFailure(e);
        }
    }

    private void executePredictStream(MLPredictionTaskRequest mlPredictionTaskRequest, ActionListener<MLTaskResponse> wrappedListener, String modelId) {
        String requestId = mlPredictionTaskRequest.getRequestID();
        long startTime = System.nanoTime();
        FunctionName functionName = this.modelCacheHelper.getOptionalFunctionName(modelId).orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm());
        this.mlPredictTaskRunner.run(functionName, mlPredictionTaskRequest, (TransportService)streamTransportService, (ActionListener<MLTaskResponse>)ActionListener.runAfter(wrappedListener, () -> {
            long endTime = System.nanoTime();
            double durationInMs = (double)(endTime - startTime) / 1000000.0;
            this.modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
            this.modelCacheHelper.refreshLastAccessTime(modelId);
            log.debug("completed predict request {} for model {}", (Object)requestId, (Object)modelId);
        }));
    }

    public void validateInputSchema(String modelId, MLInput mlInput) {
        if (this.modelCacheHelper.getModelInterface(modelId) != null && this.modelCacheHelper.getModelInterface(modelId).get("input") != null) {
            String inputSchemaString = this.modelCacheHelper.getModelInterface(modelId).get("input");
            try {
                String InputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
                String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString, inputSchemaString);
                MLNodeUtils.validateSchema(inputSchemaString, processedInputString);
            }
            catch (Exception e) {
                throw new OpenSearchStatusException("Error validating input schema, if you think this is expected, please update your 'input' field in the 'interface' field for this model: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]);
            }
        }
    }
}

