/*
 * Decompiled with CFR 0.152.
 */
package com.mentalfrostbyte.jello.util.game.player.rotation;

import com.mentalfrostbyte.Client;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

public class NeuralNetwork {
    private static final Random random = new Random();
    public static final int INPUT_SIZE = 8;
    public static final int HIDDEN_SIZE = 16;
    public static final int OUTPUT_SIZE = 2;
    private float[][] weightsInputToHidden = new float[8][16];
    private float[][] weightsHiddenToOutput = new float[16][2];
    private float[] hiddenBiases = new float[16];
    private float[] outputBiases = new float[2];
    private static final float LEARNING_RATE = 0.03f;
    private static final float MOMENTUM = 0.7f;
    private float[][] prevDeltaInputHidden = new float[8][16];
    private float[][] prevDeltaHiddenOutput = new float[16][2];
    private List<TrainingSample> batchSamples = new ArrayList<TrainingSample>();
    private static final int BATCH_SIZE = 12;
    private static final int EPOCHS_PER_BATCH = 15;
    private static final String WEIGHTS_FILE = "jelloai_weights.dat";
    private int trainingCounter = 0;
    private static final int SAVE_INTERVAL = 50;
    private boolean initialized = false;
    private int sampleCount = 0;
    private static final int CONFIDENCE_THRESHOLD = 100;

    public void initialize() {
        if (!this.loadWeights()) {
            this.initializeRandomWeights();
        }
        this.initialized = true;
    }

    private void initializeRandomWeights() {
        int j;
        int i;
        for (i = 0; i < 8; ++i) {
            for (j = 0; j < 16; ++j) {
                this.weightsInputToHidden[i][j] = (random.nextFloat() - 0.5f) * 0.1f;
                this.prevDeltaInputHidden[i][j] = 0.0f;
            }
        }
        for (i = 0; i < 16; ++i) {
            this.hiddenBiases[i] = (random.nextFloat() - 0.5f) * 0.1f;
            for (j = 0; j < 2; ++j) {
                this.weightsHiddenToOutput[i][j] = (random.nextFloat() - 0.5f) * 0.1f;
                this.prevDeltaHiddenOutput[i][j] = 0.0f;
            }
        }
        for (i = 0; i < 2; ++i) {
            this.outputBiases[i] = (random.nextFloat() - 0.5f) * 0.1f;
        }
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public float getConfidence() {
        if (!this.initialized) {
            return 0.0f;
        }
        return Math.min(1.0f, (float)this.sampleCount / 100.0f);
    }

    public void addToBatch(float[] inputs, float[] expected, float weight) {
        if (this.batchSamples.size() >= 12) {
            this.trainBatch();
        }
        this.batchSamples.add(new TrainingSample(inputs, expected, weight));
    }

    public void trainBatch() {
        if (this.batchSamples.isEmpty()) {
            return;
        }
        float totalError = 0.0f;
        for (int epoch = 0; epoch < 15; ++epoch) {
            Collections.shuffle(this.batchSamples);
            for (TrainingSample sample : this.batchSamples) {
                totalError += this.trainNetworkImmediate(sample.inputs, sample.expected, sample.weight);
            }
        }
        Client.logger.info("JelloAI: Batch training completed, average error: " + totalError / (float)(this.batchSamples.size() * 15));
        this.batchSamples.clear();
    }

    public float trainNetworkImmediate(float[] inputs, float[] expectedOutputs, float weight) {
        int i;
        ++this.sampleCount;
        boolean pushAway = weight < 0.0f;
        float absWeight = Math.abs(weight);
        float[] hiddenOutputs = new float[16];
        for (int i2 = 0; i2 < 16; ++i2) {
            float sum = this.hiddenBiases[i2];
            for (int j = 0; j < 8; ++j) {
                sum += inputs[j] * this.weightsInputToHidden[j][i2];
            }
            hiddenOutputs[i2] = this.sigmoid(sum);
        }
        float[] outputs = new float[2];
        for (int i3 = 0; i3 < 2; ++i3) {
            float sum = this.outputBiases[i3];
            for (int j = 0; j < 16; ++j) {
                sum += hiddenOutputs[j] * this.weightsHiddenToOutput[j][i3];
            }
            outputs[i3] = this.tanh(sum);
        }
        Client.logger.info("JelloAI Training - Expected: [" + expectedOutputs[0] + ", " + expectedOutputs[1] + "], Predicted: [" + outputs[0] + ", " + outputs[1] + "], Weight: " + weight);
        float totalError = 0.0f;
        float[] outputErrors = new float[2];
        for (int i4 = 0; i4 < 2; ++i4) {
            float error = expectedOutputs[i4] - outputs[i4];
            if (pushAway) {
                error = -error;
            }
            outputErrors[i4] = error * this.tanhDerivative(outputs[i4]) * absWeight;
            totalError += error * error;
        }
        float[] hiddenErrors = new float[16];
        for (i = 0; i < 16; ++i) {
            float error = 0.0f;
            for (int j = 0; j < 2; ++j) {
                error += outputErrors[j] * this.weightsHiddenToOutput[i][j];
            }
            hiddenErrors[i] = error * this.sigmoidDerivative(hiddenOutputs[i]);
        }
        for (i = 0; i < 16; ++i) {
            for (int j = 0; j < 2; ++j) {
                float delta = 0.03f * outputErrors[j] * hiddenOutputs[i] + 0.7f * this.prevDeltaHiddenOutput[i][j];
                float[] fArray = this.weightsHiddenToOutput[i];
                int n = j;
                fArray[n] = fArray[n] + delta;
                this.prevDeltaHiddenOutput[i][j] = delta;
            }
        }
        for (i = 0; i < 2; ++i) {
            int n = i;
            this.outputBiases[n] = this.outputBiases[n] + 0.03f * outputErrors[i];
        }
        for (i = 0; i < 8; ++i) {
            for (int j = 0; j < 16; ++j) {
                float delta = 0.03f * hiddenErrors[j] * inputs[i] + 0.7f * this.prevDeltaInputHidden[i][j];
                float[] fArray = this.weightsInputToHidden[i];
                int n = j;
                fArray[n] = fArray[n] + delta;
                this.prevDeltaInputHidden[i][j] = delta;
            }
        }
        for (i = 0; i < 16; ++i) {
            int n = i;
            this.hiddenBiases[n] = this.hiddenBiases[n] + 0.03f * hiddenErrors[i];
        }
        ++this.trainingCounter;
        if (this.trainingCounter >= 50) {
            this.saveWeights();
            this.trainingCounter = 0;
        }
        return totalError;
    }

    public void saveWeights() {
        try {
            int j;
            int i;
            File file = new File(Client.getInstance().file, WEIGHTS_FILE);
            DataOutputStream out = new DataOutputStream(new FileOutputStream(file));
            for (i = 0; i < 8; ++i) {
                for (j = 0; j < 16; ++j) {
                    out.writeFloat(this.weightsInputToHidden[i][j]);
                }
            }
            for (i = 0; i < 16; ++i) {
                out.writeFloat(this.hiddenBiases[i]);
            }
            for (i = 0; i < 16; ++i) {
                for (j = 0; j < 2; ++j) {
                    out.writeFloat(this.weightsHiddenToOutput[i][j]);
                }
            }
            for (i = 0; i < 2; ++i) {
                out.writeFloat(this.outputBiases[i]);
            }
            out.close();
            Client.logger.info("JelloAI: Saved neural network weights");
        }
        catch (Exception e) {
            Client.logger.error("Error saving JelloAI weights", (Throwable)e);
        }
    }

    private boolean loadWeights() {
        try {
            int j;
            int i;
            File file = new File(Client.getInstance().file, WEIGHTS_FILE);
            if (!file.exists()) {
                return false;
            }
            DataInputStream in = new DataInputStream(new FileInputStream(file));
            for (i = 0; i < 8; ++i) {
                for (j = 0; j < 16; ++j) {
                    this.weightsInputToHidden[i][j] = in.readFloat();
                }
            }
            for (i = 0; i < 16; ++i) {
                this.hiddenBiases[i] = in.readFloat();
            }
            for (i = 0; i < 16; ++i) {
                for (j = 0; j < 2; ++j) {
                    this.weightsHiddenToOutput[i][j] = in.readFloat();
                }
            }
            for (i = 0; i < 2; ++i) {
                this.outputBiases[i] = in.readFloat();
            }
            in.close();
            Client.logger.info("JelloAI: Loaded neural network weights");
            return true;
        }
        catch (Exception e) {
            Client.logger.error("Error loading JelloAI weights", (Throwable)e);
            return false;
        }
    }

    private float sigmoid(float x) {
        return (float)(1.0 / (1.0 + Math.exp(-x)));
    }

    private float sigmoidDerivative(float x) {
        return x * (1.0f - x);
    }

    private float tanh(float x) {
        return (float)Math.tanh(x);
    }

    private float tanhDerivative(float x) {
        return 1.0f - x * x;
    }

    private void shuffleArray(Object[] array) {
        for (int i = array.length - 1; i > 0; --i) {
            int index = random.nextInt(i + 1);
            Object temp = array[index];
            array[index] = array[i];
            array[i] = temp;
        }
    }

    public void trainNetwork(TrainingSample[] samples) {
        if (samples.length == 0) {
            return;
        }
        this.sampleCount += samples.length;
        this.shuffleArray(samples);
        for (int epoch = 0; epoch < 3; ++epoch) {
            for (TrainingSample sample : samples) {
                float[] inputs = sample.getInputs();
                float[] expectedOutputs = sample.getExpectedOutputs();
                float weight = sample.getWeight();
                this.trainNetworkImmediate(inputs, expectedOutputs, weight);
            }
        }
        this.saveWeights();
    }

    public float[] predict(float[] inputs) {
        if (!this.initialized) {
            return new float[2];
        }
        float[] hiddenOutputs = new float[16];
        for (int i = 0; i < 16; ++i) {
            float sum = this.hiddenBiases[i];
            for (int j = 0; j < 8; ++j) {
                sum += inputs[j] * this.weightsInputToHidden[j][i];
            }
            hiddenOutputs[i] = this.sigmoid(sum);
        }
        float[] outputs = new float[2];
        for (int i = 0; i < 2; ++i) {
            float sum = this.outputBiases[i];
            for (int j = 0; j < 16; ++j) {
                sum += hiddenOutputs[j] * this.weightsHiddenToOutput[j][i];
            }
            outputs[i] = this.tanh(sum);
        }
        return outputs;
    }

    private static class TrainingSample {
        float[] inputs;
        float[] expected;
        float weight;

        TrainingSample(float[] inputs, float[] expected, float weight) {
            this.inputs = inputs;
            this.expected = expected;
            this.weight = weight;
        }

        public float[] getInputs() {
            return this.inputs;
        }

        public float[] getExpectedOutputs() {
            return this.expected;
        }

        public float getWeight() {
            return this.weight;
        }
    }
}

