/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote.streaming;

import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import javax.naming.AuthenticationException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorThrottlingException;
import org.opensearch.ml.engine.algorithms.remote.streaming.BaseStreamingHandler;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ValidationException;
import software.amazon.awssdk.services.s3.model.InvalidRequestException;

public class BedrockStreamingHandler
extends BaseStreamingHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(BedrockStreamingHandler.class);
    private final SdkAsyncHttpClient httpClient;
    private final AwsConnector connector;
    private static final String STOP_REASON_TOOL_USE = "StopReason=tool_use";

    public BedrockStreamingHandler(SdkAsyncHttpClient httpClient, AwsConnector connector) {
        this.httpClient = httpClient;
        this.connector = connector;
    }

    @Override
    public void startStream(String action, Map<String, String> parameters, String payload, StreamPredictActionListener<MLTaskResponse, ?> listener) {
        try {
            AtomicBoolean isStreamClosed = new AtomicBoolean(false);
            AtomicReference toolName = new AtomicReference();
            AtomicReference toolInput = new AtomicReference();
            AtomicReference toolUseId = new AtomicReference();
            StringBuilder toolInputAccumulator = new StringBuilder();
            AtomicReference<StreamState> currentState = new AtomicReference<StreamState>(StreamState.STREAMING_CONTENT);
            BedrockRuntimeAsyncClient bedrockClient = this.buildBedrockRuntimeAsyncClient();
            ConverseStreamRequest request = this.buildConverseStreamRequest(payload, parameters);
            ConverseStreamResponseHandler handler = ((ConverseStreamResponseHandler.Builder)((ConverseStreamResponseHandler.Builder)((ConverseStreamResponseHandler.Builder)((ConverseStreamResponseHandler.Builder)ConverseStreamResponseHandler.builder().onResponse(response -> {})).onError(error -> {
                log.error("Converse stream error: {}", (Object)error.getMessage());
                if (this.isThrottlingError((Throwable)error)) {
                    listener.onFailure((Exception)((Object)new RemoteConnectorThrottlingException("Error from remote service: The request was denied due to remote server throttling. To change the retry policy and behavior, please update the connector client_config.", RestStatus.BAD_REQUEST, new Object[0])));
                } else if (this.isClientError((Throwable)error)) {
                    listener.onFailure((Exception)new OpenSearchStatusException("Error from remote service: " + error.getMessage(), RestStatus.BAD_REQUEST, new Object[0]));
                } else {
                    listener.onFailure((Exception)new MLException("Error from remote service: " + error.getMessage(), error));
                }
            })).onComplete(() -> {
                if (currentState.get() != StreamState.WAITING_FOR_TOOL_RESULT) {
                    this.sendCompletionResponse(isStreamClosed, listener);
                } else {
                    log.debug("Tool execution in progress - keeping stream open");
                }
            })).subscriber(event -> this.handleStreamEvent((ConverseStreamOutput)event, listener, isStreamClosed, toolName, toolInput, toolUseId, toolInputAccumulator, currentState))).build();
            bedrockClient.converseStream(request, handler);
        }
        catch (Exception e) {
            log.error("Failed to execute Bedrock streaming", (Throwable)e);
            this.handleError(e, listener);
        }
    }

    @Override
    public void handleError(Throwable error, StreamPredictActionListener<MLTaskResponse, ?> listener) {
        log.error("HTTP streaming error", error);
        listener.onFailure((Exception)new MLException("Fail to execute streaming", error));
    }

    private boolean isThrottlingError(Throwable error) {
        return error.getMessage().contains("throttling") || error.getMessage().contains("TooManyRequestsException") || error.getMessage().contains("Rate exceeded");
    }

    private boolean isClientError(Throwable error) {
        return error instanceof ValidationException || error instanceof InvalidRequestException || error instanceof AuthenticationException;
    }

    private ConverseStreamRequest buildConverseStreamRequest(String payload, Map<String, String> parameters) {
        try {
            ObjectMapper mapper = new ObjectMapper();
            JsonNode payloadJson = mapper.readTree(payload);
            return (ConverseStreamRequest)ConverseStreamRequest.builder().modelId(parameters.get("model")).system((Collection)this.getOptionalNode(payloadJson, "system").map(this::parseSystemMessages).orElse(null)).messages((Collection)this.getOptionalNode(payloadJson, "messages").map(this::parseMessages).orElse(null)).toolConfig((ToolConfiguration)this.getOptionalNode(payloadJson, "toolConfig").map(this::parseToolConfig).orElse(null)).build();
        }
        catch (Exception e) {
            throw new MLException("Failed to parse payload for Bedrock request", (Throwable)e);
        }
    }

    private Optional<JsonNode> getOptionalNode(JsonNode json, String field) {
        return Optional.ofNullable(json.get(field));
    }

    private void handleStreamEvent(ConverseStreamOutput event, StreamPredictActionListener<MLTaskResponse, ?> listener, AtomicBoolean isStreamClosed, AtomicReference<String> toolName, AtomicReference<Map<String, Object>> toolInput, AtomicReference<String> toolUseId, StringBuilder toolInputAccumulator, AtomicReference<StreamState> currentState) {
        switch (currentState.get().ordinal()) {
            case 0: {
                if (this.isToolUseDetected(event)) {
                    currentState.set(StreamState.TOOL_CALL_DETECTED);
                    this.extractToolInfo(event, toolName, toolUseId);
                    break;
                }
                if (this.isContentDelta(event)) {
                    this.sendContentResponse(this.getTextContent(event), false, listener);
                    break;
                }
                if (!this.isStreamComplete(event)) break;
                currentState.set(StreamState.COMPLETED);
                this.sendCompletionResponse(isStreamClosed, listener);
                break;
            }
            case 1: {
                if (!this.isToolInputDelta(event)) break;
                currentState.set(StreamState.ACCUMULATING_TOOL_INPUT);
                this.accumulateToolInput(this.getToolInputFragment(event), toolInput, toolInputAccumulator);
                break;
            }
            case 2: {
                if (this.isToolInputDelta(event)) {
                    this.accumulateToolInput(this.getToolInputFragment(event), toolInput, toolInputAccumulator);
                    break;
                }
                if (!this.isToolInputComplete(event)) break;
                currentState.set(StreamState.WAITING_FOR_TOOL_RESULT);
                listener.onResponse(this.createToolUseResponse(toolName, toolInput, toolUseId));
                break;
            }
            case 3: {
                log.debug("Waiting for tool result - keeping stream open");
                break;
            }
        }
    }

    private void extractToolInfo(ConverseStreamOutput event, AtomicReference<String> toolName, AtomicReference<String> toolUseId) {
        ContentBlockStartEvent startEvent = (ContentBlockStartEvent)event;
        if (startEvent.start() != null && startEvent.start().toolUse() != null) {
            toolName.set(startEvent.start().toolUse().name());
            toolUseId.set(startEvent.start().toolUse().toolUseId());
        }
    }

    private String getTextContent(ConverseStreamOutput event) {
        ContentBlockDeltaEvent contentEvent = (ContentBlockDeltaEvent)event;
        return contentEvent.delta().text();
    }

    private String getToolInputFragment(ConverseStreamOutput event) {
        ContentBlockDeltaEvent contentEvent = (ContentBlockDeltaEvent)event;
        return contentEvent.delta().toolUse().input();
    }

    private boolean isToolUseDetected(ConverseStreamOutput event) {
        return event.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_START;
    }

    private boolean isContentDelta(ConverseStreamOutput event) {
        return event.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA && ((ContentBlockDeltaEvent)event).delta().text() != null;
    }

    private boolean isToolInputDelta(ConverseStreamOutput event) {
        return event.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA && ((ContentBlockDeltaEvent)event).delta().toolUse() != null;
    }

    private boolean isStreamComplete(ConverseStreamOutput event) {
        return event.sdkEventType() == ConverseStreamOutput.EventType.MESSAGE_STOP && !event.toString().contains(STOP_REASON_TOOL_USE);
    }

    private boolean isToolInputComplete(ConverseStreamOutput event) {
        return event.sdkEventType() == ConverseStreamOutput.EventType.MESSAGE_STOP && event.toString().contains(STOP_REASON_TOOL_USE);
    }

    private MLTaskResponse createToolUseResponse(AtomicReference<String> toolName, AtomicReference<Map<String, Object>> toolInput, AtomicReference<String> toolUseId) {
        if (toolName == null || toolInput == null || toolUseId == null) {
            throw new IllegalArgumentException("Tool references cannot be null");
        }
        Map<String, String> wrappedResponse = Map.of("output", Map.of("message", Map.of("content", List.of(Map.of("toolUse", Map.of("name", toolName.get(), "input", toolInput.get(), "toolUseId", toolUseId.get()))))), "stopReason", "tool_use");
        ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(wrappedResponse).build();
        ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build();
        ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
        return new MLTaskResponse((MLOutput)output);
    }

    private void accumulateToolInput(String inputFragment, AtomicReference<Map<String, Object>> toolInput, StringBuilder toolInputAccumulator) {
        if (inputFragment == null) {
            return;
        }
        ObjectMapper objectMapper = new ObjectMapper();
        toolInputAccumulator.append(inputFragment);
        String accumulated = toolInputAccumulator.toString();
        try {
            JsonParser parser = objectMapper.getFactory().createParser(accumulated);
            JsonToken firstToken = parser.nextToken();
            if (firstToken != JsonToken.START_OBJECT) {
                log.debug("Input does not start with an object: {}", (Object)accumulated);
                return;
            }
            int objectDepth = 1;
            while (parser.nextToken() != null) {
                JsonToken currentToken = parser.getCurrentToken();
                if (currentToken == JsonToken.START_OBJECT) {
                    ++objectDepth;
                } else if (currentToken == JsonToken.END_OBJECT) {
                    --objectDepth;
                }
                if (objectDepth != 0) continue;
                if (parser.nextToken() != null) {
                    log.debug("Extra content after JSON object: {}", (Object)accumulated);
                    return;
                }
                Map parsedInput = (Map)objectMapper.readValue(accumulated, Map.class);
                toolInput.set(parsedInput);
                log.debug("Successfully parsed tool input: {}", (Object)parsedInput);
                return;
            }
            log.debug("Incomplete JSON object: {}", (Object)accumulated);
        }
        catch (JsonParseException e) {
            log.debug("Invalid or incomplete JSON: {}", (Object)accumulated);
        }
        catch (IOException e) {
            log.error("Error parsing JSON input", (Throwable)e);
        }
    }

    private BedrockRuntimeAsyncClient buildBedrockRuntimeAsyncClient() {
        StaticCredentialsProvider awsCredentialsProvider = this.connector.getSessionToken() != null ? StaticCredentialsProvider.create((AwsCredentials)AwsSessionCredentials.create((String)this.connector.getAccessKey(), (String)this.connector.getSecretKey(), (String)this.connector.getSessionToken())) : StaticCredentialsProvider.create((AwsCredentials)AwsBasicCredentials.create((String)this.connector.getAccessKey(), (String)this.connector.getSecretKey()));
        return (BedrockRuntimeAsyncClient)((BedrockRuntimeAsyncClientBuilder)((BedrockRuntimeAsyncClientBuilder)((BedrockRuntimeAsyncClientBuilder)BedrockRuntimeAsyncClient.builder().region(Region.of((String)this.connector.getRegion()))).credentialsProvider((AwsCredentialsProvider)awsCredentialsProvider)).httpClient(this.httpClient)).build();
    }

    private List<SystemContentBlock> parseSystemMessages(JsonNode systemArray) {
        return systemArray.findValuesAsText("text").stream().map(text -> (SystemContentBlock)SystemContentBlock.builder().text(text).build()).collect(Collectors.toList());
    }

    private List<Message> parseMessages(JsonNode messagesArray) {
        ArrayList<Message> messages = new ArrayList<Message>();
        for (JsonNode messageItem : messagesArray) {
            messages.add(this.buildMessage(messageItem));
        }
        return messages;
    }

    private Message buildMessage(JsonNode messageItem) {
        String role = messageItem.has("role") && messageItem.get("role") != null ? messageItem.get("role").asText() : "assistant";
        List<ContentBlock> contentBlocks = this.buildContentBlocks(messageItem.get("content"));
        return (Message)Message.builder().role(role).content(contentBlocks).build();
    }

    private List<ContentBlock> buildContentBlocks(JsonNode contentArray) {
        ArrayList<ContentBlock> blocks = new ArrayList<ContentBlock>();
        if (contentArray != null && contentArray.isArray()) {
            for (JsonNode item : contentArray) {
                this.addContentBlock(blocks, item);
            }
        }
        return blocks;
    }

    private void addContentBlock(List<ContentBlock> blocks, JsonNode item) {
        if (item.has("text")) {
            blocks.add((ContentBlock)ContentBlock.builder().text(item.get("text").asText()).build());
        }
        if (item.has("toolResult")) {
            blocks.add(this.buildToolResultBlock(item.get("toolResult")));
        }
        if (item.has("toolUse")) {
            blocks.add(this.buildToolUseBlock(item.get("toolUse")));
        }
    }

    private ContentBlock buildToolResultBlock(JsonNode toolResult) {
        String text = this.extractResultText(toolResult.get("content"));
        return (ContentBlock)ContentBlock.builder().toolResult((ToolResultBlock)ToolResultBlock.builder().toolUseId(toolResult.get("toolUseId").asText()).content(new ToolResultContentBlock[]{(ToolResultContentBlock)ToolResultContentBlock.builder().text(text).build()}).build()).build();
    }

    private String extractResultText(JsonNode content) {
        if (content.isArray() && content.size() > 0) {
            return content.get(0).get("text").asText();
        }
        return content.isTextual() ? content.asText() : "";
    }

    private ContentBlock buildToolUseBlock(JsonNode toolUse) {
        Document input = toolUse.has("input") ? this.buildDocumentFromJsonNode(toolUse.get("input")) : Document.fromMap(Map.of());
        return (ContentBlock)ContentBlock.builder().toolUse((ToolUseBlock)ToolUseBlock.builder().toolUseId(toolUse.get("toolUseId").asText()).name(toolUse.get("name").asText()).input(input).build()).build();
    }

    private ToolConfiguration parseToolConfig(JsonNode toolConfig) {
        if (!toolConfig.has("tools")) {
            return null;
        }
        ArrayList<Tool> tools = new ArrayList<Tool>();
        for (JsonNode toolItem : toolConfig.get("tools")) {
            if (!toolItem.has("toolSpec")) continue;
            tools.add(this.buildTool(toolItem.get("toolSpec")));
        }
        return (ToolConfiguration)ToolConfiguration.builder().tools(tools).build();
    }

    private Tool buildTool(JsonNode toolSpec) {
        Document schema = this.buildDocumentFromJsonNode(toolSpec.get("inputSchema").get("json"));
        return (Tool)Tool.builder().toolSpec((ToolSpecification)ToolSpecification.builder().name(toolSpec.get("name").asText()).description(toolSpec.get("description").asText()).inputSchema((ToolInputSchema)ToolInputSchema.builder().json(schema).build()).build()).build();
    }

    private Document buildDocumentFromJsonNode(JsonNode node) {
        if (node.isObject()) {
            HashMap map = new HashMap();
            node.fields().forEachRemaining(entry -> map.put((String)entry.getKey(), this.buildDocumentFromJsonNode((JsonNode)entry.getValue())));
            return Document.fromMap(map);
        }
        if (node.isArray()) {
            ArrayList<Document> list = new ArrayList<Document>();
            for (JsonNode item : node) {
                list.add(this.buildDocumentFromJsonNode(item));
            }
            return Document.fromList(list);
        }
        if (node.isTextual()) {
            return Document.fromString((String)node.asText());
        }
        if (node.isBoolean()) {
            return Document.fromBoolean((boolean)node.asBoolean());
        }
        if (node.isNumber()) {
            return Document.fromNumber((double)(node.isInt() ? (double)node.asInt() : node.asDouble()));
        }
        return Document.fromString((String)node.toString());
    }

    private static enum StreamState {
        STREAMING_CONTENT,
        TOOL_CALL_DETECTED,
        ACCUMULATING_TOOL_INPUT,
        WAITING_FOR_TOOL_RESULT,
        COMPLETED;

    }
}

