diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/NamedQuery.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/NamedQuery.kt index f84dbdfad6c..69cab9ec988 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/NamedQuery.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/NamedQuery.kt @@ -36,9 +36,9 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.TEXT import app.cash.sqldelight.dialect.api.QueryWithResults import com.alecstrong.sql.psi.core.psi.NamedElement import com.alecstrong.sql.psi.core.psi.QueryElement -import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement import com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt import com.alecstrong.sql.psi.core.psi.SqlExpr +import com.alecstrong.sql.psi.core.psi.SqlPragmaName import com.alecstrong.sql.psi.core.psi.SqlValuesExpression import com.intellij.psi.PsiElement import com.squareup.kotlinpoet.ClassName @@ -188,6 +188,7 @@ data class NamedQuery( private fun PsiElement.functionName() = when (this) { is NamedElement -> allocateName(this) is SqlExpr -> name + is SqlPragmaName -> text else -> throw IllegalStateException("Cannot get name for type ${this.javaClass}") } @@ -220,30 +221,3 @@ data class NamedQuery( // name -> query name get() = getUniqueQueryIdentifier(statement.sqFile().let { "${it.packageName}:${it.name}:$name" }) } - -class SelectQueryable( - override val select: SqlCompoundSelectStmt, - override var statement: SqlAnnotatedElement = select, -) : QueryWithResults { - - /** - * If this query is a pure select from a table (virtual or otherwise), this returns the LazyQuery - * which points to that table (Pure meaning it has exactly the same columns in the same order). - */ - override val pureTable: NamedElement? by lazy { - fun List.flattenCompounded(): List { - return map { column -> - if (column.compounded.none { it.element != column.element || it.nullable != column.nullable }) { - column.copy(compounded = emptyList()) - } else { - column - } - } - } - - val pureColumns = select.queryExposed().singleOrNull()?.columns?.flattenCompounded() - return@lazy select.tablesAvailable(select).firstOrNull { - it.query.columns.flattenCompounded() == pureColumns - }?.tableName - } -} diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/PragmaWithResults.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/PragmaWithResults.kt new file mode 100644 index 00000000000..e201cca247d --- /dev/null +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/PragmaWithResults.kt @@ -0,0 +1,24 @@ +package app.cash.sqldelight.core.compiler.model + +import app.cash.sqldelight.dialect.api.QueryWithResults +import com.alecstrong.sql.psi.core.psi.NamedElement +import com.alecstrong.sql.psi.core.psi.QueryElement +import com.alecstrong.sql.psi.core.psi.QueryElement.QueryResult +import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement +import com.alecstrong.sql.psi.core.psi.SqlPragmaStmt +import com.alecstrong.sql.psi.core.psi.impl.SqlPragmaNameImpl +import com.intellij.lang.ASTNode + +class PragmaWithResults(private val pragmaStmt: SqlPragmaStmt) : QueryWithResults { + override var statement: SqlAnnotatedElement = pragmaStmt + override val select: QueryElement = pragmaStmt.pragmaName as SqlDelightPragmaName + override val pureTable: NamedElement? = null +} + +internal class SqlDelightPragmaName(node: ASTNode?) : SqlPragmaNameImpl(node), QueryElement { + override fun queryExposed() = listOf( + QueryResult( + column = this + ) + ) +} diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt new file mode 100644 index 00000000000..3697cce4608 --- /dev/null +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/SelectQueryable.kt @@ -0,0 +1,34 @@ +package app.cash.sqldelight.core.compiler.model + +import app.cash.sqldelight.dialect.api.QueryWithResults +import com.alecstrong.sql.psi.core.psi.NamedElement +import com.alecstrong.sql.psi.core.psi.QueryElement.QueryColumn +import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement +import com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt + +class SelectQueryable( + override val select: SqlCompoundSelectStmt, + override var statement: SqlAnnotatedElement = select, +) : QueryWithResults { + + /** + * If this query is a pure select from a table (virtual or otherwise), this returns the LazyQuery + * which points to that table (Pure meaning it has exactly the same columns in the same order). + */ + override val pureTable: NamedElement? by lazy { + fun List.flattenCompounded(): List { + return map { column -> + if (column.compounded.none { it.element != column.element || it.nullable != column.nullable }) { + column.copy(compounded = emptyList()) + } else { + column + } + } + } + + val pureColumns = select.queryExposed().singleOrNull()?.columns?.flattenCompounded() + return@lazy select.tablesAvailable(select).firstOrNull { + it.query.columns.flattenCompounded() == pureColumns + }?.tableName + } +} diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt index a3cbce95787..2a5e0aa419d 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt @@ -2,6 +2,7 @@ package app.cash.sqldelight.core.lang import app.cash.sqldelight.core.SqlDelightProjectService import app.cash.sqldelight.core.SqldelightParserUtil +import app.cash.sqldelight.core.compiler.model.SqlDelightPragmaName import app.cash.sqldelight.core.lang.psi.FunctionExprMixin import app.cash.sqldelight.dialect.api.SqlDelightDialect import com.alecstrong.sql.psi.core.SqlParserUtil @@ -25,6 +26,7 @@ internal class ParserUtil { SqldelightParserUtil.createElement = { when (it.elementType) { SqlTypes.FUNCTION_EXPR -> FunctionExprMixin(it) + SqlTypes.PRAGMA_NAME -> SqlDelightPragmaName(it) else -> currentElementCreation(it) } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/ExprUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/ExprUtil.kt index ac1eac4bcb8..91a9d721e6d 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/ExprUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/ExprUtil.kt @@ -16,6 +16,7 @@ package app.cash.sqldelight.core.lang.util import app.cash.sqldelight.core.compiler.SqlDelightCompiler.allocateName +import app.cash.sqldelight.core.compiler.model.PragmaWithResults import app.cash.sqldelight.core.compiler.model.SelectQueryable import app.cash.sqldelight.core.lang.types.typeResolver import app.cash.sqldelight.dialect.api.IntermediateType @@ -159,6 +160,9 @@ internal class AnsiSqlTypeResolver : TypeResolver { override fun queryWithResults(sqlStmt: SqlStmt): QueryWithResults? { sqlStmt.compoundSelectStmt?.let { return SelectQueryable(it) } + sqlStmt.pragmaStmt?.let { + if (it.pragmaValue == null) return PragmaWithResults(it) + } return null } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt index e79f6878deb..341a89176fb 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt @@ -31,6 +31,7 @@ import com.alecstrong.sql.psi.core.psi.SqlCreateViewStmt import com.alecstrong.sql.psi.core.psi.SqlCreateVirtualTableStmt import com.alecstrong.sql.psi.core.psi.SqlExpr import com.alecstrong.sql.psi.core.psi.SqlModuleArgument +import com.alecstrong.sql.psi.core.psi.SqlPragmaName import com.alecstrong.sql.psi.core.psi.SqlTableName import com.alecstrong.sql.psi.core.psi.SqlTypeName import com.alecstrong.sql.psi.core.psi.SqlTypes @@ -54,6 +55,7 @@ internal fun PsiElement.type(): IntermediateType = when (this) { is SqlTypeName -> sqFile().typeResolver.definitionType(this) is AliasElement -> source().type().copy(name = name) is ColumnDefMixin -> (columnType as ColumnTypeMixin).type() + is SqlPragmaName -> IntermediateType(TEXT) is SqlColumnName -> { when (val parentRule = parent) { is ColumnDefMixin -> parentRule.type() diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/SelectQueryTypeTest.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/SelectQueryTypeTest.kt index 4f1417200c9..5fb733b4219 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/SelectQueryTypeTest.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/SelectQueryTypeTest.kt @@ -1634,4 +1634,26 @@ class SelectQueryTypeTest { |""".trimMargin() ) } + + @Test + fun `pragma with results`() { + val file = FixtureCompiler.parseSql( + """ + |getVersion: + |PRAGMA user_version; + |""".trimMargin(), + tempFolder + ) + + val query = file.namedQueries.first() + val generator = SelectQueryGenerator(query) + + assertThat(generator.customResultTypeFunction().toString()).isEqualTo( + """ + |public fun getVersion(): app.cash.sqldelight.ExecutableQuery = app.cash.sqldelight.Query(${query.id}, driver, "Test.sq", "getVersion", "PRAGMA user_version") { cursor -> + | cursor.getString(0)!! + |} + |""".trimMargin() + ) + } }