/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;

public class KnnSearchRequestParser {
    static final String INDEX_PARAM = "index";
    static final String ROUTING_PARAM = "routing";
    static final ParseField KNN_SECTION_FIELD = new ParseField("knn", new String[0]);
    static final ParseField FILTER_FIELD = new ParseField("filter", new String[0]);
    private static final ObjectParser<KnnSearchRequestParser, Void> PARSER = new ObjectParser("knn-search");
    private final String[] indices;
    private String routing;
    private KnnSearch knnSearch;
    private List<QueryBuilder> filters;
    private FetchSourceContext fetchSource;
    private List<FieldAndFormat> fields;
    private List<FieldAndFormat> docValueFields;
    private StoredFieldsContext storedFields;

    public static KnnSearchRequestParser parseRestRequest(RestRequest restRequest) throws IOException {
        KnnSearchRequestParser builder = new KnnSearchRequestParser(Strings.splitStringByCommaToArray(restRequest.param(INDEX_PARAM)));
        builder.routing(restRequest.param(ROUTING_PARAM));
        if (restRequest.hasContentOrSourceParam()) {
            try (XContentParser contentParser = restRequest.contentOrSourceParamParser();){
                PARSER.parse(contentParser, (Object)builder, null);
            }
        }
        return builder;
    }

    private KnnSearchRequestParser(String[] indices) {
        this.indices = indices;
    }

    private void knnSearch(KnnSearch knnSearch) {
        this.knnSearch = knnSearch;
    }

    private void filter(List<QueryBuilder> filter) {
        this.filters = filter;
    }

    private void routing(String routing) {
        this.routing = routing;
    }

    private void fetchSource(FetchSourceContext fetchSource) {
        this.fetchSource = fetchSource;
    }

    private void fields(List<FieldAndFormat> fields) {
        this.fields = fields;
    }

    private void docValueFields(List<FieldAndFormat> docValueFields) {
        this.docValueFields = docValueFields;
    }

    private void storedFields(StoredFieldsContext storedFields) {
        this.storedFields = storedFields;
    }

    public void toSearchRequest(SearchRequestBuilder builder) {
        builder.setIndices(this.indices);
        builder.setRouting(this.routing);
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
        sourceBuilder.trackTotalHitsUpTo(Integer.MAX_VALUE);
        if (this.knnSearch == null) {
            throw new IllegalArgumentException("missing required [" + KNN_SECTION_FIELD.getPreferredName() + "] section in search body");
        }
        KnnVectorQueryBuilder queryBuilder = this.knnSearch.toQueryBuilder();
        if (this.filters != null) {
            queryBuilder.addFilterQueries(this.filters);
        }
        sourceBuilder.query(queryBuilder);
        sourceBuilder.size(this.knnSearch.k);
        sourceBuilder.fetchSource(this.fetchSource);
        sourceBuilder.storedFields(this.storedFields);
        if (this.fields != null) {
            for (FieldAndFormat field : this.fields) {
                sourceBuilder.fetchField(field);
            }
        }
        if (this.docValueFields != null) {
            for (FieldAndFormat field : this.docValueFields) {
                sourceBuilder.docValueField(field.field, field.format);
            }
        }
        builder.setSource(sourceBuilder);
    }

    static {
        PARSER.declareField(KnnSearchRequestParser::knnSearch, KnnSearch::parse, KNN_SECTION_FIELD, ObjectParser.ValueType.OBJECT);
        PARSER.declareFieldArray(KnnSearchRequestParser::filter, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), FILTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
        PARSER.declareField((p, request, c) -> request.fetchSource(FetchSourceContext.fromXContent(p)), SearchSourceBuilder._SOURCE_FIELD, ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING);
        PARSER.declareFieldArray(KnnSearchRequestParser::fields, (p, c) -> FieldAndFormat.fromXContent(p), SearchSourceBuilder.FETCH_FIELDS_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
        PARSER.declareFieldArray(KnnSearchRequestParser::docValueFields, (p, c) -> FieldAndFormat.fromXContent(p), SearchSourceBuilder.DOCVALUE_FIELDS_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
        PARSER.declareField((p, request, c) -> request.storedFields(StoredFieldsContext.fromXContent(SearchSourceBuilder.STORED_FIELDS_FIELD.getPreferredName(), p)), SearchSourceBuilder.STORED_FIELDS_FIELD, ObjectParser.ValueType.STRING_ARRAY);
    }

    static class KnnSearch {
        private static final int NUM_CANDS_LIMIT = 10000;
        static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
        static final ParseField K_FIELD = new ParseField("k", new String[0]);
        static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates", new String[0]);
        static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
        private static final ConstructingObjectParser<KnnSearch, Void> PARSER = new ConstructingObjectParser("knn", args -> {
            List vector = (List)args[1];
            float[] vectorArray = new float[vector.size()];
            for (int i = 0; i < vector.size(); ++i) {
                vectorArray[i] = ((Float)vector.get(i)).floatValue();
            }
            return new KnnSearch((String)args[0], vectorArray, (Integer)args[2], (Integer)args[3]);
        });
        final String field;
        final float[] queryVector;
        final int k;
        final int numCands;

        public static KnnSearch parse(XContentParser parser) throws IOException {
            return (KnnSearch)PARSER.parse(parser, null);
        }

        KnnSearch(String field, float[] queryVector, int k, int numCands) {
            this.field = field;
            this.queryVector = queryVector;
            this.k = k;
            this.numCands = numCands;
        }

        public KnnVectorQueryBuilder toQueryBuilder() {
            if (this.k < 1) {
                throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
            }
            if (this.numCands < this.k) {
                throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]");
            }
            if (this.numCands > 10000) {
                throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [10000]");
            }
            return new KnnVectorQueryBuilder(this.field, this.queryVector, (Integer)this.numCands, (Integer)this.numCands, null, null);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            KnnSearch that = (KnnSearch)o;
            return this.k == that.k && this.numCands == that.numCands && Objects.equals(this.field, that.field) && Arrays.equals(this.queryVector, that.queryVector);
        }

        public int hashCode() {
            int result = Objects.hash(this.field, this.k, this.numCands);
            result = 31 * result + Arrays.hashCode(this.queryVector);
            return result;
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
            PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), QUERY_VECTOR_FIELD);
            PARSER.declareInt(ConstructingObjectParser.constructorArg(), K_FIELD);
            PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CANDS_FIELD);
        }
    }
}

