package com.flipkart.fdp.ml.adapter;

import com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo;
import java.util.LinkedHashSet;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeType;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.sql.DataFrame;

/* loaded from: input_file:com/flipkart/fdp/ml/adapter/OneHotEncoderModelInfoAdapter.class */
public class OneHotEncoderModelInfoAdapter extends AbstractModelInfoAdapter<OneHotEncoder, OneHotEncoderModelInfo> {
    @Override // com.flipkart.fdp.ml.adapter.AbstractModelInfoAdapter
    public OneHotEncoderModelInfo getModelInfo(OneHotEncoder oneHotEncoder, DataFrame dataFrame) {
        OneHotEncoderModelInfo oneHotEncoderModelInfo = new OneHotEncoderModelInfo();
        String inputCol = oneHotEncoder.getInputCol();
        int i = -1;
        Attribute fromStructField = Attribute.fromStructField(dataFrame.schema().apply(inputCol));
        if (fromStructField.attrType() == AttributeType.Nominal()) {
            i = ((String[]) Attribute.fromStructField(dataFrame.schema().apply(inputCol)).values().get()).length;
        } else if (fromStructField.attrType() == AttributeType.Binary()) {
            i = ((String[]) Attribute.fromStructField(dataFrame.schema().apply(inputCol)).values().get()).length;
        }
        oneHotEncoderModelInfo.setNumTypes(i - 1);
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        linkedHashSet.add(oneHotEncoder.getInputCol());
        oneHotEncoderModelInfo.setInputKeys(linkedHashSet);
        oneHotEncoderModelInfo.setOutputKey(oneHotEncoder.getOutputCol());
        return oneHotEncoderModelInfo;
    }

    @Override // com.flipkart.fdp.ml.adapter.ModelInfoAdapter
    public Class<OneHotEncoder> getSource() {
        return OneHotEncoder.class;
    }

    @Override // com.flipkart.fdp.ml.adapter.ModelInfoAdapter
    public Class<OneHotEncoderModelInfo> getTarget() {
        return OneHotEncoderModelInfo.class;
    }
}
