package org.headrest.lang.typing;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.headrest.lang.grammarutils.ASTFactory;
import org.headrest.lang.headREST.Conjunction;
import org.headrest.lang.headREST.Disjunction;
import org.headrest.lang.headREST.Expression;
import org.headrest.lang.headREST.InType;
import org.headrest.lang.headREST.NormalDisjunction;
import org.headrest.lang.headREST.NormalRefinedConjunction;
import org.headrest.lang.headREST.Type;
import org.headrest.lang.headREST.Variable;
import org.headrest.lang.headREST.VariableType;
import org.headrest.lang.headREST.WhereType;
import org.headrest.lang.validation.Environment;
import org.headrest.lang.validation.HeadRESTSwitchWithDerived;
import org.headrest.lang.validation.VariableSubstitutionExpression;

/* loaded from: input_file:org/headrest/lang/typing/TypeNormalizer.class */
public class TypeNormalizer extends HeadRESTSwitchWithDerived<NormalDisjunction> {
    private ASTFactory factory = ASTFactory.getInstance();
    private Environment environment = Environment.getInstance();

    public NormalDisjunction normalize(Type type) {
        return (NormalDisjunction) doSwitch(type);
    }

    @Override // org.headrest.lang.headREST.util.HeadRESTSwitch
    public NormalDisjunction caseType(Type type) {
        return this.factory.createNormalDisjunction(this.factory.createNormalRefinedConjunction(this.factory.freshVariableName(), type, this.factory.createTrueValue()));
    }

    @Override // org.headrest.lang.headREST.util.HeadRESTSwitch
    public NormalDisjunction caseVariableType(VariableType variableType) {
        return normalize(this.environment.getType(variableType.getName()));
    }

    @Override // org.headrest.lang.headREST.util.HeadRESTSwitch
    public NormalDisjunction caseWhereType(WhereType whereType) {
        NormalDisjunction normalize = normalize(whereType.getBind().getType());
        ArrayList arrayList = new ArrayList();
        for (NormalRefinedConjunction normalRefinedConjunction : normalize.getDisjuncts()) {
            arrayList.add(conjRD(normalRefinedConjunction, normR(this.factory.createNormalRefinedConjunction(whereType.getBind().getName(), (List<Type>) normalRefinedConjunction.getConjuncts(), whereType.getExpression()))));
        }
        return this.factory.joinToNormalDisjunction(arrayList);
    }

    private NormalDisjunction normR(NormalRefinedConjunction normalRefinedConjunction) {
        Expression expression = normalRefinedConjunction.getExpression();
        if (!(expression instanceof InType) || !(((InType) expression).getExpression() instanceof Variable) || !((Variable) ((InType) expression).getExpression()).getName().equals(normalRefinedConjunction.getName())) {
            return expression instanceof Disjunction ? this.factory.joinToNormalDisjunction(normR(this.factory.createNormalRefinedConjunction(normalRefinedConjunction.getName(), (List<Type>) normalRefinedConjunction.getConjuncts(), ((Disjunction) expression).getLeft())), normR(this.factory.createNormalRefinedConjunction(normalRefinedConjunction.getName(), (List<Type>) normalRefinedConjunction.getConjuncts(), ((Disjunction) expression).getRight()))) : expression instanceof Conjunction ? conjDD(normR(this.factory.createNormalRefinedConjunction(normalRefinedConjunction.getName(), (List<Type>) normalRefinedConjunction.getConjuncts(), ((Conjunction) expression).getLeft())), normR(this.factory.createNormalRefinedConjunction(normalRefinedConjunction.getName(), (List<Type>) normalRefinedConjunction.getConjuncts(), ((Conjunction) expression).getRight()))) : this.factory.createNormalDisjunction(normalRefinedConjunction);
        }
        Stream concat = Stream.concat(normalRefinedConjunction.getConjuncts().stream(), Stream.of(((InType) expression).getType()));
        Collector list = Collectors.toList();
        ASTFactory aSTFactory = this.factory;
        aSTFactory.getClass();
        return normalize((Type) concat.collect(Collectors.collectingAndThen(list, aSTFactory::createIntersectionType)));
    }

    private NormalDisjunction conjDD(NormalDisjunction normalDisjunction, NormalDisjunction normalDisjunction2) {
        Stream map = normalDisjunction.getDisjuncts().stream().map(normalRefinedConjunction -> {
            return conjRD(normalRefinedConjunction, normalDisjunction2);
        });
        Collector list = Collectors.toList();
        ASTFactory aSTFactory = this.factory;
        aSTFactory.getClass();
        return (NormalDisjunction) map.collect(Collectors.collectingAndThen(list, aSTFactory::joinToNormalDisjunction));
    }

    private NormalDisjunction conjRD(NormalRefinedConjunction normalRefinedConjunction, NormalDisjunction normalDisjunction) {
        Stream map = normalDisjunction.getDisjuncts().stream().map(normalRefinedConjunction2 -> {
            return conjRR(normalRefinedConjunction, normalRefinedConjunction2);
        });
        Collector list = Collectors.toList();
        ASTFactory aSTFactory = this.factory;
        aSTFactory.getClass();
        return (NormalDisjunction) map.collect(Collectors.collectingAndThen(list, aSTFactory::createNormalDisjunction));
    }

    private NormalRefinedConjunction conjRR(NormalRefinedConjunction normalRefinedConjunction, NormalRefinedConjunction normalRefinedConjunction2) {
        String freshVariableName = this.factory.freshVariableName();
        Stream concat = Stream.concat(normalRefinedConjunction.getConjuncts().stream(), normalRefinedConjunction2.getConjuncts().stream());
        Collector list = Collectors.toList();
        ASTFactory aSTFactory = this.factory;
        aSTFactory.getClass();
        return this.factory.createNormalRefinedConjunction(freshVariableName, (Type) concat.collect(Collectors.collectingAndThen(list, aSTFactory::createIntersectionType)), this.factory.createConjunction(new VariableSubstitutionExpression(freshVariableName, normalRefinedConjunction.getName()).substitute(normalRefinedConjunction.getExpression()), new VariableSubstitutionExpression(freshVariableName, normalRefinedConjunction2.getName()).substitute(normalRefinedConjunction2.getExpression())));
    }
}
