package pl.wojciechkarpiel.jhou.alpha;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import pl.wojciechkarpiel.jhou.ast.Abstraction;
import pl.wojciechkarpiel.jhou.ast.Application;
import pl.wojciechkarpiel.jhou.ast.Constant;
import pl.wojciechkarpiel.jhou.ast.Term;
import pl.wojciechkarpiel.jhou.ast.Variable;
import pl.wojciechkarpiel.jhou.ast.type.Type;
import pl.wojciechkarpiel.jhou.ast.util.Visitor;
import pl.wojciechkarpiel.jhou.substitution.Substitution;
import pl.wojciechkarpiel.jhou.substitution.SubstitutionPair;
import pl.wojciechkarpiel.jhou.termHead.BetaEtaNormal;
import pl.wojciechkarpiel.jhou.termHead.Head;
import pl.wojciechkarpiel.jhou.util.ListUtil;
import pl.wojciechkarpiel.jhou.util.MapUtil;
import pl.wojciechkarpiel.jhou.util.Pair;

/* loaded from: input_file:pl/wojciechkarpiel/jhou/alpha/AlphaEqual.class */
public class AlphaEqual {

    /* loaded from: input_file:pl/wojciechkarpiel/jhou/alpha/AlphaEqual$BenPair.class */
    public static class BenPair extends Pair<BetaEtaNormal, BetaEtaNormal> {
        public BenPair(BetaEtaNormal betaEtaNormal, BetaEtaNormal betaEtaNormal2) {
            super(betaEtaNormal, betaEtaNormal2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:pl/wojciechkarpiel/jhou/alpha/AlphaEqual$LazyAlphaEqual.class */
    public static class LazyAlphaEqual {
        private final MapUtil<Variable, Variable> leftSub = new MapUtil<>(new HashMap());
        private final MapUtil<Variable, Variable> rightSub = new MapUtil<>(new HashMap());

        static boolean isAlphaEqualLazy(Term term, Term term2) {
            return new LazyAlphaEqual().alphaEqual(term, term2);
        }

        private LazyAlphaEqual() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean alphaEqual(Term term, final Term term2) {
            return ((Boolean) term.visit(new Visitor<Boolean>() { // from class: pl.wojciechkarpiel.jhou.alpha.AlphaEqual.LazyAlphaEqual.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
                public Boolean visitConstant(Constant constant) {
                    return Boolean.valueOf(constant.equals(term2));
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
                public Boolean visitVariable(Variable variable) {
                    Variable variable2 = (Variable) LazyAlphaEqual.this.leftSub.get(variable).orElse(variable);
                    if (term2 instanceof Variable) {
                        return Boolean.valueOf(variable2.equals((Variable) LazyAlphaEqual.this.rightSub.get((Variable) term2).orElse((Variable) term2)));
                    }
                    return false;
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
                public Boolean visitApplication(Application application) {
                    if (!(term2 instanceof Application)) {
                        return false;
                    }
                    Application application2 = (Application) term2;
                    return Boolean.valueOf(LazyAlphaEqual.this.alphaEqual(application.getFunction(), application2.getFunction()) && LazyAlphaEqual.this.alphaEqual(application.getArgument(), application2.getArgument()));
                }

                /* JADX WARN: Can't rename method to resolve collision */
                @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
                public Boolean visitAbstraction(Abstraction abstraction) {
                    if (!(term2 instanceof Abstraction)) {
                        return false;
                    }
                    Abstraction abstraction2 = (Abstraction) term2;
                    if (abstraction2.getVariable().equals(abstraction.getVariable())) {
                        return Boolean.valueOf(LazyAlphaEqual.this.alphaEqual(abstraction.getBody(), abstraction2.getBody()));
                    }
                    if (!abstraction2.getVariable().getType().equals(abstraction.getVariable().getType())) {
                        return false;
                    }
                    Variable freshVariable = Variable.freshVariable(abstraction2.getVariable().getType());
                    return (Boolean) LazyAlphaEqual.this.leftSub.withMapping(abstraction.getVariable(), freshVariable, () -> {
                        return (Boolean) LazyAlphaEqual.this.rightSub.withMapping(abstraction2.getVariable(), freshVariable, () -> {
                            return Boolean.valueOf(LazyAlphaEqual.this.alphaEqual(abstraction.getBody(), abstraction2.getBody()));
                        });
                    });
                }
            })).booleanValue();
        }
    }

    private AlphaEqual() {
    }

    public static boolean isAlphaEqual(Abstraction abstraction, Abstraction abstraction2) {
        return LazyAlphaEqual.isAlphaEqualLazy(abstraction, abstraction2);
    }

    public static boolean headAlphaUnifiable(Term term, Term term2) {
        return alphaEqualizeHeading(BetaEtaNormal.normalize(term), BetaEtaNormal.normalize(term2)).isPresent();
    }

    public static Optional<BenPair> alphaEqualizeHeading(BetaEtaNormal betaEtaNormal, BetaEtaNormal betaEtaNormal2) {
        if (equalHeadings(betaEtaNormal, betaEtaNormal2) && betaEtaNormal.getArguments().size() == betaEtaNormal2.getArguments().size()) {
            int size = betaEtaNormal.getBinder().size();
            ArrayList arrayList = new ArrayList(size);
            int size2 = new HashSet(betaEtaNormal.getBinder()).size();
            int size3 = new HashSet(betaEtaNormal2.getBinder()).size();
            int size4 = betaEtaNormal2.getBinder().size();
            int size5 = betaEtaNormal.getBinder().size();
            if (size2 != size3 || size3 != size4 || size4 != size5) {
                throw new RuntimeException();
            }
            ArrayList arrayList2 = new ArrayList(size);
            ArrayList arrayList3 = new ArrayList(size);
            for (int i = 0; i < size; i++) {
                Type type = betaEtaNormal.getBinder().get(i).getType();
                if (!type.equals(betaEtaNormal2.getBinder().get(i).getType())) {
                    throw new RuntimeException();
                }
                Variable freshVariable = Variable.freshVariable(type);
                arrayList.add(freshVariable);
                arrayList2.add(new SubstitutionPair(betaEtaNormal.getBinder().get(i), freshVariable));
                arrayList3.add(new SubstitutionPair(betaEtaNormal2.getBinder().get(i), freshVariable));
            }
            return Optional.of(new BenPair(substitute(new Substitution(arrayList2), arrayList, betaEtaNormal), substitute(new Substitution(arrayList3), arrayList, betaEtaNormal2)));
        }
        return Optional.empty();
    }

    private static BetaEtaNormal substitute(Substitution substitution, List<Variable> list, BetaEtaNormal betaEtaNormal) {
        Head fromTerm = Head.fromTerm(substitution.substitute(betaEtaNormal.getHead().getTerm()));
        Stream<Term> stream = betaEtaNormal.getArguments().stream();
        Objects.requireNonNull(substitution);
        return BetaEtaNormal.fromFakeNormal(fromTerm, list, (List) stream.map(substitution::substitute).collect(Collectors.toList()));
    }

    private static boolean equalHeadings(BetaEtaNormal betaEtaNormal, BetaEtaNormal betaEtaNormal2) {
        return BetaEtaNormal.fromFakeNormal(betaEtaNormal.getHead(), betaEtaNormal.getBinder(), ListUtil.of(new Term[0])).backToTerm().equals(BetaEtaNormal.fromFakeNormal(betaEtaNormal2.getHead(), betaEtaNormal2.getBinder(), ListUtil.of(new Term[0])).backToTerm());
    }
}
