/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.PriorityQueue;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LongBitSet;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.codec.SparsePostingsEnum;
import org.opensearch.neuralsearch.sparse.common.DocWeightIterator;
import org.opensearch.neuralsearch.sparse.common.IteratorWrapper;
import org.opensearch.neuralsearch.sparse.data.DocumentCluster;
import org.opensearch.neuralsearch.sparse.data.SparseVector;
import org.opensearch.neuralsearch.sparse.query.SparseQueryContext;

public abstract class SeismicBaseScorer
extends Scorer {
    @Generated
    private static final Logger log = LogManager.getLogger(SeismicBaseScorer.class);
    private static final int SEISMIC_HEAP_SIZE = 10;
    protected final HeapWrapper scoreHeap;
    protected final LongBitSet visitedDocId;
    protected final String fieldName;
    protected final SparseQueryContext sparseQueryContext;
    protected final byte[] queryDenseVector;
    protected final Bits acceptedDocs;
    protected SparseVectorReader reader;
    protected List<Scorer> subScorers = new ArrayList<Scorer>();

    public SeismicBaseScorer(LeafReader leafReader, String fieldName, SparseQueryContext sparseQueryContext, int maxDocCount, SparseVector queryVector, @NonNull SparseVectorReader reader, Bits acceptedDocs) throws IOException {
        Objects.requireNonNull(reader, "reader is marked non-null but is null");
        this.visitedDocId = new LongBitSet((long)maxDocCount);
        this.fieldName = fieldName;
        this.sparseQueryContext = sparseQueryContext;
        this.queryDenseVector = queryVector.toDenseVector();
        this.reader = reader;
        this.acceptedDocs = acceptedDocs;
        this.scoreHeap = new HeapWrapper(10);
        this.initialize(leafReader);
    }

    protected void initialize(LeafReader leafReader) throws IOException {
        Terms terms = Terms.getTerms((LeafReader)leafReader, (String)this.fieldName);
        for (String token : this.sparseQueryContext.getTokens()) {
            BytesRef term;
            TermsEnum termsEnum = terms.iterator();
            if (!termsEnum.seekExact(term = new BytesRef((CharSequence)token))) continue;
            PostingsEnum postingsEnum = termsEnum.postings(null, 8);
            if (!(postingsEnum instanceof SparsePostingsEnum)) {
                throw new IllegalStateException(String.format(Locale.ROOT, "posting enum is not SparsePostingsEnum, actual type: %s", postingsEnum == null ? null : postingsEnum.getClass().getName()));
            }
            SparsePostingsEnum sparsePostingsEnum = (SparsePostingsEnum)postingsEnum;
            this.subScorers.add(new SingleScorer(sparsePostingsEnum));
        }
    }

    protected List<Pair<Integer, Integer>> searchUpfront(int resultSize) throws IOException {
        HeapWrapper resultHeap = new HeapWrapper(resultSize);
        for (Scorer scorer : this.subScorers) {
            DocIdSetIterator iterator = scorer.iterator();
            int docId = 0;
            while ((docId = iterator.nextDoc()) != Integer.MAX_VALUE) {
                if (this.acceptedDocs != null && !this.acceptedDocs.get(docId) || this.visitedDocId.get((long)docId)) continue;
                this.visitedDocId.set((long)docId);
                SparseVector doc = this.reader.read(docId);
                if (doc == null) continue;
                int score = doc.dotProduct(this.queryDenseVector);
                this.scoreHeap.add((Pair<Integer, Integer>)Pair.of((Object)docId, (Object)score));
                resultHeap.add((Pair<Integer, Integer>)Pair.of((Object)docId, (Object)score));
            }
        }
        return resultHeap.toOrderedList();
    }

    protected static PriorityQueue<Pair<Integer, Integer>> makeHeap() {
        return new PriorityQueue<Pair<Integer, Integer>>(Comparator.comparingInt(Pair::getRight));
    }

    @Generated
    public SparseVectorReader getReader() {
        return this.reader;
    }

    protected static class HeapWrapper {
        private final PriorityQueue<Pair<Integer, Integer>> heap = SeismicBaseScorer.makeHeap();
        private float heapThreshold = -2.1474836E9f;
        private final int k;

        HeapWrapper(int k) {
            this.k = k;
        }

        public boolean isFull() {
            return this.heap.size() == this.k;
        }

        public void add(Pair<Integer, Integer> pair) {
            if ((float)((Integer)pair.getRight()).intValue() > this.heapThreshold) {
                this.heap.add(pair);
                if (this.heap.size() > this.k) {
                    this.heap.poll();
                    assert (this.heap.peek() != null);
                    this.heapThreshold = ((Integer)this.heap.peek().getRight()).intValue();
                }
            }
        }

        public List<Pair<Integer, Integer>> toOrderedList() {
            ArrayList<Pair<Integer, Integer>> list = new ArrayList<Pair<Integer, Integer>>(this.heap);
            list.sort((a, b) -> Float.compare(((Integer)a.getLeft()).intValue(), ((Integer)b.getLeft()).intValue()));
            return list;
        }

        public int size() {
            return this.heap.size();
        }

        public Pair<Integer, Integer> peek() {
            return this.heap.peek();
        }
    }

    class SingleScorer
    extends Scorer {
        private final IteratorWrapper<DocumentCluster> clusterIter;
        private DocWeightIterator docs = null;

        public SingleScorer(SparsePostingsEnum postingsEnum) throws IOException {
            this.clusterIter = postingsEnum.clusterIterator();
        }

        public int docID() {
            if (this.docs == null) {
                return -1;
            }
            return this.docs.docID();
        }

        public DocIdSetIterator iterator() {
            return new DocIdSetIterator(){

                private DocumentCluster nextQualifiedCluster() {
                    if (SingleScorer.this.clusterIter == null) {
                        return null;
                    }
                    DocumentCluster cluster = SingleScorer.this.clusterIter.next();
                    while (cluster != null) {
                        if (cluster.isShouldNotSkip()) {
                            return cluster;
                        }
                        int score = cluster.getSummary().dotProduct(SeismicBaseScorer.this.queryDenseVector);
                        if (SeismicBaseScorer.this.scoreHeap.isFull() && (float)score < (float)((Integer)Objects.requireNonNull(SeismicBaseScorer.this.scoreHeap.peek()).getRight()).intValue() / SeismicBaseScorer.this.sparseQueryContext.getHeapFactor()) {
                            cluster = SingleScorer.this.clusterIter.next();
                            continue;
                        }
                        return cluster;
                    }
                    return null;
                }

                public int docID() {
                    if (SingleScorer.this.docs == null) {
                        return -1;
                    }
                    return SingleScorer.this.docs.docID();
                }

                public int nextDoc() throws IOException {
                    DocumentCluster cluster = null;
                    if (SingleScorer.this.docs == null) {
                        cluster = this.nextQualifiedCluster();
                    } else {
                        int docId = SingleScorer.this.docs.nextDoc();
                        if (docId != Integer.MAX_VALUE) {
                            return docId;
                        }
                        cluster = this.nextQualifiedCluster();
                    }
                    if (cluster == null) {
                        return Integer.MAX_VALUE;
                    }
                    SingleScorer.this.docs = cluster.getDisi();
                    return SingleScorer.this.docs.nextDoc();
                }

                public int advance(int target) throws IOException {
                    return 0;
                }

                public long cost() {
                    return 0L;
                }
            };
        }

        public float getMaxScore(int upTo) throws IOException {
            return 0.0f;
        }

        public float score() throws IOException {
            return 0.0f;
        }
    }

    public static class ResultsDocValueIterator
    extends DocIdSetIterator {
        private final IteratorWrapper<Pair<Integer, Integer>> resultsIterator;
        private int docId;

        public ResultsDocValueIterator(List<Pair<Integer, Integer>> results) {
            this.resultsIterator = new IteratorWrapper<Pair<Integer, Integer>>(results.iterator());
            this.docId = -1;
        }

        public int docID() {
            return this.docId;
        }

        public int nextDoc() throws IOException {
            if (this.resultsIterator.next() == null) {
                this.docId = Integer.MAX_VALUE;
                return Integer.MAX_VALUE;
            }
            this.docId = (Integer)this.resultsIterator.getCurrent().getLeft();
            return this.docId;
        }

        public int advance(int target) throws IOException {
            if (target <= this.docId) {
                return this.docId;
            }
            while (this.resultsIterator.hasNext()) {
                Pair<Integer, Integer> pair = this.resultsIterator.next();
                if ((Integer)pair.getKey() < target) continue;
                this.docId = (Integer)pair.getKey();
                return this.docId;
            }
            this.docId = Integer.MAX_VALUE;
            return Integer.MAX_VALUE;
        }

        public long cost() {
            if (this.resultsIterator.getCurrent() == null || this.docId == Integer.MAX_VALUE) {
                return 0L;
            }
            return ((Integer)this.resultsIterator.getCurrent().getValue()).intValue();
        }
    }
}

