/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.store.mongo;

import com.mongodb.client.model.Accumulators;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.BsonField;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.drill.exec.store.mongo.common.MongoOp;
import org.bson.BsonArray;
import org.bson.BsonDocument;
import org.bson.BsonElement;
import org.bson.BsonInt32;
import org.bson.BsonNull;
import org.bson.BsonString;
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.conversions.Bson;

public class MongoAggregateUtils {
    public static List<String> mongoFieldNames(RelDataType rowType) {
        List renamed = rowType.getFieldNames().stream().map(name -> name.startsWith("$") ? "_" + name.substring(2) : name).collect(Collectors.toList());
        return SqlValidatorUtil.uniquify(renamed, (boolean)true);
    }

    public static String maybeQuote(String s) {
        if (!MongoAggregateUtils.needsQuote(s)) {
            return s;
        }
        return MongoAggregateUtils.quote(s);
    }

    public static String quote(String s) {
        return "'" + s + "'";
    }

    private static boolean needsQuote(String s) {
        int n = s.length();
        for (int i = 0; i < n; ++i) {
            char c = s.charAt(i);
            if (Character.isJavaIdentifierPart(c) && c != '$') continue;
            return true;
        }
        return false;
    }

    public static List<Bson> getAggregateOperations(Aggregate aggregate, RelDataType rowType) {
        String id;
        List<String> inNames = MongoAggregateUtils.mongoFieldNames(rowType);
        List<String> outNames = MongoAggregateUtils.mongoFieldNames(aggregate.getRowType());
        if (aggregate.getGroupSet().cardinality() == 1) {
            String inName2 = inNames.get(aggregate.getGroupSet().nth(0));
            id = "$" + inName2;
        } else {
            List elements = StreamSupport.stream(aggregate.getGroupSet().spliterator(), false).map(inNames::get).map(inName -> new BsonElement(inName, (BsonValue)new BsonString("$" + inName))).collect(Collectors.toList());
            id = new BsonDocument(elements);
        }
        int outNameIndex = aggregate.getGroupSet().cardinality();
        ArrayList<BsonField> accumList = new ArrayList<BsonField>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            accumList.add(MongoAggregateUtils.bsonAggregate(inNames, outNames.get(outNameIndex++), aggCall));
        }
        ArrayList<Bson> operationsList = new ArrayList<Bson>();
        operationsList.add((Bson)Aggregates.group((Object)id, accumList).toBsonDocument());
        ArrayList<BsonElement> projectFields = new ArrayList<BsonElement>();
        if (aggregate.getGroupSet().cardinality() == 1) {
            for (int index = 0; index < outNames.size(); ++index) {
                String outName = outNames.get(index);
                projectFields.add(new BsonElement(MongoAggregateUtils.maybeQuote(outName), (BsonValue)new BsonString("$" + (index == 0 ? "_id" : outName))));
            }
        } else {
            projectFields.add(new BsonElement("_id", (BsonValue)new BsonInt32(0)));
            Iterator iterator = aggregate.getGroupSet().iterator();
            while (iterator.hasNext()) {
                int group = (Integer)iterator.next();
                projectFields.add(new BsonElement(MongoAggregateUtils.maybeQuote(outNames.get(group)), (BsonValue)new BsonString("$_id." + outNames.get(group))));
            }
            outNameIndex = aggregate.getGroupSet().cardinality();
            for (AggregateCall ignored : aggregate.getAggCallList()) {
                String outName = outNames.get(outNameIndex++);
                projectFields.add(new BsonElement(MongoAggregateUtils.maybeQuote(outName), (BsonValue)new BsonString("$" + outName)));
            }
        }
        if (!aggregate.getGroupSet().isEmpty()) {
            operationsList.add((Bson)Aggregates.project((Bson)new BsonDocument(projectFields)).toBsonDocument());
        }
        return operationsList;
    }

    private static BsonField bsonAggregate(List<String> inNames, String outName, AggregateCall aggCall) {
        String aggregationName = aggCall.getAggregation().getName();
        List args = aggCall.getArgList();
        if (aggregationName.equals(SqlStdOperatorTable.COUNT.getName())) {
            Integer expr;
            if (args.size() == 0) {
                expr = 1;
            } else {
                assert (args.size() == 1);
                String inName = inNames.get((Integer)args.get(0));
                expr = new BsonDocument(MongoOp.COND.getCompareOp(), (BsonValue)new BsonArray(Arrays.asList(new Document(MongoOp.EQUAL.getCompareOp(), (Object)new BsonArray(Arrays.asList(new BsonString(MongoAggregateUtils.quote(inName)), BsonNull.VALUE))).toBsonDocument(), new BsonInt32(0), new BsonInt32(1))));
            }
            return Accumulators.sum((String)MongoAggregateUtils.maybeQuote(outName), (Object)expr);
        }
        BiFunction<String, String, BsonField> mongoAccumulator = MongoAggregateUtils.mongoAccumulator(aggregationName);
        if (mongoAccumulator != null) {
            return mongoAccumulator.apply(MongoAggregateUtils.maybeQuote(outName), "$" + inNames.get((Integer)args.get(0)));
        }
        return null;
    }

    private static <T> BiFunction<String, T, BsonField> mongoAccumulator(String aggregationName) {
        if (aggregationName.equals(SqlStdOperatorTable.SUM.getName()) || aggregationName.equals(SqlStdOperatorTable.SUM0.getName())) {
            return Accumulators::sum;
        }
        if (aggregationName.equals(SqlStdOperatorTable.MIN.getName())) {
            return Accumulators::min;
        }
        if (aggregationName.equals(SqlStdOperatorTable.MAX.getName())) {
            return Accumulators::max;
        }
        if (aggregationName.equals(SqlStdOperatorTable.AVG.getName())) {
            return Accumulators::avg;
        }
        if (aggregationName.equals(SqlStdOperatorTable.FIRST.getName())) {
            return Accumulators::first;
        }
        if (aggregationName.equals(SqlStdOperatorTable.LAST.getName())) {
            return Accumulators::last;
        }
        if (aggregationName.equals(SqlStdOperatorTable.STDDEV.getName()) || aggregationName.equals(SqlStdOperatorTable.STDDEV_SAMP.getName())) {
            return Accumulators::stdDevSamp;
        }
        if (aggregationName.equals(SqlStdOperatorTable.STDDEV_POP.getName())) {
            return Accumulators::stdDevPop;
        }
        return null;
    }

    public static boolean supportsAggregation(AggregateCall aggregateCall) {
        String name = aggregateCall.getAggregation().getName();
        return name.equals(SqlStdOperatorTable.COUNT.getName()) || name.equals(SqlStdOperatorTable.SUM.getName()) || name.equals(SqlStdOperatorTable.SUM0.getName()) || name.equals(SqlStdOperatorTable.MIN.getName()) || name.equals(SqlStdOperatorTable.MAX.getName()) || name.equals(SqlStdOperatorTable.AVG.getName()) || name.equals(SqlStdOperatorTable.FIRST.getName()) || name.equals(SqlStdOperatorTable.LAST.getName()) || name.equals(SqlStdOperatorTable.STDDEV.getName()) || name.equals(SqlStdOperatorTable.STDDEV_SAMP.getName()) || name.equals(SqlStdOperatorTable.STDDEV_POP.getName());
    }
}

