import  torch
import torch.nn as nn
import random
import torch.nn.utils.rnn as rnn_utils

# Defining an LSTM layer
input_size = 384 # Dimension of the embedding
hidden_size = 20
num_layers = 2

# Defining bins
N_EXAMS = 10
N_FRAMES_PER_EXAM = 20
N_BINS = 5
N_CLASSES = 5

# Creating exams


exams = torch.rand(N_EXAMS, N_FRAMES_PER_EXAM, 1, 384, 384)



# Binning logic
binned_exams = []
for j in range(N_EXAMS):
    binned_exam = []
    for i in range(N_BINS):
        frames_per_bin = N_FRAMES_PER_EXAM / N_BINS

        start = int(i * frames_per_bin)
        end = int(((i + 1) * frames_per_bin) -1)

        rand_frame = random.randint(start, end) # Picking a random frame within the amount of frames per bin

        chosen = exams[j][rand_frame] # Should be [1, 384, 384] since it's one frame

        binned_exam.append(chosen)

    x = torch.stack(binned_exam)
    binned_exams.append(x)
    

binned_exams_t = torch.stack(binned_exams)
print(binned_exams_t.shape)

print("-"*68)

# Encode frames for LSTM
frame_encoder_vit_sim = nn.Linear(1*384*384, input_size)
frames_flat_for_vit = binned_exams_t.view(N_EXAMS * N_BINS, 1,  384, 384)

all_encoded_frames = frame_encoder_vit_sim(frames_flat_for_vit.view(N_EXAMS * N_BINS, -1))
print(f"The shape of all the encoded frames: {all_encoded_frames.shape}")

# Shape the embeddings back to sequencies per exam
sequence_embeddings = all_encoded_frames.view(N_EXAMS, N_BINS, input_size)
print(f"The shape of sequential embeddings {sequence_embeddings.shape}")



lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)


# Initialize hidden and cell states
h_0 = torch.randn(num_layers, N_EXAMS, hidden_size)
c_0 = torch.randn(num_layers, N_EXAMS, hidden_size)

print("-"*68)

# Forward pass
output, (h_n, c_n) = lstm(sequence_embeddings, (h_0, c_0))
print(f"Output shape: {output.shape}")
print(f"Final hidden state shape: {h_n.shape}")
print(f"Final cell state shape: {c_n.shape}")

print("-"*68)

last_hidden_state_for_classification = h_n[-1, :, :]
print(f"The shape of the last hidden state: {last_hidden_state_for_classification.shape}")

print("-"*68)
#Classifier
classifier = nn.Linear(hidden_size, N_CLASSES)

logits = classifier(last_hidden_state_for_classification)
print(f"Shape of the logits: {logits.shape}") # Each exam should now have 5 logits

probabilities = torch.softmax(logits, dim=1) # The logits are converted to probabilities using softmax
print(f" The probabilities of the {N_EXAMS} exams:\n {probabilities}")


# This would of course then be implemented in the original training pipeline, which consists of a loss function and optimizer etc.