package org.apache.asterix.optimizer.rules;

import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.asterix.dataflow.data.common.TypeResolverUtil;
import org.apache.asterix.lang.common.util.FunctionUtil;
import org.apache.asterix.om.functions.BuiltinFunctions;
import org.apache.asterix.om.typecomputer.base.TypeCastUtils;
import org.apache.asterix.om.types.IAType;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
import org.apache.hyracks.algebricks.core.algebra.expressions.ScalarFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import org.apache.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;

/* loaded from: input_file:org/apache/asterix/optimizer/rules/InjectTypeCastForSwitchCaseRule.class */
public class InjectTypeCastForSwitchCaseRule implements IAlgebraicRewriteRule {
    private static final Set<FunctionIdentifier> IF_FUNCTIONS = ImmutableSet.of(BuiltinFunctions.IF_MISSING, BuiltinFunctions.IF_NULL, BuiltinFunctions.IF_MISSING_OR_NULL);

    public boolean rewritePost(Mutable<ILogicalOperator> mutable, IOptimizationContext iOptimizationContext) throws AlgebricksException {
        ILogicalOperator iLogicalOperator = (ILogicalOperator) mutable.getValue();
        if (iLogicalOperator.getInputs().isEmpty()) {
            return false;
        }
        iOptimizationContext.computeAndSetTypeEnvironmentForOperator(iLogicalOperator);
        if (!iLogicalOperator.acceptExpressionTransform(mutable2 -> {
            return injectTypeCast(iLogicalOperator, mutable2, iOptimizationContext);
        })) {
            return false;
        }
        iOptimizationContext.computeAndSetTypeEnvironmentForOperator(iLogicalOperator);
        return true;
    }

    private boolean injectTypeCast(ILogicalOperator iLogicalOperator, Mutable<ILogicalExpression> mutable, IOptimizationContext iOptimizationContext) throws AlgebricksException {
        ILogicalExpression iLogicalExpression = (ILogicalExpression) mutable.getValue();
        if (iLogicalExpression.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
            return false;
        }
        boolean z = false;
        AbstractFunctionCallExpression abstractFunctionCallExpression = (AbstractFunctionCallExpression) iLogicalExpression;
        Iterator it = abstractFunctionCallExpression.getArguments().iterator();
        while (it.hasNext()) {
            if (injectTypeCast(iLogicalOperator, (Mutable) it.next(), iOptimizationContext)) {
                iOptimizationContext.computeAndSetTypeEnvironmentForOperator(iLogicalOperator);
                z = true;
            }
        }
        FunctionIdentifier functionIdentifier = abstractFunctionCallExpression.getFunctionIdentifier();
        if (functionIdentifier.equals(BuiltinFunctions.SWITCH_CASE)) {
            if (rewriteSwitchCase(iLogicalOperator, abstractFunctionCallExpression, iOptimizationContext)) {
                z = true;
            }
        } else if (IF_FUNCTIONS.contains(functionIdentifier) && rewriteFunction(iLogicalOperator, abstractFunctionCallExpression, iOptimizationContext)) {
            z = true;
        }
        return z;
    }

    private boolean rewriteSwitchCase(ILogicalOperator iLogicalOperator, AbstractFunctionCallExpression abstractFunctionCallExpression, IOptimizationContext iOptimizationContext) throws AlgebricksException {
        IVariableTypeEnvironment computeInputTypeEnvironment = iLogicalOperator.computeInputTypeEnvironment(iOptimizationContext);
        IAType iAType = (IAType) computeInputTypeEnvironment.getType(abstractFunctionCallExpression);
        List arguments = abstractFunctionCallExpression.getArguments();
        int size = arguments.size();
        boolean z = false;
        int i = 2;
        while (true) {
            int i2 = i;
            if (i2 >= size) {
                return z;
            }
            if (rewriteFunctionArgument((Mutable) arguments.get(i2), iAType, computeInputTypeEnvironment)) {
                z = true;
            }
            i = i2 + (i2 + 2 == size ? 1 : 2);
        }
    }

    private boolean rewriteFunction(ILogicalOperator iLogicalOperator, AbstractFunctionCallExpression abstractFunctionCallExpression, IOptimizationContext iOptimizationContext) throws AlgebricksException {
        IVariableTypeEnvironment computeInputTypeEnvironment = iLogicalOperator.computeInputTypeEnvironment(iOptimizationContext);
        IAType iAType = (IAType) computeInputTypeEnvironment.getType(abstractFunctionCallExpression);
        List arguments = abstractFunctionCallExpression.getArguments();
        int size = arguments.size();
        boolean z = false;
        for (int i = 0; i < size; i++) {
            if (rewriteFunctionArgument((Mutable) arguments.get(i), iAType, computeInputTypeEnvironment)) {
                z = true;
            }
        }
        return z;
    }

    private boolean rewriteFunctionArgument(Mutable<ILogicalExpression> mutable, IAType iAType, IVariableTypeEnvironment iVariableTypeEnvironment) throws AlgebricksException {
        ILogicalExpression iLogicalExpression = (ILogicalExpression) mutable.getValue();
        IAType iAType2 = (IAType) iVariableTypeEnvironment.getType(iLogicalExpression);
        if (!TypeResolverUtil.needsCast(iAType, iAType2)) {
            return false;
        }
        ScalarFunctionCallExpression scalarFunctionCallExpression = new ScalarFunctionCallExpression(FunctionUtil.getFunctionInfo(BuiltinFunctions.CAST_TYPE), new ArrayList(Collections.singletonList(new MutableObject(iLogicalExpression))));
        scalarFunctionCallExpression.setSourceLocation(iLogicalExpression.getSourceLocation());
        TypeCastUtils.setRequiredAndInputTypes(scalarFunctionCallExpression, iAType, iAType2);
        mutable.setValue(scalarFunctionCallExpression);
        return true;
    }
}
