package org.apache.hama.ml.recommendation.cf.function;

import org.apache.hama.commons.io.VectorWritable;
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DenseDoubleVector;
import org.apache.hama.commons.math.DoubleVector;
import org.apache.hama.ml.recommendation.cf.function.OnlineUpdate;

/* loaded from: input_file:org/apache/hama/ml/recommendation/cf/function/MeanAbsError.class */
public class MeanAbsError implements OnlineUpdate.Function {
    private static final double TETTA = 0.01d;
    private DoubleVector zeroVector = null;

    @Override // org.apache.hama.ml.recommendation.cf.function.OnlineUpdate.Function
    public OnlineUpdate.OutputStructure compute(OnlineUpdate.InputStructure inputStructure) {
        OnlineUpdate.OutputStructure outputStructure = new OnlineUpdate.OutputStructure();
        int length = inputStructure.user.getVector().getLength();
        if (this.zeroVector == null) {
            this.zeroVector = new DenseDoubleVector(length, 0.0d);
        }
        DoubleVector doubleVector = this.zeroVector;
        DoubleVector doubleVector2 = this.zeroVector;
        boolean z = inputStructure.userFeatures != null;
        boolean z2 = inputStructure.itemFeatures != null;
        if (z2) {
            doubleVector = inputStructure.itemFeatureFactorized.multiplyVector(inputStructure.itemFeatures.getVector());
        }
        if (z) {
            doubleVector2 = inputStructure.userFeatureFactorized.multiplyVector(inputStructure.userFeatures.getVector());
        }
        DoubleVector add = inputStructure.item.getVector().add(doubleVector);
        DoubleVector add2 = inputStructure.user.getVector().add(doubleVector2);
        double sum = inputStructure.expectedScore.get() - add2.multiply(add).sum();
        outputStructure.itemFactorized = new VectorWritable(inputStructure.item.getVector().add(add2.multiply(0.02d * sum)));
        outputStructure.userFactorized = new VectorWritable(inputStructure.user.getVector().add(add.multiply(0.02d * sum)));
        if (z2) {
            DenseDoubleVector[] denseDoubleVectorArr = new DenseDoubleVector[length];
            for (int i = 0; i < length; i++) {
                denseDoubleVectorArr[i] = inputStructure.itemFeatureFactorized.getRowVector(i).multiply(add2.get(i));
            }
            outputStructure.itemFeatureFactorized = inputStructure.itemFeatureFactorized.add(new DenseDoubleMatrix(denseDoubleVectorArr).multiply(0.02d * sum));
        }
        if (z) {
            DenseDoubleVector[] denseDoubleVectorArr2 = new DenseDoubleVector[length];
            for (int i2 = 0; i2 < length; i2++) {
                denseDoubleVectorArr2[i2] = inputStructure.userFeatureFactorized.getRowVector(i2).multiply(add.get(i2));
            }
            outputStructure.userFeatureFactorized = inputStructure.userFeatureFactorized.add(new DenseDoubleMatrix(denseDoubleVectorArr2).multiply(0.02d * sum));
        }
        return outputStructure;
    }

    @Override // org.apache.hama.ml.recommendation.cf.function.OnlineUpdate.Function
    public double predict(OnlineUpdate.InputStructure inputStructure) {
        int length = inputStructure.user.getVector().getLength();
        if (this.zeroVector == null) {
            this.zeroVector = new DenseDoubleVector(length, 0.0d);
        }
        DoubleVector doubleVector = this.zeroVector;
        DoubleVector doubleVector2 = this.zeroVector;
        boolean z = inputStructure.userFeatures != null;
        if (inputStructure.itemFeatures != null) {
            doubleVector = inputStructure.itemFeatureFactorized.multiplyVector(inputStructure.itemFeatures.getVector());
        }
        if (z) {
            doubleVector2 = inputStructure.userFeatureFactorized.multiplyVector(inputStructure.userFeatures.getVector());
        }
        return inputStructure.user.getVector().add(doubleVector2).multiply(inputStructure.item.getVector().add(doubleVector)).sum();
    }
}
