#version 450
#pragma shader_stage(compute)

layout(local_size_x = 16, local_size_y = 16) in;

// binding = 0: Vertex
struct Vertex {
    vec3 Position;
    float tangentW; 
    vec3 Normal;
    vec3 Tangent;  
    vec2 TexCoords;
};
layout(std430, set = 0, binding = 0) readonly buffer VertexBuffer {
    Vertex vertices[];
};

// binding = 1: Index
layout(std430, set = 0, binding = 1) readonly buffer IndexBuffer {
    uint indices[];
};

// binding = 2: Material
struct Material {
    vec3 baseColor;
    int baseColorTexture;  // -1 表示无纹理

    vec3 normal;
    int normalTexture;

    float metallic;
    int metallicTexture;

    float roughness;
    int roughnessTexture;
};
layout(std430, set = 0, binding = 2) readonly buffer MaterialBuffer {
    Material materials[];
};

// binding = 3: Object
struct Object {
    mat4 transform;
    uint indexOffset;
    uint indexCount;
    uint materialID;
};

layout(std430, set = 0, binding = 3) readonly buffer ObjectBuffer {
    Object objects[];
};

// binding = 4: camera
layout(set = 0, binding = 4) uniform Camera {
    mat4 viewProjectionMatrix; // 视图投影矩阵
    vec3 cameraPosition;       // 摄像机位置
};

// binding = 5: FaceLight
struct FaceLight {
    vec3 vertices; // 顺序：逆时针
    vec3 color;
    float intensity;
};
layout(std430, set = 0, binding = 5) readonly buffer FaceLightBuffer {
    FaceLight faceLights[];
};

// binding = 6: Directional
struct DirectionalLight {
    vec3 direction; // 单位向量
    float intensity;
    vec3 color;
};
layout(std430, set = 0, binding = 6) readonly buffer DirectionalLightBuffer {
    DirectionalLight directionalLights[];
};

// binding = 0: output image (储存最终颜色)
layout(set = 1, binding = 0, rgba8) uniform writeonly image2D resultImage;


bool intersectTriangle(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2, out float t);

void main() {
    uvec2 pixelCoord = gl_GlobalInvocationID.xy;
    ivec2 imageSize = imageSize(resultImage);
    if (pixelCoord.x >= imageSize.x || pixelCoord.y >= imageSize.y)
        return;

    // 1. 生成 ray
    vec2 ndc = (vec2(pixelCoord) / vec2(imageSize)) * 2.0 - 1.0;
    vec4 rayClip = vec4(ndc, -1.0, 1.0);  // clip space -Z for forward
    vec4 rayView = inverse(viewProjectionMatrix) * rayClip;
    rayView /= rayView.w;
    vec3 rayDir = normalize(rayView.xyz - cameraPosition);
    vec3 rayOrigin = cameraPosition;

    float minT = 1e20;
    vec3 hitColor = vec3(0.0);

    // 2. 遍历物体
    for (uint objID = 0; objID < objects.length(); ++objID) {
        Object obj = objects[objID];
        mat4 model = obj.transform;
        uint idxStart = obj.indexOffset;
        uint idxEnd = idxStart + obj.indexCount;

        for (uint i = idxStart; i + 2 < idxEnd; i += 3) {
            uint i0 = indices[i];
            uint i1 = indices[i+1];
            uint i2 = indices[i+2];

            vec3 v0 = (model * vec4(vertices[i0].Position, 1.0)).xyz;
            vec3 v1 = (model * vec4(vertices[i1].Position, 1.0)).xyz;
            vec3 v2 = (model * vec4(vertices[i2].Position, 1.0)).xyz;

            float t;
            if (intersectTriangle(rayOrigin, rayDir, v0, v1, v2, t)) {
                if (t < minT) {
                    minT = t;
                    hitColor = materials[obj.materialID].baseColor;
                }
            }
        }
    }

    vec4 finalColor = vec4(hitColor, 1.0);
    imageStore(resultImage, ivec2(pixelCoord), finalColor);
}


bool intersectTriangle(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2, out float t) {
    const float EPSILON = 0.000001;
    vec3 edge1 = v1 - v0;
    vec3 edge2 = v2 - v0;
    vec3 h = cross(dir, edge2);
    float a = dot(edge1, h);
    if (abs(a) < EPSILON)
        return false; // 射线与三角形平行
    float f = 1.0 / a;
    vec3 s = orig - v0;
    float u = f * dot(s, h);
    if (u < 0.0 || u > 1.0)
        return false;
    vec3 q = cross(s, edge1);
    float v = f * dot(dir, q);
    if (v < 0.0 || u + v > 1.0)
        return false;
    t = f * dot(edge2, q);
    return t > EPSILON;
}
