#pragma once
#include "Mesh.h"
#include "Material.h"
#include "Light.h"
#include "Camera.h"
#include "VKBase.h"
#include "Input.h"

#include <glm/glm.hpp>

#include <vector>
#include <memory>

using namespace Celestiq::Vulkan;

class BindableBuffer {
public:
    virtual void create() = 0;
    virtual void upload() = 0;
    virtual void writeDescriptor(descriptorSet& set, uint32_t binding) = 0;
    virtual ~BindableBuffer() = default;
};

template<typename T>
class TypedBufferManager : public BindableBuffer {
protected:
    std::unique_ptr<storageBuffer> buffer;
    std::vector<T> data;

public:
    virtual std::vector<T> fetch() const = 0;

    void create() override {
        data = fetch();
        if (!data.empty())
            buffer = std::make_unique<storageBuffer>(data.size() * sizeof(T));
    }

    void upload() override {
        if (!data.empty())
            buffer->TransferData(data.data(), data.size() * sizeof(T));
    }

    void writeDescriptor(descriptorSet& set, uint32_t binding) override {
        if (buffer) {
            VkDescriptorBufferInfo info{
                .buffer = buffer->getHandle(), .offset = 0, .range = VK_WHOLE_SIZE
            };
            set.Write(makeSpanFromOne(info), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, binding);
        }
    }
};


class VertexBufferManager : public TypedBufferManager<VertexData> {
private:
    std::vector<Mesh*> meshes;

public:
    void collect(const std::vector<std::unique_ptr<Mesh>>& input) {
        meshes.clear();
        meshes.reserve(input.size());
        for (const auto& m : input)
            meshes.push_back(m.get());  // 提取裸指针
    }

    std::vector<VertexData> fetch() const override {
        std::vector<VertexData> result;
        for (const auto& mesh : meshes) {
            const auto& verts = mesh->getVertices();
            result.insert(result.end(), verts.begin(), verts.end());
        }
        return result;
    }
};


class IndexBufferManager : public TypedBufferManager<uint32_t> {
private:
    std::vector<Mesh*> meshes;

public:
    void collect(const std::vector<std::unique_ptr<Mesh>>& input) {
        meshes.clear();
        meshes.reserve(input.size());
        for (const auto& m : input)
            meshes.push_back(m.get());  // 提取裸指针
    }

    std::vector<uint32_t> fetch() const override {
        std::vector<uint32_t> result;
        uint32_t vertexOffset = 0;
        for (const auto& mesh : meshes) {
            const auto& meshIndices = mesh->getIndices();
            for (uint32_t idx : meshIndices)
                result.push_back(idx + vertexOffset);
            vertexOffset += static_cast<uint32_t>(mesh->getVertices().size());
        }
        return result;
    }
};


struct alignas(16) sceneObject{
    glm::mat4 transform;
    int indexOffset;
    int indexCount;
    int materialID;
    float padding;
};

class ObjectBufferManager : public TypedBufferManager<sceneObject> {
private:
    std::vector<Mesh*> meshes;

public:
    void collect(const std::vector<std::unique_ptr<Mesh>>& input) {
        meshes.clear();
        meshes.reserve(input.size());
        for (const auto& m : input)
            meshes.push_back(m.get());  // 提取裸指针
    }

    std::vector<sceneObject> fetch() const override {
        std::vector<sceneObject> result;
        uint32_t vertexOffset = 0;
        uint32_t indexOffset = 0;

        for (const auto& mesh : meshes) {
            const auto& meshVertices = mesh->getVertices();
            const auto& meshIndices = mesh->getIndices();

            sceneObject obj;
            obj.transform = mesh->get_ModelMatrix();
            obj.indexOffset = indexOffset;
            obj.indexCount = static_cast<int>(meshIndices.size());
            obj.materialID = mesh->get_materialID();
            obj.padding = 0.0f;

            result.push_back(obj);

            vertexOffset += static_cast<uint32_t>(meshVertices.size());
            indexOffset += static_cast<uint32_t>(meshIndices.size());
        }

        return result;
    }
};


class MaterialBufferManager : public TypedBufferManager<MaterialData> {
public:
    std::vector<MaterialData> fetch() const override {
        return MaterialManager::get().getMaterialDataBuffer();
    }
};


class FaceLightBufferManager : public TypedBufferManager<FaceLightData> {
public:
    std::vector<FaceLightData> fetch() const override {
        return Lights::get().getFaceLightDataBuffer();
    }
};

class DirectionalLightBufferManager : public TypedBufferManager<DirectionalLightData> {
public:
    std::vector<DirectionalLightData> fetch() const override {
        return Lights::get().getDirectionalLightDataBuffer();
    }
};


class Scene{
private:
    std::vector<std::unique_ptr<Mesh>> s_meshes;
    std::vector<sceneObject> s_object;
    std::unique_ptr<Camera> r_camera;

    std::unique_ptr<descriptorSetLayout> s_descriptorSetLayout;
    std::unique_ptr<descriptorSet> s_descriptorSet;

    std::unique_ptr<VertexBufferManager>         s_vertexBufferMgr;
    std::unique_ptr<IndexBufferManager>          s_indexBufferMgr;
    std::unique_ptr<ObjectBufferManager>         s_objectBufferMgr;
    std::unique_ptr<MaterialBufferManager>       s_materialBufferMgr;
    std::unique_ptr<FaceLightBufferManager>      s_faceLightBufferMgr;
    std::unique_ptr<DirectionalLightBufferManager> s_directionalLightBufferMgr;
    

public:
    void initScene();
    void initDescriptor(descriptorPool* pool);
    void writeDescriptor();
    void update(float deltaTime);

    inline VkDescriptorSetLayout getDescriptorSetLayout() {return s_descriptorSetLayout->getHandle();}
    VkDescriptorSet getDescriptorSet(){return s_descriptorSet->getHandle();}

private:
    void initBufferManager();
    void uploadSceneToGPU();

public:
    static glm::vec3 hexToVec3(const std::string& hexStr);
};