package graph

import (
	"context"
	"errors"
	"fmt"
	"sync"

	"github.com/emirpasic/gods/sets/hashset"
	"github.com/sourcegraph/conc"
	"github.com/sourcegraph/conc/panics"
	"go.opentelemetry.io/otel"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"

	openfgav1 "github.com/openfga/api/proto/openfga/v1"

	"github.com/openfga/openfga/internal/checkutil"
	"github.com/openfga/openfga/internal/concurrency"
	openfgaErrors "github.com/openfga/openfga/internal/errors"
	"github.com/openfga/openfga/internal/validation"
	"github.com/openfga/openfga/pkg/logger"
	serverconfig "github.com/openfga/openfga/pkg/server/config"
	"github.com/openfga/openfga/pkg/storage"
	"github.com/openfga/openfga/pkg/telemetry"
	"github.com/openfga/openfga/pkg/tuple"
	"github.com/openfga/openfga/pkg/typesystem"
)

var tracer = otel.Tracer("internal/graph/check")

type setOperatorType int

var (
	ErrUnknownSetOperator = fmt.Errorf("%w: unexpected set operator type encountered", openfgaErrors.ErrUnknown)
	ErrPanic              = errors.New("panic captured")
)

const (
	unionSetOperator setOperatorType = iota
	intersectionSetOperator
	exclusionSetOperator
)

type checkOutcome struct {
	resp *ResolveCheckResponse
	err  error
}

type LocalChecker struct {
	delegate             CheckResolver
	concurrencyLimit     int
	usersetBatchSize     int
	logger               logger.Logger
	optimizationsEnabled bool
	maxResolutionDepth   uint32
}

type LocalCheckerOption func(d *LocalChecker)

// WithResolveNodeBreadthLimit see server.WithResolveNodeBreadthLimit.
func WithResolveNodeBreadthLimit(limit uint32) LocalCheckerOption {
	return func(d *LocalChecker) {
		d.concurrencyLimit = int(limit)
	}
}

func WithOptimizations(enabled bool) LocalCheckerOption {
	return func(d *LocalChecker) {
		d.optimizationsEnabled = enabled
	}
}

// WithUsersetBatchSize see server.WithUsersetBatchSize.
func WithUsersetBatchSize(usersetBatchSize uint32) LocalCheckerOption {
	return func(d *LocalChecker) {
		d.usersetBatchSize = int(usersetBatchSize)
	}
}

func WithLocalCheckerLogger(logger logger.Logger) LocalCheckerOption {
	return func(d *LocalChecker) {
		d.logger = logger
	}
}

func WithMaxResolutionDepth(depth uint32) LocalCheckerOption {
	return func(d *LocalChecker) {
		d.maxResolutionDepth = depth
	}
}

// NewLocalChecker constructs a LocalChecker that can be used to evaluate a Check
// request locally.
//
// Developers wanting a LocalChecker with other optional layers (e.g caching and others)
// are encouraged to use [[NewOrderedCheckResolvers]] instead.
func NewLocalChecker(opts ...LocalCheckerOption) *LocalChecker {
	checker := &LocalChecker{
		concurrencyLimit:   serverconfig.DefaultResolveNodeBreadthLimit,
		usersetBatchSize:   serverconfig.DefaultUsersetBatchSize,
		maxResolutionDepth: serverconfig.DefaultResolveNodeLimit,
		logger:             logger.NewNoopLogger(),
	}
	// by default, a LocalChecker delegates/dispatchs subproblems to itself (e.g. local dispatch) unless otherwise configured.
	checker.delegate = checker

	for _, opt := range opts {
		opt(checker)
	}

	return checker
}

// SetDelegate sets this LocalChecker's dispatch delegate.
func (c *LocalChecker) SetDelegate(delegate CheckResolver) {
	c.delegate = delegate
}

// GetDelegate sets this LocalChecker's dispatch delegate.
func (c *LocalChecker) GetDelegate() CheckResolver {
	return c.delegate
}

// CheckHandlerFunc defines a function that evaluates a CheckResponse or returns an error
// otherwise.
type CheckHandlerFunc func(ctx context.Context) (*ResolveCheckResponse, error)

// CheckFuncReducer defines a function that combines or reduces one or more CheckHandlerFunc into
// a single CheckResponse with a maximum limit on the number of concurrent evaluations that can be
// in flight at any given time.
type CheckFuncReducer func(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (*ResolveCheckResponse, error)

// resolver concurrently resolves one or more CheckHandlerFunc and yields the results on the provided resultChan.
// Callers of the 'resolver' function should be sure to invoke the callback returned from this function to ensure
// every concurrent check is evaluated. The concurrencyLimit can be set to provide a maximum number of concurrent
// evaluations in flight at any point.
func resolver(ctx context.Context, concurrencyLimit int, resultChan chan<- checkOutcome, handlers ...CheckHandlerFunc) func() error {
	limiter := make(chan struct{}, concurrencyLimit)

	var wg conc.WaitGroup

	checker := func(fn CheckHandlerFunc) {
		defer func() {
			<-limiter
		}()

		resolved := make(chan checkOutcome, 1)

		if ctx.Err() != nil {
			resultChan <- checkOutcome{nil, ctx.Err()}
			return
		}

		wg.Go(func() {
			defer close(resolved)

			resp, err := fn(ctx)
			resolved <- checkOutcome{resp, err}
		})

		select {
		case <-ctx.Done():
			return
		case res := <-resolved:
			resultChan <- res
		}
	}

	wg.Go(func() {
	outer:
		for _, handler := range handlers {
			fn := handler // capture loop var

			select {
			case limiter <- struct{}{}:
				wg.Go(func() {
					checker(fn)
				})
			case <-ctx.Done():
				break outer
			}
		}
	})

	return func() error {
		recoveredError := wg.WaitAndRecover()
		close(limiter)

		if recoveredError != nil {
			return fmt.Errorf("%w: %s", ErrPanic, recoveredError.AsError())
		}

		return nil
	}
}

// union implements a CheckFuncReducer that requires any of the provided CheckHandlerFunc to resolve
// to an allowed outcome. The first allowed outcome causes premature termination of the reducer.
func union(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (resp *ResolveCheckResponse, err error) {
	ctx, cancel := context.WithCancel(ctx)
	resultChan := make(chan checkOutcome, len(handlers))

	drain := resolver(ctx, concurrencyLimit, resultChan, handlers...)

	defer func() {
		cancel()
		drainErr := drain()
		if drainErr != nil {
			err = drainErr
			resp = nil
		}
		close(resultChan)
	}()

	var elErr error
	var cycleDetected bool
	for i := 0; i < len(handlers); i++ {
		select {
		case result := <-resultChan:
			if result.err != nil {
				elErr = result.err
				continue
			}

			if result.resp.GetCycleDetected() {
				cycleDetected = true
			}

			if result.resp.GetAllowed() {
				resp = result.resp
				return
			}
		case <-ctx.Done():
			err = ctx.Err()
			return
		}
	}

	if elErr != nil {
		err = elErr
		return
	}

	resp = &ResolveCheckResponse{
		Allowed: false,
		ResolutionMetadata: ResolveCheckResponseMetadata{
			CycleDetected: cycleDetected,
		},
	}

	return
}

// intersection implements a CheckFuncReducer that requires all of the provided CheckHandlerFunc to resolve
// to an allowed outcome. The first falsey or erroneous outcome causes premature termination of the reducer.
func intersection(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (resp *ResolveCheckResponse, err error) {
	if len(handlers) == 0 {
		return &ResolveCheckResponse{
			Allowed: false,
		}, nil
	}

	span := trace.SpanFromContext(ctx)

	ctx, cancel := context.WithCancel(ctx)
	resultChan := make(chan checkOutcome, len(handlers))

	drain := resolver(ctx, concurrencyLimit, resultChan, handlers...)

	defer func() {
		cancel()
		drainErr := drain()
		if drainErr != nil {
			err = drainErr
		}
		close(resultChan)
	}()

	var elErr error
	for i := 0; i < len(handlers); i++ {
		select {
		case result := <-resultChan:
			if result.err != nil {
				telemetry.TraceError(span, result.err)
				elErr = errors.Join(elErr, result.err)
				continue
			}

			if result.resp.GetCycleDetected() || !result.resp.GetAllowed() {
				resp = result.resp
				return
			}
		case <-ctx.Done():
			err = ctx.Err()
			return
		}
	}

	// all operands are either truthy or we've seen at least one error
	if elErr != nil {
		err = elErr
		return
	}

	resp = &ResolveCheckResponse{
		Allowed: true,
	}

	return
}

// exclusion implements a CheckFuncReducer that requires a 'base' CheckHandlerFunc to resolve to an allowed
// outcome and a 'sub' CheckHandlerFunc to resolve to a falsey outcome. The base and sub computations are
// handled concurrently relative to one another.
func exclusion(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (*ResolveCheckResponse, error) {
	if len(handlers) != 2 {
		return nil, fmt.Errorf("%w, expected two rewrite operands for exclusion operator, but got '%d'", openfgaErrors.ErrUnknown, len(handlers))
	}

	span := trace.SpanFromContext(ctx)

	limiter := make(chan struct{}, concurrencyLimit)

	ctx, cancel := context.WithCancel(ctx)
	baseChan := make(chan checkOutcome, 1)
	subChan := make(chan checkOutcome, 1)

	var wg sync.WaitGroup

	defer func() {
		cancel()
		wg.Wait()
		close(baseChan)
		close(subChan)
	}()

	baseHandler := handlers[0]
	subHandler := handlers[1]

	limiter <- struct{}{}
	wg.Add(1)
	go func() {
		recoveredError := panics.Try(func() {
			defer func() {
				wg.Done()
				<-limiter
			}()

			resp, err := baseHandler(ctx)
			baseChan <- checkOutcome{resp, err}
		})

		if recoveredError != nil {
			baseChan <- checkOutcome{nil, fmt.Errorf("%w: %s", ErrPanic, recoveredError.AsError())}
		}
	}()

	limiter <- struct{}{}
	wg.Add(1)
	go func() {
		recoveredError := panics.Try(func() {
			defer func() {
				wg.Done()
				<-limiter
			}()

			resp, err := subHandler(ctx)
			subChan <- checkOutcome{resp, err}
		})
		if recoveredError != nil {
			subChan <- checkOutcome{nil, fmt.Errorf("%w: %s", ErrPanic, recoveredError.AsError())}
		}
	}()

	response := &ResolveCheckResponse{
		Allowed: false,
	}

	var baseErr error
	var subErr error

	for i := 0; i < len(handlers); i++ {
		select {
		case baseResult := <-baseChan:
			if baseResult.err != nil {
				telemetry.TraceError(span, baseResult.err)
				baseErr = baseResult.err
				continue
			}

			if baseResult.resp.GetCycleDetected() {
				return &ResolveCheckResponse{
					Allowed: false,
					ResolutionMetadata: ResolveCheckResponseMetadata{
						CycleDetected: true,
					},
				}, nil
			}

			if !baseResult.resp.GetAllowed() {
				return response, nil
			}

		case subResult := <-subChan:
			if subResult.err != nil {
				telemetry.TraceError(span, subResult.err)
				subErr = subResult.err
				continue
			}

			if subResult.resp.GetCycleDetected() {
				return &ResolveCheckResponse{
					Allowed: false,
					ResolutionMetadata: ResolveCheckResponseMetadata{
						CycleDetected: true,
					},
				}, nil
			}

			if subResult.resp.GetAllowed() {
				return response, nil
			}
		case <-ctx.Done():
			return nil, ctx.Err()
		}
	}

	// base is either (true) or error, sub is either (false) or error:
	// true, false - true
	// true, error - error
	// error, false - error
	// error, error - error
	if baseErr != nil || subErr != nil {
		return nil, errors.Join(baseErr, subErr)
	}

	return &ResolveCheckResponse{
		Allowed: true,
	}, nil
}

// Close is a noop.
func (c *LocalChecker) Close() {
}

// dispatch clones the parent request, modifies its metadata and tupleKey, and dispatches the new request
// to the CheckResolver this LocalChecker was constructed with.
func (c *LocalChecker) dispatch(_ context.Context, parentReq *ResolveCheckRequest, tk *openfgav1.TupleKey) CheckHandlerFunc {
	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		parentReq.GetRequestMetadata().DispatchCounter.Add(1)
		childRequest := parentReq.clone()
		childRequest.TupleKey = tk
		childRequest.GetRequestMetadata().Depth++

		resp, err := c.delegate.ResolveCheck(ctx, childRequest)
		if err != nil {
			return nil, err
		}
		return resp, nil
	}
}

var _ CheckResolver = (*LocalChecker)(nil)

// ResolveCheck implements [[CheckResolver.ResolveCheck]].
func (c *LocalChecker) ResolveCheck(
	ctx context.Context,
	req *ResolveCheckRequest,
) (*ResolveCheckResponse, error) {
	if ctx.Err() != nil {
		return nil, ctx.Err()
	}

	ctx, span := tracer.Start(ctx, "ResolveCheck", trace.WithAttributes(
		attribute.String("store_id", req.GetStoreID()),
		attribute.String("resolver_type", "LocalChecker"),
		attribute.String("tuple_key", tuple.TupleKeyWithConditionToString(req.GetTupleKey())),
	))
	defer span.End()

	if req.GetRequestMetadata().Depth == c.maxResolutionDepth {
		return nil, ErrResolutionDepthExceeded
	}

	cycle := c.hasCycle(req)
	if cycle {
		span.SetAttributes(attribute.Bool("cycle_detected", true))
		return &ResolveCheckResponse{
			Allowed: false,
			ResolutionMetadata: ResolveCheckResponseMetadata{
				CycleDetected: true,
			},
		}, nil
	}

	tupleKey := req.GetTupleKey()
	object := tupleKey.GetObject()
	relation := tupleKey.GetRelation()

	if tuple.IsSelfDefining(req.GetTupleKey()) {
		return &ResolveCheckResponse{
			Allowed: true,
		}, nil
	}

	typesys, ok := typesystem.TypesystemFromContext(ctx)
	if !ok {
		return nil, fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
	}
	_, ok = storage.RelationshipTupleReaderFromContext(ctx)
	if !ok {
		return nil, fmt.Errorf("%w: relationship tuple reader datastore missing in context", openfgaErrors.ErrUnknown)
	}

	objectType, _ := tuple.SplitObject(object)
	rel, err := typesys.GetRelation(objectType, relation)
	if err != nil {
		return nil, fmt.Errorf("relation '%s' undefined for object type '%s'", relation, objectType)
	}

	hasPath, err := typesys.PathExists(tupleKey.GetUser(), relation, objectType)
	if err != nil {
		return nil, err
	}
	if !hasPath {
		return &ResolveCheckResponse{
			Allowed: false,
		}, nil
	}

	resp, err := c.CheckRewrite(ctx, req, rel.GetRewrite())(ctx)
	if err != nil {
		telemetry.TraceError(span, err)
		return nil, err
	}

	return resp, nil
}

// hasCycle returns true if a cycle has been found. It modifies the request object.
func (c *LocalChecker) hasCycle(req *ResolveCheckRequest) bool {
	key := tuple.TupleKeyToString(req.GetTupleKey())
	if req.VisitedPaths == nil {
		req.VisitedPaths = map[string]struct{}{}
	}

	_, cycleDetected := req.VisitedPaths[key]
	if cycleDetected {
		return true
	}

	req.VisitedPaths[key] = struct{}{}
	return false
}

func (c *LocalChecker) checkPublicAssignable(ctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
	typesys, _ := typesystem.TypesystemFromContext(ctx)
	ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
	storeID := req.GetStoreID()
	reqTupleKey := req.GetTupleKey()
	userType := tuple.GetType(reqTupleKey.GetUser())
	wildcardRelationReference := typesystem.WildcardRelationReference(userType)
	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		ctx, span := tracer.Start(ctx, "checkPublicAssignable")
		defer span.End()

		response := &ResolveCheckResponse{
			Allowed: false,
		}

		opts := storage.ReadUsersetTuplesOptions{
			Consistency: storage.ConsistencyOptions{
				Preference: req.GetConsistency(),
			},
		}

		// We want to query via ReadUsersetTuples instead of ReadUserTuple tuples to take
		// advantage of the storage wrapper cache
		// (https://github.com/openfga/openfga/blob/af054d9693bd7ebd0420456b144c2fb6888aaf87/internal/graph/storagewrapper.go#L139).
		// In the future, if storage wrapper cache is available for ReadUserTuple, we can switch it to ReadUserTuple.
		iter, err := ds.ReadUsersetTuples(ctx, storeID, storage.ReadUsersetTuplesFilter{
			Object:                      reqTupleKey.GetObject(),
			Relation:                    reqTupleKey.GetRelation(),
			AllowedUserTypeRestrictions: []*openfgav1.RelationReference{wildcardRelationReference},
		}, opts)
		if err != nil {
			return nil, err
		}

		filteredIter := storage.NewConditionsFilteredTupleKeyIterator(
			storage.NewFilteredTupleKeyIterator(
				storage.NewTupleKeyIteratorFromTupleIterator(iter),
				validation.FilterInvalidTuples(typesys),
			),
			checkutil.BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
		)
		defer filteredIter.Stop()

		_, err = filteredIter.Next(ctx)
		if err != nil {
			if errors.Is(err, storage.ErrIteratorDone) {
				return response, nil
			}
			return nil, err
		}
		// when we get to here, it means there is public wild card assigned
		span.SetAttributes(attribute.Bool("allowed", true))
		response.Allowed = true
		return response, nil
	}
}

func (c *LocalChecker) checkDirectUserTuple(ctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
	typesys, _ := typesystem.TypesystemFromContext(ctx)

	reqTupleKey := req.GetTupleKey()

	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		ctx, span := tracer.Start(ctx, "checkDirectUserTuple",
			trace.WithAttributes(attribute.String("tuple_key", tuple.TupleKeyWithConditionToString(reqTupleKey))))
		defer span.End()

		response := &ResolveCheckResponse{
			Allowed: false,
		}

		ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
		storeID := req.GetStoreID()

		opts := storage.ReadUserTupleOptions{
			Consistency: storage.ConsistencyOptions{
				Preference: req.GetConsistency(),
			},
		}
		t, err := ds.ReadUserTuple(ctx, storeID, reqTupleKey, opts)
		if err != nil {
			if errors.Is(err, storage.ErrNotFound) {
				return response, nil
			}

			return nil, err
		}

		// filter out invalid tuples yielded by the database query
		tupleKey := t.GetKey()
		err = validation.ValidateTupleForRead(typesys, tupleKey)
		if err != nil {
			return response, nil
		}
		tupleKeyConditionFilter := checkutil.BuildTupleKeyConditionFilter(ctx, req.Context, typesys)
		conditionMet, err := tupleKeyConditionFilter(tupleKey)
		if err != nil {
			telemetry.TraceError(span, err)
			return nil, err
		}
		if conditionMet {
			span.SetAttributes(attribute.Bool("allowed", true))
			response.Allowed = true
		}
		return response, nil
	}
}

// helper function to return whether checkDirectUserTuple should run.
func shouldCheckDirectTuple(ctx context.Context, reqTupleKey *openfgav1.TupleKey) bool {
	typesys, _ := typesystem.TypesystemFromContext(ctx)

	objectType := tuple.GetType(reqTupleKey.GetObject())
	relation := reqTupleKey.GetRelation()

	isDirectlyRelated, _ := typesys.IsDirectlyRelated(
		typesystem.DirectRelationReference(objectType, relation),                                                           // target
		typesystem.DirectRelationReference(tuple.GetType(reqTupleKey.GetUser()), tuple.GetRelation(reqTupleKey.GetUser())), // source
	)

	return isDirectlyRelated
}

// helper function to return whether checkPublicAssignable should run.
func shouldCheckPublicAssignable(ctx context.Context, reqTupleKey *openfgav1.TupleKey) bool {
	typesys, _ := typesystem.TypesystemFromContext(ctx)

	objectType := tuple.GetType(reqTupleKey.GetObject())
	relation := reqTupleKey.GetRelation()

	// if the user tuple is userset, by definition it cannot be a wildcard
	if tuple.IsObjectRelation(reqTupleKey.GetUser()) {
		return false
	}

	isPubliclyAssignable, _ := typesys.IsPubliclyAssignable(
		typesystem.DirectRelationReference(objectType, relation), // target
		tuple.GetType(reqTupleKey.GetUser()),
	)
	return isPubliclyAssignable
}

// checkDirect composes three CheckHandlerFunc which evaluate direct relationships with the provided
// 'object#relation'. The first handler looks up direct matches on the provided 'object#relation@user',
// the second handler looks up wildcard matches on the provided 'object#relation@user:*',
// while the third handler looks up relationships between the target 'object#relation' and any usersets
// related to it.
func (c *LocalChecker) checkDirect(parentctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		ctx, span := tracer.Start(ctx, "checkDirect")
		defer span.End()

		if ctx.Err() != nil {
			return nil, ctx.Err()
		}

		typesys, _ := typesystem.TypesystemFromContext(parentctx) // note: use of 'parentctx' not 'ctx' - this is important

		ds, _ := storage.RelationshipTupleReaderFromContext(parentctx)

		storeID := req.GetStoreID()
		reqTupleKey := req.GetTupleKey()
		objectType := tuple.GetType(reqTupleKey.GetObject())
		relation := reqTupleKey.GetRelation()

		// directlyRelatedUsersetTypes could be "group#member"
		directlyRelatedUsersetTypes, _ := typesys.DirectlyRelatedUsersets(objectType, relation)

		// TODO(jpadilla): can we lift this function up?
		checkDirectUsersetTuples := func(ctx context.Context) (*ResolveCheckResponse, error) {
			ctx, span := tracer.Start(ctx, "checkDirectUsersetTuples", trace.WithAttributes(
				attribute.String("userset", tuple.ToObjectRelationString(reqTupleKey.GetObject(), reqTupleKey.GetRelation())),
				attribute.String("resolver", "slow"),
			))
			defer span.End()

			if ctx.Err() != nil {
				return nil, ctx.Err()
			}

			opts := storage.ReadUsersetTuplesOptions{
				Consistency: storage.ConsistencyOptions{
					Preference: req.GetConsistency(),
				},
			}

			resolver := c.defaultUserset
			isUserset := tuple.IsObjectRelation(reqTupleKey.GetUser())
			userType := tuple.GetType(reqTupleKey.GetUser())

			if !isUserset {
				if len(directlyRelatedUsersetTypes) < 2 && typesys.UsersetUseWeight2Resolver(objectType, relation, userType, directlyRelatedUsersetTypes) {
					// If there are more than 1 directly related userset types of the same type, we cannot do userset optimization because
					// we cannot rely on the fact that the object ID matches. Instead, we need to take into consideration
					// on the relation as well.
					resolver = c.weight2Userset
					span.SetAttributes(attribute.String("resolver", "weight2"))
				} else if typesys.UsersetUseRecursiveResolver(objectType, relation, userType) {
					resolver = c.recursiveUserset
					span.SetAttributes(attribute.String("resolver", "recursive"))
				}
			}

			iter, err := ds.ReadUsersetTuples(ctx, storeID, storage.ReadUsersetTuplesFilter{
				Object:                      reqTupleKey.GetObject(),
				Relation:                    reqTupleKey.GetRelation(),
				AllowedUserTypeRestrictions: directlyRelatedUsersetTypes,
			}, opts)
			if err != nil {
				return nil, err
			}

			filteredIter := storage.NewConditionsFilteredTupleKeyIterator(
				storage.NewFilteredTupleKeyIterator(
					storage.NewTupleKeyIteratorFromTupleIterator(iter),
					validation.FilterInvalidTuples(typesys),
				),
				checkutil.BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
			)
			defer filteredIter.Stop()

			return resolver(ctx, req, filteredIter)
		}

		var checkFuncs []CheckHandlerFunc

		if shouldCheckDirectTuple(ctx, req.GetTupleKey()) {
			checkFuncs = []CheckHandlerFunc{c.checkDirectUserTuple(parentctx, req)}
		}

		if shouldCheckPublicAssignable(ctx, reqTupleKey) {
			checkFuncs = append(checkFuncs, c.checkPublicAssignable(parentctx, req))
		}

		if len(directlyRelatedUsersetTypes) > 0 {
			checkFuncs = append(checkFuncs, checkDirectUsersetTuples)
		}

		resp, err := union(ctx, c.concurrencyLimit, checkFuncs...)
		if err != nil {
			telemetry.TraceError(span, err)
			return nil, err
		}

		return resp, nil
	}
}

// checkComputedUserset evaluates the Check request with the rewritten relation (e.g. the computed userset relation).
func (c *LocalChecker) checkComputedUserset(_ context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset) CheckHandlerFunc {
	rewrittenTupleKey := tuple.NewTupleKey(
		req.GetTupleKey().GetObject(),
		rewrite.GetComputedUserset().GetRelation(),
		req.GetTupleKey().GetUser(),
	)

	childRequest := req.clone()
	childRequest.TupleKey = rewrittenTupleKey

	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		ctx, span := tracer.Start(ctx, "checkComputedUserset")
		defer span.End()
		// No dispatch here, as we don't want to increase resolution depth.
		return c.ResolveCheck(ctx, childRequest)
	}
}

// checkTTU looks up all tuples of the target tupleset relation on the provided object and for each one
// of them evaluates the computed userset of the TTU rewrite rule for them.
func (c *LocalChecker) checkTTU(parentctx context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset) CheckHandlerFunc {
	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		ctx, span := tracer.Start(ctx, "checkTTU", trace.WithAttributes(attribute.String("resolver", "slow")))
		defer span.End()

		if ctx.Err() != nil {
			return nil, ctx.Err()
		}

		typesys, _ := typesystem.TypesystemFromContext(parentctx) // note: use of 'parentctx' not 'ctx' - this is important

		ds, _ := storage.RelationshipTupleReaderFromContext(parentctx)

		objectType, relation := tuple.GetType(req.GetTupleKey().GetObject()), req.GetTupleKey().GetRelation()

		userType := tuple.GetType(req.GetTupleKey().GetUser())

		ctx = typesystem.ContextWithTypesystem(ctx, typesys)
		ctx = storage.ContextWithRelationshipTupleReader(ctx, ds)

		tuplesetRelation := rewrite.GetTupleToUserset().GetTupleset().GetRelation()
		computedRelation := rewrite.GetTupleToUserset().GetComputedUserset().GetRelation()

		tk := req.GetTupleKey()
		object := tk.GetObject()

		span.SetAttributes(
			attribute.String("tupleset_relation", tuple.ToObjectRelationString(tuple.GetType(object), tuplesetRelation)),
			attribute.String("computed_relation", computedRelation),
		)

		opts := storage.ReadOptions{
			Consistency: storage.ConsistencyOptions{
				Preference: req.GetConsistency(),
			},
		}

		storeID := req.GetStoreID()
		iter, err := ds.Read(
			ctx,
			storeID,
			tuple.NewTupleKey(object, tuplesetRelation, ""),
			opts,
		)
		if err != nil {
			return nil, err
		}

		// filter out invalid tuples yielded by the database iterator
		filteredIter := storage.NewConditionsFilteredTupleKeyIterator(
			storage.NewFilteredTupleKeyIterator(
				storage.NewTupleKeyIteratorFromTupleIterator(iter),
				validation.FilterInvalidTuples(typesys),
			),
			checkutil.BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
		)
		defer filteredIter.Stop()

		resolver := c.defaultTTU

		// TODO: optimize the case where user is an userset.
		// If the user is a userset, we will not be able to use the shortcut because the algo
		// will look up the objects associated with user.
		isUserset := tuple.IsObjectRelation(tk.GetUser())

		if !isUserset {
			if typesys.TTUUseWeight2Resolver(objectType, relation, userType, rewrite.GetTupleToUserset()) {
				resolver = c.weight2TTU
				span.SetAttributes(attribute.String("resolver", "weight2"))
			} else if typesys.TTUUseRecursiveResolver(objectType, relation, userType, rewrite.GetTupleToUserset()) {
				resolver = c.recursiveTTU
				span.SetAttributes(attribute.String("resolver", "recursive"))
			}
		}

		return resolver(ctx, req, rewrite, filteredIter)
	}
}

func (c *LocalChecker) checkSetOperation(
	ctx context.Context,
	req *ResolveCheckRequest,
	setOpType setOperatorType,
	reducer CheckFuncReducer,
	children ...*openfgav1.Userset,
) CheckHandlerFunc {
	var handlers []CheckHandlerFunc

	var reducerKey string
	switch setOpType {
	case unionSetOperator, intersectionSetOperator, exclusionSetOperator:
		if setOpType == unionSetOperator {
			reducerKey = "union"
		}

		if setOpType == intersectionSetOperator {
			reducerKey = "intersection"
		}

		if setOpType == exclusionSetOperator {
			reducerKey = "exclusion"
		}

		for _, child := range children {
			handlers = append(handlers, c.CheckRewrite(ctx, req, child))
		}
	default:
		return func(ctx context.Context) (*ResolveCheckResponse, error) {
			return nil, ErrUnknownSetOperator
		}
	}

	return func(ctx context.Context) (*ResolveCheckResponse, error) {
		var err error
		var resp *ResolveCheckResponse
		ctx, span := tracer.Start(ctx, reducerKey)
		defer func() {
			if err != nil {
				telemetry.TraceError(span, err)
			}
			span.End()
		}()

		resp, err = reducer(ctx, c.concurrencyLimit, handlers...)
		return resp, err
	}
}

func (c *LocalChecker) CheckRewrite(
	ctx context.Context,
	req *ResolveCheckRequest,
	rewrite *openfgav1.Userset,
) CheckHandlerFunc {
	switch rw := rewrite.GetUserset().(type) {
	case *openfgav1.Userset_This:
		return c.checkDirect(ctx, req)
	case *openfgav1.Userset_ComputedUserset:
		return c.checkComputedUserset(ctx, req, rewrite)
	case *openfgav1.Userset_TupleToUserset:
		return c.checkTTU(ctx, req, rewrite)
	case *openfgav1.Userset_Union:
		return c.checkSetOperation(ctx, req, unionSetOperator, union, rw.Union.GetChild()...)
	case *openfgav1.Userset_Intersection:
		return c.checkSetOperation(ctx, req, intersectionSetOperator, intersection, rw.Intersection.GetChild()...)
	case *openfgav1.Userset_Difference:
		return c.checkSetOperation(ctx, req, exclusionSetOperator, exclusion, rw.Difference.GetBase(), rw.Difference.GetSubtract())
	default:
		return func(ctx context.Context) (*ResolveCheckResponse, error) {
			return nil, ErrUnknownSetOperator
		}
	}
}

// TODO: make these subsequent functions generic and move outside this package.

type usersetMessage struct {
	userset string
	err     error
}

// streamedLookupUsersetFromIterator returns a channel with all the usersets given by the input iterator.
// It closes the channel in the end.
func streamedLookupUsersetFromIterator(ctx context.Context, iter storage.TupleMapper) <-chan usersetMessage {
	usersetMessageChan := make(chan usersetMessage, 100)

	go func() {
		defer func() {
			if r := recover(); r != nil {
				concurrency.TrySendThroughChannel(ctx, usersetMessage{err: fmt.Errorf("%w: %s", ErrPanic, r)}, usersetMessageChan)
			}

			close(usersetMessageChan)
		}()

		for {
			res, err := iter.Next(ctx)
			if err != nil {
				if storage.IterIsDoneOrCancelled(err) {
					return
				}
				concurrency.TrySendThroughChannel(ctx, usersetMessage{err: err}, usersetMessageChan)
				return
			}
			concurrency.TrySendThroughChannel(ctx, usersetMessage{userset: res}, usersetMessageChan)
		}
	}()

	return usersetMessageChan
}

// processUsersetMessage will add the userset in the primarySet.
// In addition, it returns whether the userset exists in secondarySet.
// This is used to find the intersection between userset from user and userset from object.
func processUsersetMessage(userset string,
	primarySet *hashset.Set,
	secondarySet *hashset.Set) bool {
	primarySet.Add(userset)
	return secondarySet.Contains(userset)
}
