/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.physical;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperatorActions;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor;
import org.opensearch.transport.client.node.NodeClient;

public class MLCommonsOperator
extends MLCommonsOperatorActions {
    private final PhysicalPlan input;
    private final String algorithm;
    private final Map<String, Literal> arguments;
    private final NodeClient nodeClient;
    private Iterator<ExprValue> iterator;

    public void open() {
        super.open();
        final DataFrame inputDataFrame = this.generateInputDataset(this.input);
        MLAlgoParams mlAlgoParams = this.convertArgumentToMLParameter(this.arguments, this.algorithm);
        final MLPredictionOutput predictionResult = this.getMLPredictionResult(FunctionName.valueOf((String)this.algorithm.toUpperCase()), mlAlgoParams, inputDataFrame, this.nodeClient);
        final Iterator inputRowIter = inputDataFrame.iterator();
        final Iterator resultRowIter = predictionResult.getPredictionResult().iterator();
        this.iterator = new Iterator<ExprValue>(){

            @Override
            public boolean hasNext() {
                return inputRowIter.hasNext();
            }

            @Override
            public ExprValue next() {
                return MLCommonsOperator.this.buildResult(inputRowIter, inputDataFrame, predictionResult, resultRowIter);
            }
        };
    }

    public <R, C> R accept(PhysicalPlanNodeVisitor<R, C> visitor, C context) {
        return (R)visitor.visitMLCommons((PhysicalPlan)this, context);
    }

    public boolean hasNext() {
        return this.iterator.hasNext();
    }

    public ExprValue next() {
        return this.iterator.next();
    }

    public List<PhysicalPlan> getChild() {
        return Collections.singletonList(this.input);
    }

    protected MLAlgoParams convertArgumentToMLParameter(Map<String, Literal> arguments, String algorithm) {
        switch (FunctionName.valueOf((String)algorithm.toUpperCase())) {
            case KMEANS: {
                return KMeansParams.builder().centroids(arguments.containsKey("centroids") ? (Integer)arguments.get("centroids").getValue() : null).iterations(arguments.containsKey("iterations") ? (Integer)arguments.get("iterations").getValue() : null).distanceType(arguments.containsKey("distance_type") ? (arguments.get("distance_type").getValue() != null ? KMeansParams.DistanceType.valueOf((String)((String)arguments.get("distance_type").getValue()).toUpperCase()) : null) : null).build();
            }
        }
        throw new IllegalArgumentException(String.format("unsupported algorithm: %s, available algorithms: %s.", FunctionName.valueOf((String)algorithm.toUpperCase()), FunctionName.KMEANS));
    }

    @Generated
    public MLCommonsOperator(PhysicalPlan input, String algorithm, Map<String, Literal> arguments, NodeClient nodeClient) {
        this.input = input;
        this.algorithm = algorithm;
        this.arguments = arguments;
        this.nodeClient = nodeClient;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLCommonsOperator)) {
            return false;
        }
        MLCommonsOperator other = (MLCommonsOperator)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        PhysicalPlan this$input = this.getInput();
        PhysicalPlan other$input = other.getInput();
        if (this$input == null ? other$input != null : !this$input.equals(other$input)) {
            return false;
        }
        String this$algorithm = this.getAlgorithm();
        String other$algorithm = other.getAlgorithm();
        if (this$algorithm == null ? other$algorithm != null : !this$algorithm.equals(other$algorithm)) {
            return false;
        }
        Map<String, Literal> this$arguments = this.getArguments();
        Map<String, Literal> other$arguments = other.getArguments();
        if (this$arguments == null ? other$arguments != null : !((Object)this$arguments).equals(other$arguments)) {
            return false;
        }
        NodeClient this$nodeClient = this.getNodeClient();
        NodeClient other$nodeClient = other.getNodeClient();
        return !(this$nodeClient == null ? other$nodeClient != null : !this$nodeClient.equals(other$nodeClient));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLCommonsOperator;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        PhysicalPlan $input = this.getInput();
        result = result * 59 + ($input == null ? 43 : $input.hashCode());
        String $algorithm = this.getAlgorithm();
        result = result * 59 + ($algorithm == null ? 43 : $algorithm.hashCode());
        Map<String, Literal> $arguments = this.getArguments();
        result = result * 59 + ($arguments == null ? 43 : ((Object)$arguments).hashCode());
        NodeClient $nodeClient = this.getNodeClient();
        result = result * 59 + ($nodeClient == null ? 43 : $nodeClient.hashCode());
        return result;
    }

    @Generated
    public PhysicalPlan getInput() {
        return this.input;
    }

    @Generated
    public String getAlgorithm() {
        return this.algorithm;
    }

    @Generated
    public Map<String, Literal> getArguments() {
        return this.arguments;
    }

    @Generated
    public NodeClient getNodeClient() {
        return this.nodeClient;
    }
}

