/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.sequence;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Output;
import org.tribuo.hash.HashedFeatureMap;
import org.tribuo.hash.Hasher;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.sequence.ImmutableSequenceDataset;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;

public final class HashingSequenceTrainer<T extends Output<T>>
implements SequenceTrainer<T> {
    private static final Logger logger = Logger.getLogger(HashingSequenceTrainer.class.getName());
    @Config(mandatory=true, description="Trainer to use.")
    private SequenceTrainer<T> innerTrainer;
    @Config(mandatory=true, description="Feature hashing function to use.")
    private Hasher hasher;

    private HashingSequenceTrainer() {
    }

    public HashingSequenceTrainer(SequenceTrainer<T> trainer, Hasher hasher) {
        this.innerTrainer = trainer;
        this.hasher = hasher;
    }

    @Override
    public SequenceModel<T> train(SequenceDataset<T> sequenceExamples, Map<String, Provenance> instanceProvenance) {
        logger.log(Level.INFO, "Before hashing, had " + sequenceExamples.getFeatureIDMap().size() + " features.");
        ImmutableSequenceDataset<T> hashedData = ImmutableSequenceDataset.changeFeatureMap(sequenceExamples, HashedFeatureMap.generateHashedFeatureMap(sequenceExamples.getFeatureIDMap(), this.hasher));
        logger.log(Level.INFO, "After hashing, had " + ((SequenceDataset)hashedData).getFeatureIDMap().size() + " features.");
        SequenceModel<T> model = this.innerTrainer.train(hashedData, instanceProvenance);
        if (!(model.featureIDMap instanceof HashedFeatureMap)) {
            throw new IllegalStateException("Trainer " + this.innerTrainer.getClass().getName() + " does not support hashing.");
        }
        return model;
    }

    @Override
    public int getInvocationCount() {
        return this.innerTrainer.getInvocationCount();
    }

    public String toString() {
        return "HashingSequenceTrainer(trainer=" + this.innerTrainer.toString() + ",hasher=" + this.hasher.toString() + ")";
    }

    public TrainerProvenance getProvenance() {
        return new HashingSequenceTrainerProvenance(this);
    }

    public static class HashingSequenceTrainerProvenance
    extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1L;

        <T extends Output<T>> HashingSequenceTrainerProvenance(HashingSequenceTrainer<T> host) {
            super(host);
        }

        public HashingSequenceTrainerProvenance(Map<String, Provenance> map) {
            super(HashingSequenceTrainerProvenance.extractProvenanceInfo(map));
        }
    }
}

