/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.stream;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Stream;
import org.apache.solr.client.solrj.SolrClient;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.StreamExecutorHelper;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.cloud.Replica;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;

public class FeaturesSelectionStream
extends TupleStream
implements Expressible {
    private static final long serialVersionUID = 1L;
    protected String zkHost;
    protected String collection;
    protected Map<String, String> params;
    protected Iterator<Tuple> tupleIterator;
    protected String field;
    protected String outcome;
    protected String featureSet;
    protected int positiveLabel;
    protected int numTerms;
    protected transient SolrClientCache clientCache;
    private transient boolean doCloseCache;

    public FeaturesSelectionStream(String zkHost, String collectionName, Map<String, String> params, String field, String outcome, String featureSet, int positiveLabel, int numTerms) throws IOException {
        this.init(collectionName, zkHost, params, field, outcome, featureSet, positiveLabel, numTerms);
    }

    public FeaturesSelectionStream(StreamExpression expression, StreamFactory factory) throws IOException {
        String collectionName = factory.getValueOperand(expression, 0);
        List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
        StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
        if (expression.getParameters().size() != 1 + namedParams.size()) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - unknown operands found", expression));
        }
        if (null == collectionName) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - collectionName expected as first operand", expression));
        }
        if (0 == namedParams.size()) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - at least one named parameter expected. eg. 'q=*:*'", expression));
        }
        HashMap<String, String> params = new HashMap<String, String>();
        for (StreamExpressionNamedParameter namedParam : namedParams) {
            if (namedParam.getName().equals("zkHost")) continue;
            params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
        }
        String fieldParam = (String)params.get("field");
        if (fieldParam == null) {
            throw new IOException("field param cannot be null for FeaturesSelectionStream");
        }
        params.remove("field");
        String outcomeParam = (String)params.get("outcome");
        if (outcomeParam == null) {
            throw new IOException("outcome param cannot be null for FeaturesSelectionStream");
        }
        params.remove("outcome");
        String featureSetParam = (String)params.get("featureSet");
        if (featureSetParam == null) {
            throw new IOException("featureSet param cannot be null for FeaturesSelectionStream");
        }
        params.remove("featureSet");
        String positiveLabelParam = (String)params.get("positiveLabel");
        int positiveLabel = 1;
        if (positiveLabelParam != null) {
            params.remove("positiveLabel");
            positiveLabel = Integer.parseInt(positiveLabelParam);
        }
        String numTermsParam = (String)params.get("numTerms");
        int numTerms = 1;
        if (numTermsParam == null) {
            throw new IOException("numTerms param cannot be null for FeaturesSelectionStream");
        }
        numTerms = Integer.parseInt(numTermsParam);
        params.remove("numTerms");
        String zkHost = null;
        if (null == zkHostExpression) {
            zkHost = factory.getCollectionZkHost(collectionName);
        } else if (zkHostExpression.getParameter() instanceof StreamExpressionValue) {
            zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
        }
        if (null == zkHost) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - zkHost not found for collection '%s'", expression, collectionName));
        }
        this.init(collectionName, zkHost, params, fieldParam, outcomeParam, featureSetParam, positiveLabel, numTerms);
    }

    @Override
    public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
        StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
        expression.addParameter(this.collection);
        for (Map.Entry<String, String> param : this.params.entrySet()) {
            expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
        }
        expression.addParameter(new StreamExpressionNamedParameter("field", this.field));
        expression.addParameter(new StreamExpressionNamedParameter("outcome", this.outcome));
        expression.addParameter(new StreamExpressionNamedParameter("featureSet", this.featureSet));
        expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", String.valueOf(this.positiveLabel)));
        expression.addParameter(new StreamExpressionNamedParameter("numTerms", String.valueOf(this.numTerms)));
        expression.addParameter(new StreamExpressionNamedParameter("zkHost", this.zkHost));
        return expression;
    }

    private void init(String collectionName, String zkHost, Map<String, String> params, String field, String outcome, String featureSet, int positiveLabel, int numTopTerms) throws IOException {
        this.zkHost = zkHost;
        this.collection = collectionName;
        this.params = params;
        this.field = field;
        this.outcome = outcome;
        this.featureSet = featureSet;
        this.positiveLabel = positiveLabel;
        this.numTerms = numTopTerms;
    }

    @Override
    public void setStreamContext(StreamContext context) {
        this.clientCache = context.getSolrClientCache();
    }

    @Override
    public void open() throws IOException {
        if (this.clientCache == null) {
            this.doCloseCache = true;
            this.clientCache = new SolrClientCache();
        } else {
            this.doCloseCache = false;
        }
    }

    @Override
    public List<TupleStream> children() {
        return null;
    }

    private List<String> getShardUrls() throws IOException {
        try {
            CloudSolrClient cloudSolrClient = this.clientCache.getCloudSolrClient(this.zkHost);
            Slice[] slices = CloudSolrStream.getSlices(this.collection, cloudSolrClient, false);
            Set liveNodes = cloudSolrClient.getClusterState().getLiveNodes();
            ArrayList<String> baseUrls = new ArrayList<String>();
            for (Slice slice : slices) {
                Collection replicas = slice.getReplicas();
                ArrayList<Replica> shuffler = new ArrayList<Replica>();
                for (Replica replica : replicas) {
                    if (replica.getState() != Replica.State.ACTIVE || !liveNodes.contains(replica.getNodeName())) continue;
                    shuffler.add(replica);
                }
                Collections.shuffle(shuffler, new Random());
                Replica rep = (Replica)shuffler.get(0);
                String url = rep.getCoreUrl();
                baseUrls.add(url);
            }
            return baseUrls;
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    private Collection<NamedList<?>> callShards(List<String> baseUrls) throws IOException {
        ArrayList<FeaturesSelectionCall> tasks = new ArrayList<FeaturesSelectionCall>();
        for (String baseUrl : baseUrls) {
            FeaturesSelectionCall lc = new FeaturesSelectionCall(baseUrl, this.params, this.field, this.outcome, this.positiveLabel, this.numTerms, this.clientCache);
            tasks.add(lc);
        }
        return StreamExecutorHelper.submitAllAndAwaitAggregatingExceptions(tasks, "FeaturesSelectionStream");
    }

    @Override
    public void close() throws IOException {
        if (this.doCloseCache) {
            this.clientCache.close();
        }
    }

    @Override
    public StreamComparator getStreamSort() {
        return null;
    }

    @Override
    public Explanation toExplanation(StreamFactory factory) throws IOException {
        return new StreamExplanation(this.getStreamNodeId().toString()).withFunctionName(factory.getFunctionName(this.getClass())).withImplementingClass(this.getClass().getName()).withExpressionType("stream-decorator").withExpression(this.toExpression(factory).toString());
    }

    @Override
    public Tuple read() throws IOException {
        try {
            if (this.tupleIterator == null) {
                Map<String, Double> termScores = new HashMap();
                HashMap<String, Long> docFreqs = new HashMap<String, Long>();
                long numDocs = 0L;
                for (NamedList<?> resp : this.callShards(this.getShardUrls())) {
                    NamedList shardTopTerms = (NamedList)resp.get("featuredTerms");
                    NamedList shardDocFreqs = (NamedList)resp.get("docFreq");
                    numDocs += (long)((Integer)resp.get("numDocs")).intValue();
                    for (int i = 0; i < shardTopTerms.size(); ++i) {
                        String term = shardTopTerms.getName(i);
                        double score = (Double)shardTopTerms.getVal(i);
                        int docFreq = (Integer)shardDocFreqs.get(term);
                        double prevScore = termScores.containsKey(term) ? (Double)termScores.get(term) : 0.0;
                        long prevDocFreq = docFreqs.containsKey(term) ? (Long)docFreqs.get(term) : 0L;
                        termScores.put(term, prevScore + score);
                        docFreqs.put(term, prevDocFreq + (long)docFreq);
                    }
                }
                ArrayList<Tuple> tuples = new ArrayList<Tuple>(this.numTerms);
                termScores = this.sortByValue(termScores);
                int index = 0;
                for (Map.Entry termScore : termScores.entrySet()) {
                    if (tuples.size() == this.numTerms) break;
                    Tuple tuple = new Tuple();
                    tuple.put("id", this.featureSet + "_" + ++index);
                    tuple.put("index_i", index);
                    tuple.put("term_s", termScore.getKey());
                    tuple.put("score_f", termScore.getValue());
                    tuple.put("featureSet_s", this.featureSet);
                    long docFreq = (Long)docFreqs.get(termScore.getKey());
                    double d = Math.log((double)numDocs / (double)(docFreq + 1L));
                    tuple.put("idf_d", d);
                    tuples.add(tuple);
                }
                tuples.add(Tuple.EOF());
                this.tupleIterator = tuples.iterator();
            }
            return this.tupleIterator.next();
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    private <K, V extends Comparable<? super V>> Map<K, V> sortByValue(Map<K, V> map) {
        LinkedHashMap result = new LinkedHashMap();
        Stream st = map.entrySet().stream();
        st.sorted(Map.Entry.comparingByValue((c1, c2) -> c2.compareTo(c1))).forEachOrdered(e -> result.put(e.getKey(), (Comparable)e.getValue()));
        return result;
    }

    protected static class FeaturesSelectionCall
    implements Callable<NamedList<?>> {
        private final String baseUrl;
        private final String outcome;
        private final String field;
        private final Map<String, String> paramsMap;
        private final int positiveLabel;
        private final int numTerms;
        private final SolrClientCache clientCache;

        public FeaturesSelectionCall(String baseUrl, Map<String, String> paramsMap, String field, String outcome, int positiveLabel, int numTerms, SolrClientCache clientCache) {
            this.baseUrl = baseUrl;
            this.outcome = outcome;
            this.field = field;
            this.paramsMap = paramsMap;
            this.positiveLabel = positiveLabel;
            this.numTerms = numTerms;
            this.clientCache = clientCache;
        }

        @Override
        public NamedList<?> call() throws Exception {
            ModifiableSolrParams params = new ModifiableSolrParams();
            SolrClient solrClient = this.clientCache.getHttpSolrClient(this.baseUrl);
            params.add("distrib", new String[]{"false"});
            params.add("fq", new String[]{"{!igain}"});
            for (Map.Entry<String, String> entry : this.paramsMap.entrySet()) {
                params.add(entry.getKey(), new String[]{entry.getValue()});
            }
            params.add("outcome", new String[]{this.outcome});
            params.add("positiveLabel", new String[]{Integer.toString(this.positiveLabel)});
            params.add("field", new String[]{this.field});
            params.add("numTerms", new String[]{String.valueOf(this.numTerms)});
            QueryRequest request = new QueryRequest((SolrParams)params);
            QueryResponse response = (QueryResponse)request.process(solrClient);
            NamedList res = response.getResponse();
            return res;
        }
    }
}

