package ws.palladian.kaggle.restaurants.classifier.nn;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import ws.palladian.classification.utils.NoNormalizer;
import ws.palladian.classification.utils.Normalization;
import ws.palladian.core.Model;

/* loaded from: input_file:ws/palladian/kaggle/restaurants/classifier/nn/MultiLayerNetworkModel.class */
public class MultiLayerNetworkModel implements Model {
    private static final long serialVersionUID = 1;
    private final MultiLayerNetwork model;
    private final List<String> categoryNames;
    private final List<String> featureNames;
    private final Normalization normalization;

    MultiLayerNetworkModel(MultiLayerNetwork multiLayerNetwork, List<String> list, List<String> list2) {
        this(multiLayerNetwork, list, list2, NoNormalizer.NO_NORMALIZATION);
    }

    MultiLayerNetworkModel(MultiLayerNetwork multiLayerNetwork, List<String> list, List<String> list2, Normalization normalization) {
        this.model = multiLayerNetwork;
        this.categoryNames = list;
        this.featureNames = list2;
        this.normalization = normalization;
    }

    MultiLayerNetwork getModel() {
        return this.model;
    }

    List<String> getCategoryNames() {
        return Collections.unmodifiableList(this.categoryNames);
    }

    List<String> getFeatureNames() {
        return Collections.unmodifiableList(this.featureNames);
    }

    public Set<String> getCategories() {
        return new HashSet(this.categoryNames);
    }

    Normalization getNormalization() {
        return this.normalization;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Layer layer : this.model.getLayers()) {
            sb.append(layer.getParam("W"));
            sb.append('\n');
        }
        return sb.toString();
    }
}
