/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BasicRelationStatistics;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Rules;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class PushJoinIntoTableScan
implements Rule<JoinNode> {
    private static final Capture<TableScanNode> LEFT_TABLE_SCAN = Capture.newCapture();
    private static final Capture<TableScanNode> RIGHT_TABLE_SCAN = Capture.newCapture();
    private static final Pattern<JoinNode> PATTERN = Patterns.join().with(Patterns.Join.left().matching(Patterns.tableScan().capturedAs(LEFT_TABLE_SCAN))).with(Patterns.Join.right().matching(Patterns.tableScan().capturedAs(RIGHT_TABLE_SCAN)));
    private final Metadata metadata;

    public PushJoinIntoTableScan(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAllowPushdownIntoConnectors(session);
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (joinNode.isCrossJoin()) {
            return Rule.Result.empty();
        }
        TableScanNode left = (TableScanNode)captures.get(LEFT_TABLE_SCAN);
        TableScanNode right = (TableScanNode)captures.get(RIGHT_TABLE_SCAN);
        Verify.verify((!left.isUpdateTarget() && !right.isUpdateTarget() ? 1 : 0) != 0, (String)"Unexpected Join over for-update table scan", (Object[])new Object[0]);
        Expression effectiveFilter = this.getEffectiveFilter(joinNode);
        FilterSplitResult filterSplitResult = this.splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context);
        if (!filterSplitResult.getRemainingFilter().equals((Object)BooleanLiteral.TRUE_LITERAL)) {
            return Rule.Result.empty();
        }
        if (left.getEnforcedConstraint().isNone() || right.getEnforcedConstraint().isNone()) {
            return Rule.Result.empty();
        }
        Map leftAssignments = (Map)left.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((Symbol)entry.getKey()).getName(), Map.Entry::getValue));
        Map rightAssignments = (Map)right.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((Symbol)entry.getKey()).getName(), Map.Entry::getValue));
        JoinStatistics joinStatistics = this.getJoinStatistics(joinNode, left, right, context);
        Optional<JoinApplicationResult<TableHandle>> joinApplicationResult = this.metadata.applyJoin(context.getSession(), this.getJoinType(joinNode), left.getTable(), right.getTable(), filterSplitResult.getPushableConditions(), leftAssignments, rightAssignments, joinStatistics);
        if (joinApplicationResult.isEmpty()) {
            return Rule.Result.empty();
        }
        TableHandle handle = (TableHandle)joinApplicationResult.get().getTableHandle();
        Map leftColumnHandlesMapping = joinApplicationResult.get().getLeftColumnHandles();
        Map rightColumnHandlesMapping = joinApplicationResult.get().getRightColumnHandles();
        ImmutableMap.Builder assignmentsBuilder = ImmutableMap.builder();
        assignmentsBuilder.putAll((Map)left.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> (ColumnHandle)leftColumnHandlesMapping.get(entry.getValue()))));
        assignmentsBuilder.putAll((Map)right.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> (ColumnHandle)rightColumnHandlesMapping.get(entry.getValue()))));
        ImmutableMap assignments = assignmentsBuilder.build();
        JoinNode.Type joinType = joinNode.getType();
        TupleDomain<ColumnHandle> leftConstraint = this.deriveConstraint(left.getEnforcedConstraint(), leftColumnHandlesMapping, joinType == JoinNode.Type.RIGHT || joinType == JoinNode.Type.FULL);
        TupleDomain<ColumnHandle> rightConstraint = this.deriveConstraint(right.getEnforcedConstraint(), rightColumnHandlesMapping, joinType == JoinNode.Type.LEFT || joinType == JoinNode.Type.FULL);
        TupleDomain newEnforcedConstraint = TupleDomain.withColumnDomains((Map)ImmutableMap.builder().putAll((Map)leftConstraint.getDomains().orElseThrow()).putAll((Map)rightConstraint.getDomains().orElseThrow()).build());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(joinNode.getId(), handle, (List<Symbol>)ImmutableList.copyOf(assignments.keySet()), (Map<Symbol, ColumnHandle>)assignments, (TupleDomain<ColumnHandle>)newEnforcedConstraint, Rules.deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), joinApplicationResult.get().isPrecalculateStatistics(), joinNode), false, Optional.empty()), Assignments.identity(joinNode.getOutputSymbols())));
    }

    private JoinStatistics getJoinStatistics(final JoinNode join, final TableScanNode left, final TableScanNode right, final Rule.Context context) {
        return new JoinStatistics(){

            public Optional<BasicRelationStatistics> getLeftStatistics() {
                return this.getBasicRelationStats(left, left.getOutputSymbols(), context);
            }

            public Optional<BasicRelationStatistics> getRightStatistics() {
                return this.getBasicRelationStats(right, right.getOutputSymbols(), context);
            }

            public Optional<BasicRelationStatistics> getJoinStatistics() {
                return this.getBasicRelationStats(join, join.getOutputSymbols(), context);
            }

            private Optional<BasicRelationStatistics> getBasicRelationStats(PlanNode node, List<Symbol> outputSymbols, Rule.Context context2) {
                PlanNodeStatsEstimate stats = context2.getStatsProvider().getStats(node);
                TypeProvider types = context2.getSymbolAllocator().getTypes();
                double outputRowCount = stats.getOutputRowCount();
                double outputSize = stats.getOutputSizeInBytes(outputSymbols, types);
                if (Double.isNaN(outputRowCount) || Double.isNaN(outputSize)) {
                    return Optional.empty();
                }
                return Optional.of(new BasicRelationStatistics((long)outputRowCount, (long)outputSize));
            }
        };
    }

    private TupleDomain<ColumnHandle> deriveConstraint(TupleDomain<ColumnHandle> sourceConstraint, Map<ColumnHandle, ColumnHandle> columnMapping, boolean nullable) {
        TupleDomain constraint = sourceConstraint;
        if (nullable) {
            constraint = constraint.transformDomains((columnHandle, domain) -> domain.union(Domain.onlyNull((Type)domain.getType())));
        }
        return constraint.transformKeys(columnMapping::get);
    }

    public Expression getEffectiveFilter(JoinNode node) {
        Expression effectiveFilter = ExpressionUtils.and((Collection)node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(ImmutableList.toImmutableList()));
        if (node.getFilter().isPresent()) {
            effectiveFilter = ExpressionUtils.and(effectiveFilter, node.getFilter().get());
        }
        return effectiveFilter;
    }

    private FilterSplitResult splitFilter(Expression filter, List<Symbol> leftSymbolsList, List<Symbol> rightSymbolsList, Rule.Context context) {
        ImmutableSet leftSymbols = ImmutableSet.copyOf(leftSymbolsList);
        ImmutableSet rightSymbols = ImmutableSet.copyOf(rightSymbolsList);
        ImmutableList.Builder comparisonConditions = ImmutableList.builder();
        ImmutableList.Builder remainingConjuncts = ImmutableList.builder();
        for (Expression conjunct : ExpressionUtils.extractConjuncts(filter)) {
            this.getPushableJoinCondition(conjunct, (Set<Symbol>)leftSymbols, (Set<Symbol>)rightSymbols, context).ifPresentOrElse(arg_0 -> ((ImmutableList.Builder)comparisonConditions).add(arg_0), () -> remainingConjuncts.add((Object)conjunct));
        }
        return new FilterSplitResult((List<JoinCondition>)comparisonConditions.build(), ExpressionUtils.and((Collection<Expression>)remainingConjuncts.build()));
    }

    private Optional<JoinCondition> getPushableJoinCondition(Expression conjunct, Set<Symbol> leftSymbols, Set<Symbol> rightSymbols, Rule.Context context) {
        if (!(conjunct instanceof ComparisonExpression)) {
            return Optional.empty();
        }
        ComparisonExpression comparison = (ComparisonExpression)conjunct;
        if (!(comparison.getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference)) {
            return Optional.empty();
        }
        Symbol left = Symbol.from(comparison.getLeft());
        Symbol right = Symbol.from(comparison.getRight());
        ComparisonExpression.Operator operator = comparison.getOperator();
        if (!leftSymbols.contains(left)) {
            Symbol tmp = left;
            left = right;
            right = tmp;
            operator = operator.flip();
        }
        if (leftSymbols.contains(left) && rightSymbols.contains(right)) {
            return Optional.of(new JoinCondition(this.joinConditionOperator(operator), (ConnectorExpression)new Variable(left.getName(), context.getSymbolAllocator().getTypes().get(left)), (ConnectorExpression)new Variable(right.getName(), context.getSymbolAllocator().getTypes().get(right))));
        }
        return Optional.empty();
    }

    private JoinCondition.Operator joinConditionOperator(ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return JoinCondition.Operator.EQUAL;
            }
            case NOT_EQUAL: {
                return JoinCondition.Operator.NOT_EQUAL;
            }
            case LESS_THAN: {
                return JoinCondition.Operator.LESS_THAN;
            }
            case LESS_THAN_OR_EQUAL: {
                return JoinCondition.Operator.LESS_THAN_OR_EQUAL;
            }
            case GREATER_THAN: {
                return JoinCondition.Operator.GREATER_THAN;
            }
            case GREATER_THAN_OR_EQUAL: {
                return JoinCondition.Operator.GREATER_THAN_OR_EQUAL;
            }
            case IS_DISTINCT_FROM: {
                return JoinCondition.Operator.IS_DISTINCT_FROM;
            }
        }
        throw new IllegalArgumentException("Unknown operator: " + operator);
    }

    private JoinType getJoinType(JoinNode joinNode) {
        switch (joinNode.getType()) {
            case INNER: {
                return JoinType.INNER;
            }
            case LEFT: {
                return JoinType.LEFT_OUTER;
            }
            case RIGHT: {
                return JoinType.RIGHT_OUTER;
            }
            case FULL: {
                return JoinType.FULL_OUTER;
            }
        }
        throw new IllegalArgumentException("Unknown join type: " + joinNode.getType());
    }

    private static class FilterSplitResult {
        private final List<JoinCondition> pushableConditions;
        private final Expression remainingFilter;

        public FilterSplitResult(List<JoinCondition> pushableConditions, Expression remainingFilter) {
            this.pushableConditions = Objects.requireNonNull(pushableConditions, "pushableConditions is null");
            this.remainingFilter = Objects.requireNonNull(remainingFilter, "remainingFilter is null");
        }

        public List<JoinCondition> getPushableConditions() {
            return this.pushableConditions;
        }

        public Expression getRemainingFilter() {
            return this.remainingFilter;
        }
    }
}

