#version 450
#pragma shader_stage(compute)

layout(local_size_x = 16, local_size_y = 16) in;

// binding = 0: Vertex
struct Vertex {
    vec3 Position;
    float tangentW; 
    vec3 Normal;
    vec3 Tangent;  
    vec2 TexCoords;
};
layout(std430, set = 0, binding = 0) readonly buffer VertexBuffer {
    Vertex vertices[];
};

// binding = 1: Index
layout(std430, set = 0, binding = 1) readonly buffer IndexBuffer {
    uint indices[];
};

// binding = 2: BLASNode
struct AABB {
    vec3 min;
    vec3 max;
};
struct BLASNode {
    AABB bounds;
    ivec3 indices;
    int left;
    int right;
};
layout(std430, binding = 2) readonly buffer BLASBuffer {
    BLASNode blasNodes[];
};

// binding = 3: TLASInstance
struct TLASInstance {
    mat4 transform;
    AABB worldBounds;
    int rootNodeIndex; // 对应BLAS根节点位置
    int baseIndexOffset;
    int materialID;
};
layout(std430, binding = 3) readonly buffer TLASInstanceBuffer {
    TLASInstance instances[];
};

// binding = 4: TLASNode
struct TLASNode {
    AABB bounds;
    int left;
    int right;
    int instanceIndex;
};
layout(std430, binding = 4) readonly buffer TLASNodeBuffer {
    TLASNode tlasNodes[];
};

// binding = 5: Material
struct Material {
    vec3 baseColor;
    int baseColorTexture;  // -1 表示无纹理

    vec3 normal;
    int normalTexture;

    float metallic;
    int metallicTexture;

    float roughness;
    int roughnessTexture;
};
layout(std430, set = 0, binding = 5) readonly buffer MaterialBuffer {
    Material materials[];
};

// binding = 6: camera
layout(set = 0, binding = 6) uniform Camera {
    mat4 viewProjectionMatrix; // 视图投影矩阵
    vec3 cameraPosition;       // 摄像机位置
};

// binding = 7: FaceLight
struct FaceLight {
    vec3 vertices; // 顺序：逆时针
    vec3 color;
    float intensity;
};
layout(std430, set = 0, binding = 7) readonly buffer FaceLightBuffer {
    FaceLight faceLights[];
};

// binding = 8: Directional
struct DirectionalLight {
    vec3 direction; // 单位向量
    float intensity;
    vec3 color;
};
layout(std430, set = 0, binding = 8) readonly buffer DirectionalLightBuffer {
    DirectionalLight directionalLights[];
};

// binding = 0: output image (储存最终颜色)
layout(set = 1, binding = 0, rgba8) uniform writeonly image2D resultImage;

// 函数定义
bool intersectAABB(vec3 origin, vec3 dir, AABB aabb);
bool intersectTriangle(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2, out float t);
void traceTLAS_stack(int rootIndex, vec3 rayOrig, vec3 rayDir, inout float minT, inout vec3 hitColor);
void traceBLAS_stack(int rootIndex, vec3 rayOrig, vec3 rayDir, mat4 model, int materialID, int baseIndexOffset, inout float minT, inout vec3 hitColor);

void main() {
    uvec2 pixelCoord = gl_GlobalInvocationID.xy;
    ivec2 imageSize = imageSize(resultImage);
    if (pixelCoord.x >= imageSize.x || pixelCoord.y >= imageSize.y)
        return;

    // 1. 生成 ray
    vec2 ndc = (vec2(pixelCoord) / vec2(imageSize)) * 2.0 - 1.0;
    vec4 rayClip = vec4(ndc, -1.0, 1.0);
    vec4 rayView = inverse(viewProjectionMatrix) * rayClip;
    rayView /= rayView.w;

    vec3 rayDir = normalize(rayView.xyz - cameraPosition);
    vec3 rayOrig = cameraPosition;

    // 2. 初始化 hit 信息
    float minT = 1e20;
    vec3 hitColor = vec3(0.0);

    // 3. 遍历 TLAS 根节点（最后一个）
    int tlasRoot = int(tlasNodes.length()) - 1;
    traceTLAS_stack(tlasRoot, rayOrig, rayDir, minT, hitColor);

    // 4. 输出结果
    imageStore(resultImage, ivec2(pixelCoord), vec4(hitColor, 1.0));
}


// ---------- AABB Intersection ----------
bool intersectAABB(vec3 origin, vec3 dir, AABB aabb) {
    vec3 invDir = 1.0 / dir;
    vec3 t0 = (aabb.min - origin) * invDir;
    vec3 t1 = (aabb.max - origin) * invDir;
    vec3 tmin = min(t0, t1);
    vec3 tmax = max(t0, t1);
    float tEnter = max(max(tmin.x, tmin.y), tmin.z);
    float tExit  = min(min(tmax.x, tmax.y), tmax.z);
    return tExit >= max(tEnter, 0.0);
}

// ---------- Triangle Intersection ----------
bool intersectTriangle(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2, out float t) {
    const float EPSILON = 0.000001;
    vec3 edge1 = v1 - v0;
    vec3 edge2 = v2 - v0;
    vec3 h = cross(dir, edge2);
    float a = dot(edge1, h);
    if (abs(a) < EPSILON)
        return false; // 射线与三角形平行
    float f = 1.0 / a;
    vec3 s = orig - v0;
    float u = f * dot(s, h);
    if (u < 0.0 || u > 1.0)
        return false;
    vec3 q = cross(s, edge1);
    float v = f * dot(dir, q);
    if (v < 0.0 || u + v > 1.0)
        return false;
    t = f * dot(edge2, q);
    return t > EPSILON;
}

// ---------- Non-Recursive Traversal ----------

const int MAX_STACK_SIZE = 128;

void traceTLAS_stack(int rootIndex, vec3 rayOrig, vec3 rayDir, inout float minT, inout vec3 hitColor) {
    int stack[MAX_STACK_SIZE];
    int sp = 0;
    stack[sp++] = rootIndex;

    while (sp > 0) {
        int nodeIndex = stack[--sp];
        TLASNode node = tlasNodes[nodeIndex];
        if (!intersectAABB(rayOrig, rayDir, node.bounds)) continue;

        if (node.left == -1 && node.right == -1) {
            int instanceIndex = node.instanceIndex;
            TLASInstance inst = instances[instanceIndex];
            mat4 model = inst.transform;
            mat4 invModel = inverse(model);

            traceBLAS_stack(inst.rootNodeIndex, rayOrig, rayDir, model, inst.materialID, inst.baseIndexOffset, minT, hitColor);
        } else {
            if (node.right >= 0) stack[sp++] = node.right;
            if (node.left >= 0)  stack[sp++] = node.left;
        }

        if (sp >= MAX_STACK_SIZE) break;
    }
}

void traceBLAS_stack(int rootIndex, vec3 rayOrig, vec3 rayDir, mat4 model, int materialID, int baseIndexOffset, inout float minT, inout vec3 hitColor) {
    int stack[MAX_STACK_SIZE];
    int sp = 0;
    stack[sp++] = rootIndex;

    mat4 invModel = inverse(model);
    vec3 localOrigin = vec3(invModel * vec4(rayOrig, 1.0));
    vec3 localDir = normalize(mat3(invModel) * rayDir);

    while (sp > 0) {
        int nodeIndex = stack[--sp];
        BLASNode node = blasNodes[nodeIndex];
        if (!intersectAABB(localOrigin, localDir, node.bounds)) continue;

        if (node.right < 0 && node.left < 0) {
            for (int i = 0; i < 3; ++i) {
                int localIndex = node.indices[i];
                if(localIndex == -1) break;
                int idx = baseIndexOffset  + localIndex;
                vec3 v0 = vertices[indices[idx + 0]].Position;
                vec3 v1 = vertices[indices[idx + 1]].Position;
                vec3 v2 = vertices[indices[idx + 2]].Position;

                float t;
                if (intersectTriangle(localOrigin, localDir, v0, v1, v2, t)) {
                    // 局部空间交点转为世界空间，计算世界空间下的距离
                    vec3 hitLocal = localOrigin + t * localDir;
                    vec3 hitWorld = vec3(model * vec4(hitLocal, 1.0));
                    float tWorld = length(hitWorld - rayOrig);

                    if (tWorld < minT) {
                        minT = tWorld;
                        hitColor = materials[materialID].baseColor;
                    }
                }
            }
        } else {
            if (node.right >= 0) stack[sp++] = node.right;
            if (node.left >= 0)  stack[sp++] = node.left;
        }

        if (sp >= MAX_STACK_SIZE) break;
    }
}

