#include "TemporalFilter2.h"


void TemporalFilter2::Settings::setArgs(Args& args) const {
    args.setUniform("hysteresis", hysteresis);
    args.setUniform("neighbourhoodRadius", neighbourhoodRadius);
    args.setMacro("USE_VIRTUAL_DISTANCE", useVirtualDistance);
    args.setMacro("USE_COLOR_CLIPPING", useColorClipping);
    args.setMacro("REJECT_IF_VERY_DIFFERENT", rejectIfVeryDifferent);
    args.setMacro("SAVE_WEIGHT", saveWeightToBuffer);
    args.setUniform("falloffStartDistance", falloffStartDistance);
    args.setUniform("falloffEndDistance", falloffEndDistance);
}


void TemporalFilter2::Settings::makeGui(GuiPane* parent) {
    parent->addNumberBox("Hysteresis", &hysteresis, "", GuiTheme::LINEAR_SLIDER, 0.0f, 1.0f);
    parent->addNumberBox("Neighbor Radius", &neighbourhoodRadius, "px", GuiTheme::LINEAR_SLIDER, 0, 3);
    parent->addCheckBox("Use Color Clipping", &useColorClipping);
    parent->addCheckBox("Use Virtual Dist.", &useVirtualDistance);
    parent->addCheckBox("Reject on Depth", &rejectIfVeryDifferent);
    parent->addNumberBox("Falloff Start", &falloffStartDistance, "m", GuiTheme::LINEAR_SLIDER, 0.0f, 1.0f);
    parent->addNumberBox("Falloff End", &falloffEndDistance, "m", GuiTheme::LINEAR_SLIDER, 0.0f, 1.0f);
    parent->addCheckBox("Save Weight", &saveWeightToBuffer);
}


shared_ptr<Texture> TemporalFilter2::apply(RenderDevice* rd, const shared_ptr<Camera>& camera, const shared_ptr<Texture>& unfilteredValue,
    const shared_ptr<Texture>& depth, const shared_ptr<Texture>& ssVelocity, const Vector2& guardBandSize, 
    const shared_ptr<Texture>& weightBuffer, int numFilterComponents, const Settings& settings) {

    const CoordinateFrame& c2w = camera->frame();
    const CoordinateFrame& c2wPrev = camera->previousFrame();
    const Vector3& clipConstant = camera->projection().reconstructFromDepthClipInfo();
    const Vector4& projConstant = camera->projection().reconstructFromDepthProjInfo(depth->width(), depth->height());
    Matrix4 P;
    camera->getProjectPixelMatrix(rd->viewport(), P);
    const Matrix4& prevWorldToScreen = P*Matrix4::diagonal(1,-1,1,1)*camera->previousFrame().inverse();
    return apply(rd, clipConstant, projConstant, c2w, c2wPrev, prevWorldToScreen, unfilteredValue, depth, ssVelocity, guardBandSize, weightBuffer, numFilterComponents, settings);
}


shared_ptr<Texture> TemporalFilter2::apply
    (RenderDevice* rd,
     const Vector3&                  clipConstant,
     const Vector4&                  projConstant,
     const CFrame&                   currentCameraFrame,
     const CFrame&                   prevCameraFrame,
     const Matrix4&                  prevWorldToScreen,
     const shared_ptr<Texture>&      unfilteredValue,
     const shared_ptr<Texture>&      depth,
     const shared_ptr<Texture>&      ssVelocity,
     const Vector2&                  guardBandSize,
     const shared_ptr<Texture>&      weightBuffer,
     int                             numFilterComponents,
     const Settings&                 settings) {

    if (settings.hysteresis == 0.0f) {
        return unfilteredValue;
    }

    alwaysAssertM((settings.hysteresis >= 0.0f) && (settings.hysteresis <= 1.0f), "TemporalFilter2::Settings::hysteresis must be in [0.0, 1.0]");
    alwaysAssertM(notNull(unfilteredValue) && notNull(depth) && notNull(ssVelocity), "Sent null buffer to TemporalFilter::apply");
    alwaysAssertM((numFilterComponents >= 1) && (numFilterComponents <= 4), "numFilterComponents must be between 1 and 4");

    if (isNull(m_previousDepthBuffer) ||
        isNull(m_previousTexture) ||
        (m_previousDepthBuffer->vector2Bounds() != depth->vector2Bounds()) ||
        (m_previousTexture->vector2Bounds() != unfilteredValue->vector2Bounds())) {

        unfilteredValue->copyInto(m_previousTexture);
        depth->copyInto(m_previousDepthBuffer);
        m_resultFramebuffer = Framebuffer::create(Texture::createEmpty("G3D::TemporalFilter2::m_resultFramebuffer", m_previousTexture->width(), m_previousTexture->height(), 
            unfilteredValue->format()));
        
        Texture::copy(m_previousTexture, m_resultFramebuffer->texture(0));

        if (settings.saveWeightToBuffer && weightBuffer) {
            rd->setColorClearValue(Color4::zero());
            weightBuffer->clear();
        }

        return m_resultFramebuffer->texture(0);
    }

    if (settings.saveWeightToBuffer && weightBuffer) {
        m_resultFramebuffer->set(Framebuffer::COLOR1, weightBuffer);
    }

    rd->push2D(m_resultFramebuffer); {
        Args args;
        args.setMacro("FILTER_COMPONENT_COUNT", numFilterComponents);
        ssVelocity->setShaderArgs(args, "ssVelocity_", Sampler::buffer());
        unfilteredValue->setShaderArgs(args, "unfilteredValue_", Sampler::buffer());
        depth->setShaderArgs(args, "depth_", Sampler::buffer());
        m_previousDepthBuffer->setShaderArgs(args, "previousDepth_", Sampler::video());
        m_previousTexture->setShaderArgs(args, "previousValue_", Sampler::video());

        args.setUniform("guardBandSize", guardBandSize);

        args.setUniform("worldToScreenPrevious", prevWorldToScreen);

        args.setUniform("cameraToWorld", currentCameraFrame);
        args.setUniform("cameraToWorldPrevious", prevCameraFrame);

        args.setUniform("clipInfo", clipConstant);
        args.setUniform("projInfo", projConstant);

        settings.setArgs(args);
        if (!(weightBuffer)) {
            args.setMacro("SAVE_WEIGHT", false);
        }
        args.setRect(rd->viewport());

        LAUNCH_SHADER("TemporalFilter2_apply.*", args);
        m_resultFramebuffer->texture(0)->copyInto(m_previousTexture);

        depth->copyInto(m_previousDepthBuffer);
    } rd->pop2D();
    return m_resultFramebuffer->texture(0);
}


TemporalFilter2::Settings::Settings(const Any& a) {
    *this = Settings();

    a.verifyName("TemporalFilter2::Settings");

    AnyTableReader r(a);
    r.getIfPresent("hysteresis", hysteresis);
    r.getIfPresent("neighbourhoodRadius", neighbourhoodRadius);
    r.getIfPresent("useVirtualDistance", useVirtualDistance);
    r.getIfPresent("useColorClipping", useColorClipping);
    r.getIfPresent("falloffStartDistance", falloffStartDistance);
    r.getIfPresent("falloffEndDistance", falloffEndDistance);
    r.verifyDone();
}


Any TemporalFilter2::Settings::toAny() const {
    Any a(Any::TABLE, "TemporalFilter2::Settings");

    a["hysteresis"] = hysteresis;
    a["neighbourhoodRadius"] = neighbourhoodRadius;
    a["useVirtualDistance"] = useVirtualDistance;
    a["useColorClipping"] = useColorClipping;
    a["falloffStartDistance"] = falloffStartDistance;
    a["falloffEndDistance"] = falloffEndDistance;
    return a;
}

