/**
 *    Copyright (C) 2018-present MongoDB, Inc.
 *
 *    This program is free software: you can redistribute it and/or modify
 *    it under the terms of the Server Side Public License, version 1,
 *    as published by MongoDB, Inc.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    Server Side Public License for more details.
 *
 *    You should have received a copy of the Server Side Public License
 *    along with this program. If not, see
 *    <http://www.mongodb.com/licensing/server-side-public-license>.
 *
 *    As a special exception, the copyright holders give permission to link the
 *    code of portions of this program with the OpenSSL library under certain
 *    conditions as described in each individual source file and distribute
 *    linked combinations including the program with the OpenSSL library. You
 *    must comply with the Server Side Public License in all respects for
 *    all of the code used other than as permitted herein. If you modify file(s)
 *    with this exception, you may extend this exception to your version of the
 *    file(s), but you are not obligated to do so. If you do not wish to do so,
 *    delete this exception statement from your version. If you delete this
 *    exception statement from all source files in the program, then also delete
 *    it in the license file.
 */

#include "mongo/platform/basic.h"

#include <memory>

#include "mongo/client/dbclient_cursor.h"
#include "mongo/db/catalog/collection.h"
#include "mongo/db/catalog/database.h"
#include "mongo/db/catalog/index_catalog.h"
#include "mongo/db/client.h"
#include "mongo/db/db_raii.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/exec/index_scan.h"
#include "mongo/db/exec/plan_stage.h"
#include "mongo/db/json.h"
#include "mongo/db/matcher/expression_parser.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/query/plan_executor_factory.h"
#include "mongo/dbtests/dbtests.h"

/**
 * This file tests db/exec/index_scan.cpp
 */

namespace QueryStageTests {

using std::unique_ptr;

class IndexScanBase {
public:
    IndexScanBase() : _client(&_opCtx) {
        dbtests::WriteContextForTests ctx(&_opCtx, ns());

        for (int i = 0; i < numObj(); ++i) {
            BSONObjBuilder bob;
            bob.append("foo", i);
            bob.append("baz", i);
            bob.append("bar", numObj() - i);
            _client.insert(ns(), bob.obj());
        }

        addIndex(BSON("foo" << 1));
        addIndex(BSON("foo" << 1 << "baz" << 1));
    }

    virtual ~IndexScanBase() {
        dbtests::WriteContextForTests ctx(&_opCtx, ns());
        _client.dropCollection(ns());
    }

    void addIndex(const BSONObj& obj) {
        ASSERT_OK(dbtests::createIndex(&_opCtx, ns(), obj));
    }

    int countResults(const IndexScanParams& params, BSONObj filterObj = BSONObj()) {
        AutoGetCollectionForReadCommand ctx(&_opCtx, NamespaceString(ns()));

        StatusWithMatchExpression statusWithMatcher =
            MatchExpressionParser::parse(filterObj, _expCtx);
        verify(statusWithMatcher.isOK());
        unique_ptr<MatchExpression> filterExpr = std::move(statusWithMatcher.getValue());

        unique_ptr<WorkingSet> ws = std::make_unique<WorkingSet>();
        unique_ptr<IndexScan> ix = std::make_unique<IndexScan>(
            _expCtx.get(), ctx.getCollection(), params, ws.get(), filterExpr.get());

        auto statusWithPlanExecutor =
            plan_executor_factory::make(_expCtx,
                                        std::move(ws),
                                        std::move(ix),
                                        &ctx.getCollection(),
                                        PlanYieldPolicy::YieldPolicy::NO_YIELD,
                                        QueryPlannerParams::DEFAULT);
        ASSERT_OK(statusWithPlanExecutor.getStatus());
        auto exec = std::move(statusWithPlanExecutor.getValue());

        int count = 0;
        PlanExecutor::ExecState state;
        for (RecordId dl; PlanExecutor::ADVANCED ==
             (state = exec->getNext(static_cast<BSONObj*>(nullptr), &dl));) {
            ++count;
        }
        ASSERT_EQUALS(PlanExecutor::IS_EOF, state);

        return count;
    }

    void makeGeoData() {
        dbtests::WriteContextForTests ctx(&_opCtx, ns());

        for (int i = 0; i < numObj(); ++i) {
            double lat = double(rand()) / RAND_MAX;
            double lng = double(rand()) / RAND_MAX;
            _client.insert(ns(), BSON("geo" << BSON_ARRAY(lng << lat)));
        }
    }

    const IndexDescriptor* getIndex(const BSONObj& obj) {
        AutoGetCollectionForReadCommand collection(&_opCtx, NamespaceString(ns()));
        std::vector<const IndexDescriptor*> indexes;
        collection->getIndexCatalog()->findIndexesByKeyPattern(
            &_opCtx, obj, IndexCatalog::InclusionPolicy::kReady, &indexes);
        return indexes.empty() ? nullptr : indexes[0];
    }

    IndexScanParams makeIndexScanParams(OperationContext* opCtx,
                                        const IndexDescriptor* descriptor) {
        IndexScanParams params(opCtx, descriptor);
        params.bounds.isSimpleRange = true;
        params.bounds.endKey = BSONObj();
        params.bounds.boundInclusion = BoundInclusion::kIncludeBothStartAndEndKeys;
        params.direction = 1;
        return params;
    }

    static int numObj() {
        return 50;
    }
    static const char* ns() {
        return "unittests.IndexScan";
    }

protected:
    const ServiceContext::UniqueOperationContext _txnPtr = cc().makeOperationContext();
    OperationContext& _opCtx = *_txnPtr;

    boost::intrusive_ptr<ExpressionContext> _expCtx =
        new ExpressionContext(&_opCtx, nullptr, NamespaceString(ns()));

private:
    DBDirectClient _client;
};

class QueryStageIXScanBasic : public IndexScanBase {
public:
    virtual ~QueryStageIXScanBasic() {}

    void run() {
        // foo <= 20
        auto params = makeIndexScanParams(&_opCtx, getIndex(BSON("foo" << 1)));
        params.bounds.startKey = BSON("" << 20);
        params.direction = -1;

        ASSERT_EQUALS(countResults(params), 21);
    }
};

class QueryStageIXScanLowerUpper : public IndexScanBase {
public:
    virtual ~QueryStageIXScanLowerUpper() {}

    void run() {
        // 20 <= foo < 30
        auto params = makeIndexScanParams(&_opCtx, getIndex(BSON("foo" << 1)));
        params.bounds.startKey = BSON("" << 20);
        params.bounds.endKey = BSON("" << 30);
        params.bounds.boundInclusion = BoundInclusion::kIncludeStartKeyOnly;
        params.direction = 1;

        ASSERT_EQUALS(countResults(params), 10);
    }
};

class QueryStageIXScanLowerUpperIncl : public IndexScanBase {
public:
    virtual ~QueryStageIXScanLowerUpperIncl() {}

    void run() {
        // 20 <= foo <= 30
        auto params = makeIndexScanParams(&_opCtx, getIndex(BSON("foo" << 1)));
        params.bounds.startKey = BSON("" << 20);
        params.bounds.endKey = BSON("" << 30);

        ASSERT_EQUALS(countResults(params), 11);
    }
};

class QueryStageIXScanLowerUpperInclFilter : public IndexScanBase {
public:
    virtual ~QueryStageIXScanLowerUpperInclFilter() {}

    void run() {
        // 20 <= foo < 30
        // foo == 25
        auto params = makeIndexScanParams(&_opCtx, getIndex(BSON("foo" << 1)));
        params.bounds.startKey = BSON("" << 20);
        params.bounds.endKey = BSON("" << 30);

        ASSERT_EQUALS(countResults(params, BSON("foo" << 25)), 1);
    }
};

class QueryStageIXScanCantMatch : public IndexScanBase {
public:
    virtual ~QueryStageIXScanCantMatch() {}

    void run() {
        // 20 <= foo < 30
        // bar == 25 (not covered, should error.)
        auto params = makeIndexScanParams(&_opCtx, getIndex(BSON("foo" << 1)));
        params.bounds.startKey = BSON("" << 20);
        params.bounds.endKey = BSON("" << 30);

        ASSERT_THROWS(countResults(params, BSON("baz" << 25)), AssertionException);
    }
};

class All : public OldStyleSuiteSpecification {
public:
    All() : OldStyleSuiteSpecification("query_stage_tests") {}

    void setupTests() {
        add<QueryStageIXScanBasic>();
        add<QueryStageIXScanLowerUpper>();
        add<QueryStageIXScanLowerUpperIncl>();
        add<QueryStageIXScanLowerUpperInclFilter>();
        add<QueryStageIXScanCantMatch>();
    }
};

OldStyleSuiteInitializer<All> queryStageTestsAll;

}  // namespace QueryStageTests
