package pl.wojciechkarpiel.jhou.types.inference;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import pl.wojciechkarpiel.jhou.ast.Abstraction;
import pl.wojciechkarpiel.jhou.ast.Application;
import pl.wojciechkarpiel.jhou.ast.Constant;
import pl.wojciechkarpiel.jhou.ast.Term;
import pl.wojciechkarpiel.jhou.ast.Variable;
import pl.wojciechkarpiel.jhou.ast.type.ArrowType;
import pl.wojciechkarpiel.jhou.ast.type.BaseType;
import pl.wojciechkarpiel.jhou.ast.type.Type;
import pl.wojciechkarpiel.jhou.ast.type.TypeVisitor;
import pl.wojciechkarpiel.jhou.ast.util.Id;
import pl.wojciechkarpiel.jhou.ast.util.Visitor;
import pl.wojciechkarpiel.jhou.substitution.Substitution;
import pl.wojciechkarpiel.jhou.types.TypeCalculator;
import pl.wojciechkarpiel.jhou.unifier.AllowedTypeInference;
import pl.wojciechkarpiel.jhou.unifier.DisagreementPair;
import pl.wojciechkarpiel.jhou.unifier.DisagreementSet;
import pl.wojciechkarpiel.jhou.unifier.SolutionIterator;
import pl.wojciechkarpiel.jhou.unifier.tree.WorkWorkNode;
import pl.wojciechkarpiel.jhou.util.DevNullPrintStream;
import pl.wojciechkarpiel.jhou.util.ListUtil;
import pl.wojciechkarpiel.jhou.util.MapUtil;

/* loaded from: input_file:pl/wojciechkarpiel/jhou/types/inference/TypeInference.class */
public class TypeInference {
    private final PrintStream printStream;
    private final AllowedTypeInference allowedInference;
    private static final Type DUMMY_TYPE = BaseType.freshBaseType("dummy_type");
    private static final Constant ARROW = Constant.freshConstant(new ArrowType(DUMMY_TYPE, new ArrowType(DUMMY_TYPE, DUMMY_TYPE)), "ARR");
    private Set<Type> newTypes = null;
    private final Map<Type, Constant> typeConstantMap = new HashMap();
    private final Map<Constant, Type> constanTypeMap = new HashMap();
    private final List<TermPair> annotations = new ArrayList();
    private final MapUtil<Variable, Variable> varCache = new MapUtil<>(new HashMap());
    private final MapUtil<Constant, Variable> conCache = new MapUtil<>(new HashMap());
    private final MapUtil<Variable, Variable> bound = new MapUtil<>(new HashMap());

    /* loaded from: input_file:pl/wojciechkarpiel/jhou/types/inference/TypeInference$InferenceHasArbitrarySolutionsException.class */
    public static class InferenceHasArbitrarySolutionsException extends RuntimeException {
    }

    /* loaded from: input_file:pl/wojciechkarpiel/jhou/types/inference/TypeInference$InferenceRequiredButNotAllowedException.class */
    public static class InferenceRequiredButNotAllowedException extends RuntimeException {
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:pl/wojciechkarpiel/jhou/types/inference/TypeInference$TermPair.class */
    public static class TermPair {
        private final Term a;
        private final Term b;

        private TermPair(Term term, Term term2) {
            this.a = term;
            this.b = term2;
        }
    }

    public static Term inferMissing(Term term) {
        return inferMissing((List<Term>) ListUtil.of(term)).get(0);
    }

    public static List<Term> inferMissing(List<Term> list) {
        return inferMissing(list, System.out);
    }

    public static List<Term> inferMissing(List<Term> list, PrintStream printStream) {
        return inferMissing(list, AllowedTypeInference.PERMISSIVE, printStream);
    }

    public static List<Term> inferMissing(List<Term> list, AllowedTypeInference allowedTypeInference) {
        return inferMissing(list, allowedTypeInference, System.out);
    }

    public static List<Term> inferMissing(List<Term> list, AllowedTypeInference allowedTypeInference, PrintStream printStream) {
        return inferMissingInternal(list, allowedTypeInference, printStream);
    }

    private static List<Term> inferMissingInternal(List<Term> list, AllowedTypeInference allowedTypeInference, PrintStream printStream) {
        if (list.stream().noneMatch(TypeInference::needsInference)) {
            printStream.println("No need for type inference, types fully instantiated");
            return list;
        }
        if (allowedTypeInference == AllowedTypeInference.NO_INFERENCE_ALLOWED) {
            throw new InferenceRequiredButNotAllowedException();
        }
        TypeInference typeInference = new TypeInference(printStream, allowedTypeInference);
        Iterator<Term> it = list.iterator();
        while (it.hasNext()) {
            typeInference.annotate(it.next());
        }
        List<DisagreementPair> collectDps = typeInference.collectDps();
        for (int i = 0; i < list.size() - 1; i++) {
            collectDps.add(new DisagreementPair(typeInference.getAnon(list.get(i)), typeInference.getAnon(list.get(i + 1))));
        }
        Substitution next = new SolutionIterator(new WorkWorkNode(null, Substitution.empty(), new DisagreementSet(collectDps), true), DevNullPrintStream.INSTANCE).next();
        return (List) list.stream().map(term -> {
            return typeInference.recreateWithTypes(term, next);
        }).collect(Collectors.toList());
    }

    public static boolean needsInference(Term term) {
        return ((Boolean) term.visit(new Visitor<Boolean>() { // from class: pl.wojciechkarpiel.jhou.types.inference.TypeInference.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Boolean visitConstant(Constant constant) {
                return Boolean.valueOf(constant.getType() == null);
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Boolean visitVariable(Variable variable) {
                return Boolean.valueOf(variable.getType() == null);
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Boolean visitApplication(Application application) {
                return Boolean.valueOf(TypeInference.needsInference(application.getFunction()) || TypeInference.needsInference(application.getArgument()));
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Boolean visitAbstraction(Abstraction abstraction) {
                return Boolean.valueOf(TypeInference.needsInference(abstraction.getVariable()) || TypeInference.needsInference(abstraction.getBody()));
            }
        })).booleanValue();
    }

    private TypeInference(PrintStream printStream, AllowedTypeInference allowedTypeInference) {
        this.printStream = printStream;
        this.allowedInference = allowedTypeInference;
    }

    public final Set<Type> getNewTypes() {
        if (this.newTypes == null) {
            this.newTypes = new HashSet();
        }
        return this.newTypes;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Type getT(Term term, Substitution substitution) {
        Term term2 = null;
        if (term instanceof Variable) {
            term2 = getAnon(term);
        }
        if (term instanceof Constant) {
            term2 = getAnon(term);
        }
        if (term2 != null) {
            return fakeVarToRealType(substitution.substitute(term2));
        }
        throw new RuntimeException();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Term recreateWithTypes(Term term, final Substitution substitution) {
        return (Term) term.visit(new Visitor<Term>() { // from class: pl.wojciechkarpiel.jhou.types.inference.TypeInference.2
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Term visitConstant(Constant constant) {
                return new Constant(constant.getId(), TypeInference.this.getT(constant, substitution), constant.toString());
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Term visitVariable(Variable variable) {
                return new Variable(variable.getId(), TypeInference.this.getT(variable, substitution), variable.toString());
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Term visitApplication(Application application) {
                return new Application(TypeInference.this.recreateWithTypes(application.getFunction(), substitution), TypeInference.this.recreateWithTypes(application.getArgument(), substitution));
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Term visitAbstraction(Abstraction abstraction) {
                return new Abstraction((Variable) TypeInference.this.recreateWithTypes(abstraction.getVariable(), substitution), TypeInference.this.recreateWithTypes(abstraction.getBody(), substitution));
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Type fakeVarToRealType(Term term) {
        return (Type) term.visit(new Visitor<Type>() { // from class: pl.wojciechkarpiel.jhou.types.inference.TypeInference.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Type visitConstant(Constant constant) {
                Type type = (Type) TypeInference.this.constanTypeMap.get(constant);
                if (type != null) {
                    return type;
                }
                if (AllowedTypeInference.PERMISSIVE != TypeInference.this.allowedInference) {
                    throw new InferenceHasArbitrarySolutionsException();
                }
                Id uniqueId = Id.uniqueId();
                BaseType baseType = new BaseType(uniqueId, "infered_arbitrarty_" + uniqueId.getId());
                TypeInference.this.printStream.println("Creating a new, arbitrary type: " + baseType);
                TypeInference.this.getNewTypes().add(baseType);
                TypeInference.this.constanTypeMap.put(constant, baseType);
                return TypeInference.this.fakeVarToRealType(constant);
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Type visitVariable(Variable variable) {
                throw new RuntimeException();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Type visitApplication(Application application) {
                Application application2 = (Application) application.getFunction();
                if (application2.getFunction() != TypeInference.ARROW) {
                    throw new RuntimeException();
                }
                return new ArrowType(TypeInference.this.fakeVarToRealType(application2.getArgument()), TypeInference.this.fakeVarToRealType(application.getArgument()));
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Type visitAbstraction(Abstraction abstraction) {
                throw new RuntimeException();
            }
        });
    }

    private List<DisagreementPair> collectDps() {
        ArrayList arrayList = new ArrayList();
        for (TermPair termPair : this.annotations) {
            Term term = termPair.a;
            Term term2 = termPair.b;
            if (term instanceof Application) {
                Application application = (Application) term;
                Term anon = getAnon(application.getFunction());
                Application application2 = new Application(new Application(ARROW, getAnon(application.getArgument())), term2);
                TypeCalculator.ensureEqualTypes(anon, application2);
                arrayList.add(new DisagreementPair(anon, application2));
            }
            Type type = null;
            if (term instanceof Constant) {
                type = ((Constant) term).getType();
            } else if (term instanceof Variable) {
                type = ((Variable) term).getType();
            }
            if (type != null) {
                Term ftchType = ftchType(type);
                TypeCalculator.ensureEqualTypes(ftchType, term2);
                arrayList.add(new DisagreementPair(ftchType, term2));
            }
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Term ftchType(Type type) {
        return (Term) type.visit(new TypeVisitor<Term>() { // from class: pl.wojciechkarpiel.jhou.types.inference.TypeInference.4
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.type.TypeVisitor
            public Term visitBaseType(BaseType baseType) {
                Constant constant;
                if (TypeInference.this.typeConstantMap.containsKey(baseType)) {
                    constant = (Constant) TypeInference.this.typeConstantMap.get(baseType);
                } else {
                    constant = new Constant(Id.uniqueId(), TypeInference.DUMMY_TYPE);
                    TypeInference.this.typeConstantMap.put(baseType, constant);
                }
                TypeInference.this.constanTypeMap.put(constant, baseType);
                return constant;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.type.TypeVisitor
            public Term visitArrowType(ArrowType arrowType) {
                return Application.apply(TypeInference.ARROW, TypeInference.this.ftchType(arrowType.getFrom()), TypeInference.this.ftchType(arrowType.getTo()));
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void addAnon(Term term, Term term2) {
        this.annotations.add(new TermPair(term, term2));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Term getAnon(Term term) {
        for (TermPair termPair : this.annotations) {
            if (termPair.a == term) {
                return termPair.b;
            }
        }
        throw new RuntimeException();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void annotate(Term term) {
        term.visit(new Visitor<Void>() { // from class: pl.wojciechkarpiel.jhou.types.inference.TypeInference.5
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Void visitConstant(Constant constant) {
                Variable freshVariable;
                if (TypeInference.this.conCache.get(constant).isPresent()) {
                    freshVariable = (Variable) TypeInference.this.conCache.get(constant).get();
                } else {
                    freshVariable = Variable.freshVariable(TypeInference.DUMMY_TYPE);
                    TypeInference.this.conCache.put(constant, freshVariable);
                }
                TypeInference.this.addAnon(constant, freshVariable);
                return null;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Void visitVariable(Variable variable) {
                Variable freshVariable;
                if (TypeInference.this.bound.get(variable).isPresent()) {
                    freshVariable = (Variable) TypeInference.this.bound.get(variable).get();
                } else if (TypeInference.this.varCache.get(variable).isPresent()) {
                    freshVariable = (Variable) TypeInference.this.varCache.get(variable).get();
                } else {
                    freshVariable = Variable.freshVariable(TypeInference.DUMMY_TYPE);
                    TypeInference.this.varCache.put(variable, freshVariable);
                }
                TypeInference.this.addAnon(variable, freshVariable);
                return null;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Void visitApplication(Application application) {
                TypeInference.this.addAnon(application, Variable.freshVariable(TypeInference.DUMMY_TYPE));
                TypeInference.this.annotate(application.getArgument());
                TypeInference.this.annotate(application.getFunction());
                return null;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // pl.wojciechkarpiel.jhou.ast.util.Visitor
            public Void visitAbstraction(Abstraction abstraction) {
                Variable freshVariable = Variable.freshVariable(TypeInference.DUMMY_TYPE);
                TypeInference.this.addAnon(abstraction.getVariable(), freshVariable);
                TypeInference.this.bound.withMapping(abstraction.getVariable(), freshVariable, () -> {
                    TypeInference.this.annotate(abstraction.getBody());
                    return null;
                });
                TypeInference.this.addAnon(abstraction, new Application(new Application(TypeInference.ARROW, freshVariable), TypeInference.this.getAnon(abstraction.getBody())));
                return null;
            }
        });
    }
}
