/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.listener.TrainingListener;
import java.time.Duration;

public final class EarlyStoppingListener
implements TrainingListener {
    private final double objectiveSuccess;
    private final int minEpochs;
    private final long maxMillis;
    private final double earlyStopPctImprovement;
    private final int epochPatience;
    private long startTimeMills;
    private double prevLoss;
    private int numberOfEpochsWithoutImprovements;

    private EarlyStoppingListener(double objectiveSuccess, int minEpochs, long maxMillis, double earlyStopPctImprovement, int earlyStopPatience) {
        this.objectiveSuccess = objectiveSuccess;
        this.minEpochs = minEpochs;
        this.maxMillis = maxMillis;
        this.earlyStopPctImprovement = earlyStopPctImprovement;
        this.epochPatience = earlyStopPatience;
    }

    @Override
    public void onEpoch(Trainer trainer) {
        int currentEpoch = trainer.getTrainingResult().getEpoch();
        double loss = EarlyStoppingListener.getLoss(trainer.getTrainingResult());
        if (currentEpoch >= this.minEpochs) {
            if (loss < this.objectiveSuccess) {
                throw new EarlyStoppedException(currentEpoch, String.format("validation loss %s < objectiveSuccess %s", loss, this.objectiveSuccess));
            }
            long elapsedMillis = System.currentTimeMillis() - this.startTimeMills;
            if (elapsedMillis >= this.maxMillis) {
                throw new EarlyStoppedException(currentEpoch, String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, this.maxMillis));
            }
            if (Double.isFinite(this.prevLoss)) {
                boolean improved;
                double goalImprovement = this.prevLoss * (100.0 - this.earlyStopPctImprovement) / 100.0;
                boolean bl = improved = loss <= goalImprovement;
                if (improved) {
                    this.numberOfEpochsWithoutImprovements = 0;
                } else {
                    ++this.numberOfEpochsWithoutImprovements;
                    if (this.numberOfEpochsWithoutImprovements >= this.epochPatience) {
                        throw new EarlyStoppedException(currentEpoch, String.format("failed to achieve %s%% improvement %s times in a row", this.earlyStopPctImprovement, this.epochPatience));
                    }
                }
            }
        }
        if (Double.isFinite(loss)) {
            this.prevLoss = loss;
        }
    }

    private static double getLoss(TrainingResult trainingResult) {
        Float vLoss = trainingResult.getValidateLoss();
        if (vLoss != null) {
            return vLoss.floatValue();
        }
        Float tLoss = trainingResult.getTrainLoss();
        if (tLoss == null) {
            return Double.NaN;
        }
        return tLoss.floatValue();
    }

    @Override
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
    }

    @Override
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
    }

    @Override
    public void onTrainingBegin(Trainer trainer) {
        this.startTimeMills = System.currentTimeMillis();
        this.prevLoss = Double.NaN;
        this.numberOfEpochsWithoutImprovements = 0;
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class EarlyStoppedException
    extends RuntimeException {
        private static final long serialVersionUID = 1L;
        private final int stopEpoch;

        public EarlyStoppedException(int stopEpoch, String message) {
            super(message);
            this.stopEpoch = stopEpoch;
        }

        public int getStopEpoch() {
            return this.stopEpoch;
        }
    }

    public static final class Builder {
        private final double objectiveSuccess;
        private int minEpochs = 0;
        private long maxMillis = Long.MAX_VALUE;
        private double earlyStopPctImprovement = 0.0;
        private int epochPatience = 0;

        public Builder() {
            this.objectiveSuccess = 0.0;
        }

        public Builder optMinEpochs(int minEpochs) {
            this.minEpochs = minEpochs;
            return this;
        }

        public Builder optMaxDuration(Duration duration) {
            this.maxMillis = duration.toMillis();
            return this;
        }

        public Builder optMaxMillis(int maxMillis) {
            this.maxMillis = maxMillis;
            return this;
        }

        public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) {
            this.earlyStopPctImprovement = earlyStopPctImprovement;
            return this;
        }

        public Builder optEpochPatience(int epochPatience) {
            this.epochPatience = epochPatience;
            return this;
        }

        public EarlyStoppingListener build() {
            return new EarlyStoppingListener(this.objectiveSuccess, this.minEpochs, this.maxMillis, this.earlyStopPctImprovement, this.epochPatience);
        }
    }
}

