package me.liuwj.ktorm.dsl

import me.liuwj.ktorm.database.Database
import me.liuwj.ktorm.database.prepareStatement
import me.liuwj.ktorm.expression.*
import me.liuwj.ktorm.schema.Column
import me.liuwj.ktorm.schema.ColumnDeclaring
import me.liuwj.ktorm.schema.SqlType
import me.liuwj.ktorm.schema.Table
import java.util.*
import kotlin.collections.ArrayList

/**
 * 更新表中的记录，返回受影响的记录数
 */
fun <T : Table<*>> T.update(block: UpdateStatementBuilder.(T) -> Unit): Int {
    val assignments = ArrayList<ColumnAssignmentExpression<*>>()
    val builder = UpdateStatementBuilder(assignments).apply { block(this@update) }

    val expression = AliasRemover.visit(UpdateExpression(asExpression(), assignments, builder.where?.asExpression()))

    expression.prepareStatement { statement, logger ->
        return statement.executeUpdate().also { logger.debug("Effects: {}", it) }
    }
}

/**
 * 批量执行多条更新，返回受影响的记录数
 */
fun <T : Table<*>> T.batchUpdate(block: BatchUpdateStatementBuilder<T>.() -> Unit): IntArray {
    val builder = BatchUpdateStatementBuilder(this).apply(block)
    val expressions = builder.expressions.map { AliasRemover.visit(it) }

    if (expressions.isEmpty()) {
        return IntArray(0)
    } else {
        return expressions.executeBatch()
    }
}

private fun List<SqlExpression>.executeBatch(): IntArray {
    val database = Database.global
    val logger = database.logger
    val (sql, _) = database.formatExpression(this[0])

    if (logger.isDebugEnabled) {
        logger.debug("SQL: $sql")
    }

    database.useConnection { conn ->
        conn.prepareStatement(sql).use { statement ->
            for (expr in this) {
                val (_, args) = database.formatExpression(expr)

                if (logger.isDebugEnabled) {
                    logger.debug("Parameters: " + args.map { "${it.value}(${it.sqlType.typeName})" })
                }

                for ((i, arg) in args.withIndex()) {
                    @Suppress("UNCHECKED_CAST")
                    val sqlType = arg.sqlType as SqlType<Any>
                    sqlType.setParameter(statement, i + 1, arg.value)
                }

                statement.addBatch()
            }

            val effects = statement.executeBatch()

            if (logger.isDebugEnabled) {
                logger.debug("Effects: {}", effects?.contentToString())
            }

            return effects
        }
    }
}

/**
 * 往表中插入一条记录，返回受影响的记录数
 */
fun <T : Table<*>> T.insert(block: AssignmentsBuilder.(T) -> Unit): Int {
    val assignments = ArrayList<ColumnAssignmentExpression<*>>()
    AssignmentsBuilder(assignments).apply { block(this@insert) }

    val expression = AliasRemover.visit(InsertExpression(asExpression(), assignments))

    expression.prepareStatement { statement, logger ->
        return statement.executeUpdate().also { logger.debug("Effects: {}", it) }
    }
}

/**
 * 批量往表中插入记录，返回受影响的记录数
 */
fun <T : Table<*>> T.batchInsert(block: BatchInsertStatementBuilder<T>.() -> Unit): IntArray {
    val builder = BatchInsertStatementBuilder(this).apply(block)
    val expressions = builder.expressions.map { AliasRemover.visit(it) }

    if (expressions.isEmpty()) {
        return IntArray(0)
    } else {
        return expressions.executeBatch()
    }
}

/**
 * 往表中插入记录，并且返回主键
 */
fun <T : Table<*>> T.insertAndGenerateKey(block: AssignmentsBuilder.(T) -> Unit): Any {
    val assignments = ArrayList<ColumnAssignmentExpression<*>>()
    AssignmentsBuilder(assignments).apply { block(this@insertAndGenerateKey) }

    val expression = AliasRemover.visit(InsertExpression(asExpression(), assignments))

    expression.prepareStatement(autoGeneratedKeys = true) { statement, logger ->
        statement.executeUpdate().also { logger.debug("Effects: {}", it) }

        statement.generatedKeys.use { rs ->
            if (rs.next()) {
                val sqlType = primaryKey?.sqlType ?: error("Table $tableName must have a primary key.")
                return sqlType.getResult(rs, 1) ?: error("Generated key is null.")
            } else {
                error("No generated key returns by database.")
            }
        }
    }
}

/**
 * 将当前查询的返回结果批量插入到表中，返回受影响的记录数
 */
fun Query.insertTo(table: Table<*>, vararg columns: Column<*>): Int {
    val expression = InsertFromQueryExpression(
        table = table.asExpression(),
        columns = columns.map { it.asExpression() },
        query = this.expression
    )

    expression.prepareStatement { statement, logger ->
        return statement.executeUpdate().also { logger.debug("Effects: {}", it) }
    }
}

/**
 * 根据条件删除表中的记录，返回受影响的记录数
 */
fun <T : Table<*>> T.delete(block: (T) -> ColumnDeclaring<Boolean>): Int {
    val expression = AliasRemover.visit(DeleteExpression(asExpression(), block(this).asExpression()))

    expression.prepareStatement { statement, logger ->
        return statement.executeUpdate().also { logger.debug("Effects: {}", it) }
    }
}

/**
 * 删除表中所有数据
 */
fun Table<*>.deleteAll(): Int {
    val expression = AliasRemover.visit(DeleteExpression(asExpression(), where = null))

    expression.prepareStatement { statement, logger ->
        return statement.executeUpdate().also { logger.debug("Effects: {}", it) }
    }
}

@DslMarker
annotation class KtormDsl

@KtormDsl
open class AssignmentsBuilder(private val assignments: MutableList<ColumnAssignmentExpression<*>>) {

    infix fun <C : Any> Column<C>.to(expr: ColumnDeclaring<C>) {
        assignments += ColumnAssignmentExpression(asExpression(), expr.asExpression())
    }

    infix fun <C : Any> Column<C>.to(argument: C?) {
        this to wrapArgument(argument)
    }

    @Suppress("UNCHECKED_CAST")
    @JvmName("toAny")
    infix fun Column<*>.to(argument: Any?) {
        if (argument == null) {
            (this as Column<Any>) to (null as Any?)
        } else {
            throw IllegalArgumentException("Argument type ${argument.javaClass.name} cannot assign to ${sqlType.typeName}")
        }
    }
}

@KtormDsl
class UpdateStatementBuilder(assignments: MutableList<ColumnAssignmentExpression<*>>) : AssignmentsBuilder(assignments) {
    internal var where: ColumnDeclaring<Boolean>? = null

    fun where(block: () -> ColumnDeclaring<Boolean>) {
        this.where = block()
    }
}

@KtormDsl
class BatchUpdateStatementBuilder<T : Table<*>>(internal val table: T) {
    internal val expressions = ArrayList<SqlExpression>()
    internal val sqls = HashSet<String>()

    fun item(block: UpdateStatementBuilder.(T) -> Unit) {
        val assignments = ArrayList<ColumnAssignmentExpression<*>>()
        val builder = UpdateStatementBuilder(assignments)
        builder.block(table)

        val expr = UpdateExpression(table.asExpression(), assignments, builder.where?.asExpression())

        val (sql, _) = Database.global.formatExpression(expr, beautifySql = true)

        if (sqls.isEmpty() || sql in sqls) {
            sqls += sql
            expressions += expr
        } else {
            throw IllegalArgumentException("Every item in a batch operation must be the same. SQL: \n\n$sql")
        }
    }
}

@KtormDsl
class BatchInsertStatementBuilder<T : Table<*>>(internal val table: T) {
    internal val expressions = ArrayList<SqlExpression>()
    internal val sqls = HashSet<String>()

    fun item(block: AssignmentsBuilder.(T) -> Unit) {
        val assignments = ArrayList<ColumnAssignmentExpression<*>>()
        val builder = AssignmentsBuilder(assignments)
        builder.block(table)

        val expr = InsertExpression(table.asExpression(), assignments)

        val (sql, _) = Database.global.formatExpression(expr, beautifySql = true)

        if (sqls.isEmpty() || sql in sqls) {
            sqls += sql
            expressions += expr
        } else {
            throw IllegalArgumentException("Every item in a batch operation must be the same. SQL: \n\n$sql")
        }
    }
}

internal object AliasRemover : SqlExpressionVisitor() {

    override fun visitTable(expr: TableExpression): TableExpression {
        if (expr.tableAlias == null) {
            return expr
        } else {
            return expr.copy(tableAlias = null)
        }
    }

    override fun <T : Any> visitColumn(expr: ColumnExpression<T>): ColumnExpression<T> {
        if (expr.tableAlias == null) {
            return expr
        } else {
            return expr.copy(tableAlias = null)
        }
    }
}