#include "db.h"
#include "config.h"
#include "general.h"
#include <cstdlib>
#include <filesystem>
#include <sqlite3.h>
#include <dpp/dpp.h>

QueueEntity DbUtil::getQueue(dpp::snowflake guildId)
{
    sqlite3 *connection = connectDb();

    std::string guildIdString = std::to_string(guildId);
    std::string sql = "SELECT id, guild_id, name, current_item_id FROM `queue` WHERE `guild_id` = ? LIMIT 1;";
    sqlite3_stmt *statement;

    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_text(statement, 1, guildIdString.c_str(), guildIdString.length(), SQLITE_TRANSIENT);

    const std::string defaultQueueName = "Default";
    QueueEntity queue(guildId, defaultQueueName);

    int res = sqlite3_step(statement);
    if (res != SQLITE_ROW)
    {
        //No queue for current guild --> create one
        sql = "INSERT INTO `queue` (`guild_id`, `name`) VALUES  (?, ?)";

        sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

        sqlite3_bind_text(statement, 1, guildIdString.c_str(), guildIdString.length(), SQLITE_TRANSIENT);

        sqlite3_bind_text(statement, 2, defaultQueueName.c_str(), defaultQueueName.length(), SQLITE_TRANSIENT);

        sqlite3_step(statement);

        //Get last inserted id
        sql = "SELECT last_insert_rowid() AS id;";
        sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);
        sqlite3_step(statement);

        queue.id = sqlite3_column_int(statement, 0);
    }
    else if (res == SQLITE_ROW)
    {
        //Get first queue entity
        queue = QueueEntity(statement);
    }

    closeDb(connection, statement);

    return queue;
}

std::vector<QueueItemEntity> DbUtil::getQueueItems(int queueId, int offset)
{
    sqlite3 *connection = connectDb();
    std::string sql = "SELECT id, queue_id, source, display_name, position, duration, ROW_NUMBER() OVER (PARTITION by `queue_id` ORDER BY `position`) AS `row_number` FROM `queue_item` WHERE `queue_id` = ? ORDER BY `position`";

    if (offset != -1)
    {
        //Add offset to query
        sql += " LIMIT " + std::to_string(offset) + ",10";
    }

    sqlite3_stmt *statement;

    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_int(statement, 1, queueId);

    //Create list of queueItems
    std::vector<QueueItemEntity> queueItems;
    while (sqlite3_step(statement) == SQLITE_ROW)
    {
        queueItems.push_back(QueueItemEntity(statement));
    }

    closeDb(connection, statement);
    return queueItems;
}

int DbUtil::getCountQueueItems(int queueId)
{
    sqlite3 *connection = connectDb();

    std::string sql = "SELECT COUNT(*) AS `count` FROM `queue_item` WHERE `queue_id` = ?";

    sqlite3_stmt *statement;

    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_int(statement, 1, queueId);

    //Get count from result
    sqlite3_step(statement);
    int count = sqlite3_column_int(statement, 0);

    closeDb(connection, statement);

    return count;
}

QueueItemEntity DbUtil::getCurrentQueueItem(int currentItemId)
{
    sqlite3 *connection = connectDb();

    std::string sql = "SELECT id, queue_id, source, display_name, position, duration, ROW_NUMBER() OVER (PARTITION by `queue_id` ORDER BY `position`) AS `row_number` FROM `queue_item` ORDER BY id=? DESC LIMIT 1";

    sqlite3_stmt *statement;

    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_int(statement, 1, currentItemId);

    QueueItemEntity queueItem = QueueItemEntity();
    if (sqlite3_step(statement) == SQLITE_ROW)
    {
        queueItem = QueueItemEntity(statement);
    }

    closeDb(connection, statement);
    return queueItem;
}

void DbUtil::setCurrentQueueItem(int queueId, int queueItemId)
{
    sqlite3 *connection = connectDb();

    //Create SQL INSERT command
    std::string sql = "UPDATE `queue` SET `current_item_id` = ? WHERE `id` = ?";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_int(statement, 1, queueItemId);

    sqlite3_bind_int(statement, 2, queueId);

    sqlite3_step(statement);
    closeDb(connection, statement);
}

void DbUtil::saveQueueItems(int queueId, std::vector<QueueItemEntity> queueItems)
{
    sqlite3 *connection = connectDb();

    //Create SQL INSERT command
    std::string sql = "INSERT INTO `queue_item` VALUES ";
    for (int i = 0; i < queueItems.size(); i++)
    {
        if (queueItems[i].id > 0)
        {
            sql += "(" + std::to_string(queueItems[i].id) + ", ?, ?, ?, ?, ?, ?) ";
        }
        else
        {
            sql += "(NULL, ?, ?, ?, ?, ?, ?) ";
        }

        if (i < queueItems.size() - 1)
        {
            //Add comma except for last entry
            sql += ",";
        }
    }
    sql += " ON CONFLICT(id) DO UPDATE SET `source`=excluded.`source`, `display_name`=excluded.`display_name`, `type`=excluded.`type`, `duration`=excluded.`duration`, `position`=excluded.`position`;";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    int numberParameters = 6;
    for (int i = 0; i < queueItems.size(); i++)
    {
        QueueItemEntity queueItem = queueItems[i];

        sqlite3_bind_int(statement, numberParameters * i + 1, queueId);
        sqlite3_bind_text(statement, numberParameters * i + 2, queueItem.source.c_str(), queueItem.source.length(), SQLITE_TRANSIENT);
        sqlite3_bind_text(statement, numberParameters * i + 3, queueItem.displayName.c_str(), queueItem.displayName.length(), SQLITE_TRANSIENT);
        sqlite3_bind_text(statement, numberParameters * i + 4, queueItem.type.c_str(), queueItem.type.length(), SQLITE_TRANSIENT);
        sqlite3_bind_int(statement, numberParameters * i + 5, queueItem.duration);
        sqlite3_bind_int(statement, numberParameters * i + 6, queueItem.position);
    }

    sqlite3_step(statement);

    closeDb(connection, statement);
}

bool DbUtil::removeQueueItem(QueueItemEntity queueItem)
{
    bool ret = false;
    sqlite3 *connection = connectDb();

    std::string sql = "DELETE FROM `queue_item` WHERE id=?";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_int(statement, 1, queueItem.id);

    if (sqlite3_step(statement) == SQLITE_DONE)
    {
        ret = true;
    }

    closeDb(connection, statement);

    return ret;
}

SearchEntity DbUtil::getSearch(std::string searchTerm)
{
    SearchEntity ret = SearchEntity();
    sqlite3 *connection = connectDb();

    std::string sql = "SELECT id, term, result FROM `search` WHERE `term` = ?;";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_text(statement, 1, searchTerm.c_str(), searchTerm.length(), SQLITE_TRANSIENT);

    //Create search entity
    if (sqlite3_step(statement) == SQLITE_ROW)
    {
        ret = SearchEntity(statement);
    }

    closeDb(connection, statement);

    return ret;
}

QueueItemEntity DbUtil::getQueueItemByDisplayName(std::string searchTerm)
{
    QueueItemEntity ret = QueueItemEntity();

    //Search by title (Displayname) in db
    std::string sql = "SELECT id, queue_id, source, display_name, position, duration, ROW_NUMBER() OVER (PARTITION by `queue_id` ORDER BY `position`) AS `row_number` FROM `queue_item`  WHERE `display_name` LIKE ? ORDER BY `position` DESC;";
    searchTerm = "%" + searchTerm + "%";

    sqlite3 *connection = connectDb();
    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);
    sqlite3_bind_text(statement, 1, searchTerm.c_str(), searchTerm.length(), SQLITE_TRANSIENT);

    if (sqlite3_step(statement) == SQLITE_ROW)
    {
        ret = QueueItemEntity(statement);
    }
    closeDb(connection, statement);

    return ret;
}

void DbUtil::saveSearch(SearchEntity search)
{
    sqlite3 *connection = connectDb();

    //Create SQL INSERT command
    std::string sql = "INSERT INTO `search` (term, result, created_at) VALUES (?, ?, datetime('now'));";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    sqlite3_bind_text(statement, 1, search.term.c_str(), search.term.length(), SQLITE_TRANSIENT);
    sqlite3_bind_text(statement, 2, search.result.c_str(), search.result.length(), SQLITE_TRANSIENT);

    sqlite3_step(statement);

    closeDb(connection, statement);
}

bool DbUtil::saveShortcut(ShortcutEntity shortcut)
{
    sqlite3 *connection = connectDb();

    std::string guildIdString = std::to_string(shortcut.guildId);

    //Create SQL INSERT command
    std::string sql = "INSERT INTO `shortcut` (`command`, `query`, `guild_id`) VALUES (?, ?, ?);";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);
    sqlite3_bind_text(statement, 2, shortcut.query.c_str(), shortcut.query.length(), SQLITE_TRANSIENT);
    sqlite3_bind_text(statement, 3, guildIdString.c_str(), guildIdString.length(), SQLITE_TRANSIENT);

    //Save aliases in separate lines
    sqlite3_exec(connection, "BEGIN TRANSACTION;", NULL, NULL, NULL);

    bool ret = false;
    for(const auto command : GeneralUtil::explode(shortcut.command, ';')) {
		sqlite3_bind_text(statement, 1, command.c_str(), command.length(), SQLITE_STATIC);
		ret = sqlite3_step(statement) == SQLITE_DONE;

		if(!ret) {
			sqlite3_exec(connection, "ROLLBACK;", NULL, NULL, NULL);
		}
    }

    if(ret) {
		sqlite3_exec(connection, "COMMIT;", NULL, NULL, NULL);
    }

    closeDb(connection, statement);

    return ret;
}

std::vector<ShortcutEntity> DbUtil::getShortcuts(dpp::snowflake guildId)
{
    sqlite3 *connection = connectDb();
    std::string sql = "SELECT `id`, `command`, `query` FROM `shortcut` WHERE `guild_id` = ? ORDER BY `query`, `command`;";

    sqlite3_stmt *statement;

    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    std::string guildIdString = std::to_string(guildId);
    sqlite3_bind_text(statement, 1, guildIdString.c_str(), guildIdString.length(), SQLITE_TRANSIENT);

    std::vector<ShortcutEntity> shortcuts;

    //Create list of queueItems
    while (sqlite3_step(statement) == SQLITE_ROW)
    {
        shortcuts.push_back(ShortcutEntity(statement));
    }

    closeDb(connection, statement);
    return shortcuts;
}

ShortcutEntity DbUtil::getShortcutByCommand(std::string command, dpp::snowflake guildId)
{
    sqlite3 *connection = connectDb();
    std::string sql = "SELECT `id`, `command`, `query` FROM `shortcut` WHERE `guild_id` = ? AND `command` = ? ORDER BY `command`, `query`;";

    sqlite3_stmt *statement;

    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    std::string guildIdString = std::to_string(guildId);
    sqlite3_bind_text(statement, 1, guildIdString.c_str(), guildIdString.length(), SQLITE_TRANSIENT);
    sqlite3_bind_text(statement, 2, command.c_str(), command.length(), SQLITE_TRANSIENT);

    ShortcutEntity shortcut;

    //Create list of queueItems
    if (sqlite3_step(statement) == SQLITE_ROW)
    {
        shortcut = ShortcutEntity(statement);
    }

    closeDb(connection, statement);
    return shortcut;
}

bool DbUtil::removeShortcutByCommandOrQuery(std::string commandOrQuery, dpp::snowflake guildId)
{
    bool ret = false;
    sqlite3 *connection = connectDb();

    std::string sql = "DELETE FROM `shortcut` WHERE `guild_id` = ? AND (`command` = ? OR `query` = ?);";

    sqlite3_stmt *statement;
    sqlite3_prepare_v2(connection, sql.c_str(), sql.length(), &statement, nullptr);

    std::string guildIdString = std::to_string(guildId);
    sqlite3_bind_text(statement, 1, guildIdString.c_str(), guildIdString.length(), SQLITE_TRANSIENT);
    sqlite3_bind_text(statement, 2, commandOrQuery.c_str(), commandOrQuery.length(), SQLITE_TRANSIENT);
    sqlite3_bind_text(statement, 3, commandOrQuery.c_str(), commandOrQuery.length(), SQLITE_TRANSIENT);

    if (sqlite3_step(statement) == SQLITE_DONE)
    {
        ret = true;
    }

    closeDb(connection, statement);

    return ret;
}

sqlite3 *DbUtil::connectDb()
{
    //Create a connection and connect to database
    std::string sqlitePath = ConfigUtil::get("sqlite_path");
    if (sqlitePath.empty())
    {
        std::cerr << "Please set the sqlite_path config option" << std::endl;
        std::exit(9);
    }

    if (!std::filesystem::is_regular_file(sqlitePath))
    {
        //Sqlite file not found --> try to create it
        auto indexOfLastDirectorySeparator = sqlitePath.rfind("/");
        if (indexOfLastDirectorySeparator == std::string::npos)
        {
            indexOfLastDirectorySeparator = sqlitePath.rfind("\\");
        }

        //Create directory/directories for DB file
        std::string directory = sqlitePath.substr(0, indexOfLastDirectorySeparator);
        if (!std::filesystem::exists(directory))
        {
            if (!std::filesystem::create_directories(directory))
            {
                std::cerr << "Error creating directory for database file. Please check the sqlite_path config option" << std::endl;
                std::exit(9);
            }
        }

        //Create DB file
        std::ofstream dbFile(sqlitePath);
        dbFile.close();
    }

    sqlite3 *connection;
    if (sqlite3_open(sqlitePath.c_str(), &connection))
    {
        std::cout << "Can't open database: " << sqlite3_errmsg(connection) << std::endl;
        std::exit(9);
    }

    return connection;
}

void DbUtil::closeDb(sqlite3 *connection, sqlite3_stmt *statement)
{
    sqlite3_finalize(statement);
    closeDb(connection);
}

void DbUtil::closeDb(sqlite3 *connection)
{
    sqlite3_close(connection);
}
