Thanks to visit codestin.com
Credit goes to docs.rs

datafusion_postgres/
handlers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::{DataType, Field, Schema};
6use datafusion::common::{ParamValues, ToDFSchema};
7use datafusion::error::DataFusionError;
8use datafusion::logical_expr::LogicalPlan;
9use datafusion::prelude::*;
10use datafusion::sql::parser::Statement;
11use datafusion::sql::sqlparser;
12use log::info;
13use pgwire::api::auth::noop::NoopStartupHandler;
14use pgwire::api::auth::StartupHandler;
15use pgwire::api::portal::{Format, Portal};
16use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
17use pgwire::api::results::{
18    DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag,
19};
20use pgwire::api::stmt::QueryParser;
21use pgwire::api::stmt::StoredStatement;
22use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::response::TransactionStatus;
25
26use crate::auth::AuthManager;
27use crate::client;
28use crate::hooks::set_show::SetShowHook;
29use crate::hooks::QueryHook;
30use arrow_pg::datatypes::df;
31use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
32use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
33use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
34
35/// Simple startup handler that does no authentication
36/// For production, use DfAuthSource with proper pgwire authentication handlers
37pub struct SimpleStartupHandler;
38
39#[async_trait::async_trait]
40impl NoopStartupHandler for SimpleStartupHandler {}
41
42pub struct HandlerFactory {
43    pub session_service: Arc<DfSessionService>,
44}
45
46impl HandlerFactory {
47    pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
48        let session_service =
49            Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
50        HandlerFactory { session_service }
51    }
52
53    pub fn new_with_hooks(
54        session_context: Arc<SessionContext>,
55        auth_manager: Arc<AuthManager>,
56        query_hooks: Vec<Arc<dyn QueryHook>>,
57    ) -> Self {
58        let session_service = Arc::new(DfSessionService::new_with_hooks(
59            session_context,
60            auth_manager.clone(),
61            query_hooks,
62        ));
63        HandlerFactory { session_service }
64    }
65}
66
67impl PgWireServerHandlers for HandlerFactory {
68    fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
69        self.session_service.clone()
70    }
71
72    fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
73        self.session_service.clone()
74    }
75
76    fn startup_handler(&self) -> Arc<impl StartupHandler> {
77        Arc::new(SimpleStartupHandler)
78    }
79
80    fn error_handler(&self) -> Arc<impl ErrorHandler> {
81        Arc::new(LoggingErrorHandler)
82    }
83}
84
85struct LoggingErrorHandler;
86
87impl ErrorHandler for LoggingErrorHandler {
88    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
89    where
90        C: ClientInfo,
91    {
92        info!("Sending error: {error}")
93    }
94}
95
96/// The pgwire handler backed by a datafusion `SessionContext`
97pub struct DfSessionService {
98    session_context: Arc<SessionContext>,
99    parser: Arc<Parser>,
100    auth_manager: Arc<AuthManager>,
101    query_hooks: Vec<Arc<dyn QueryHook>>,
102}
103
104impl DfSessionService {
105    pub fn new(
106        session_context: Arc<SessionContext>,
107        auth_manager: Arc<AuthManager>,
108    ) -> DfSessionService {
109        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(SetShowHook)];
110        Self::new_with_hooks(session_context, auth_manager, hooks)
111    }
112
113    pub fn new_with_hooks(
114        session_context: Arc<SessionContext>,
115        auth_manager: Arc<AuthManager>,
116        query_hooks: Vec<Arc<dyn QueryHook>>,
117    ) -> DfSessionService {
118        let parser = Arc::new(Parser {
119            session_context: session_context.clone(),
120            sql_parser: PostgresCompatibilityParser::new(),
121            query_hooks: query_hooks.clone(),
122        });
123        DfSessionService {
124            session_context,
125            parser,
126            auth_manager,
127            query_hooks,
128        }
129    }
130
131    /// Check if the current user has permission to execute a query
132    async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
133    where
134        C: ClientInfo,
135    {
136        // Get the username from client metadata
137        let username = client
138            .metadata()
139            .get("user")
140            .map(|s| s.as_str())
141            .unwrap_or("anonymous");
142
143        // Parse query to determine required permissions
144        let query_lower = query.to_lowercase();
145        let query_trimmed = query_lower.trim();
146
147        let (required_permission, resource) = if query_trimmed.starts_with("select") {
148            (Permission::Select, self.extract_table_from_query(query))
149        } else if query_trimmed.starts_with("insert") {
150            (Permission::Insert, self.extract_table_from_query(query))
151        } else if query_trimmed.starts_with("update") {
152            (Permission::Update, self.extract_table_from_query(query))
153        } else if query_trimmed.starts_with("delete") {
154            (Permission::Delete, self.extract_table_from_query(query))
155        } else if query_trimmed.starts_with("create table")
156            || query_trimmed.starts_with("create view")
157        {
158            (Permission::Create, ResourceType::All)
159        } else if query_trimmed.starts_with("drop") {
160            (Permission::Drop, self.extract_table_from_query(query))
161        } else if query_trimmed.starts_with("alter") {
162            (Permission::Alter, self.extract_table_from_query(query))
163        } else {
164            // For other queries (SHOW, EXPLAIN, etc.), allow all users
165            return Ok(());
166        };
167
168        // Check permission
169        let has_permission = self
170            .auth_manager
171            .check_permission(username, required_permission, resource)
172            .await;
173
174        if !has_permission {
175            return Err(PgWireError::UserError(Box::new(
176                pgwire::error::ErrorInfo::new(
177                    "ERROR".to_string(),
178                    "42501".to_string(), // insufficient_privilege
179                    format!("permission denied for user \"{username}\""),
180                ),
181            )));
182        }
183
184        Ok(())
185    }
186
187    /// Extract table name from query (simplified parsing)
188    fn extract_table_from_query(&self, query: &str) -> ResourceType {
189        let words: Vec<&str> = query.split_whitespace().collect();
190
191        // Simple heuristic to find table names
192        for (i, word) in words.iter().enumerate() {
193            let word_lower = word.to_lowercase();
194            if (word_lower == "from" || word_lower == "into" || word_lower == "table")
195                && i + 1 < words.len()
196            {
197                let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
198                return ResourceType::Table(table_name.to_string());
199            }
200        }
201
202        // If we can't determine the table, default to All
203        ResourceType::All
204    }
205
206    async fn try_respond_transaction_statements<C>(
207        &self,
208        client: &C,
209        query_lower: &str,
210    ) -> PgWireResult<Option<Response>>
211    where
212        C: ClientInfo,
213    {
214        // Transaction handling based on pgwire example:
215        // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
216        match query_lower.trim() {
217            "begin" | "begin transaction" | "begin work" | "start transaction" => {
218                match client.transaction_status() {
219                    TransactionStatus::Idle => {
220                        Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
221                    }
222                    TransactionStatus::Transaction => {
223                        // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS
224                        // This matches PostgreSQL's handling of nested transaction blocks
225                        log::warn!("BEGIN command ignored: already in transaction block");
226                        Ok(Some(Response::Execution(Tag::new("BEGIN"))))
227                    }
228                    TransactionStatus::Error => {
229                        // Can't start new transaction from failed state
230                        Err(PgWireError::UserError(Box::new(
231                            pgwire::error::ErrorInfo::new(
232                                "ERROR".to_string(),
233                                "25P01".to_string(),
234                                "current transaction is aborted, commands ignored until end of transaction block".to_string(),
235                            ),
236                        )))
237                    }
238                }
239            }
240            "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
241                match client.transaction_status() {
242                    TransactionStatus::Idle | TransactionStatus::Transaction => {
243                        Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
244                    }
245                    TransactionStatus::Error => {
246                        Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
247                    }
248                }
249            }
250            "rollback" | "rollback transaction" | "rollback work" | "abort" => {
251                Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
252            }
253            _ => Ok(None),
254        }
255    }
256}
257
258#[async_trait]
259impl SimpleQueryHandler for DfSessionService {
260    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
261    where
262        C: ClientInfo + Unpin + Send + Sync,
263    {
264        log::debug!("Received query: {query}"); // Log the query for debugging
265
266        // Check for transaction commands early to avoid SQL parsing issues with ABORT
267        let query_lower = query.to_lowercase().trim().to_string();
268        if let Some(resp) = self
269            .try_respond_transaction_statements(client, &query_lower)
270            .await?
271        {
272            return Ok(vec![resp]);
273        }
274
275        let statements = self
276            .parser
277            .sql_parser
278            .parse(query)
279            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
280
281        // empty query
282        if statements.is_empty() {
283            return Ok(vec![Response::EmptyQuery]);
284        }
285
286        let mut results = vec![];
287        'stmt: for statement in statements {
288            // TODO: improve statement check by using statement directly
289            let query = statement.to_string();
290            let query_lower = query.to_lowercase().trim().to_string();
291
292            // Check permissions for the query (skip for SET, transaction, and SHOW statements)
293            if !query_lower.starts_with("set")
294                && !query_lower.starts_with("begin")
295                && !query_lower.starts_with("commit")
296                && !query_lower.starts_with("rollback")
297                && !query_lower.starts_with("start")
298                && !query_lower.starts_with("end")
299                && !query_lower.starts_with("abort")
300                && !query_lower.starts_with("show")
301            {
302                self.check_query_permission(client, &query).await?;
303            }
304
305            // Call query hooks with the parsed statement
306            for hook in &self.query_hooks {
307                if let Some(result) = hook
308                    .handle_simple_query(&statement, &self.session_context, client)
309                    .await
310                {
311                    results.push(result?);
312                    continue 'stmt;
313                }
314            }
315
316            // Check if we're in a failed transaction and block non-transaction
317            // commands
318            if client.transaction_status() == TransactionStatus::Error {
319                return Err(PgWireError::UserError(Box::new(
320                pgwire::error::ErrorInfo::new(
321                    "ERROR".to_string(),
322                    "25P01".to_string(),
323                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
324                ),
325            )));
326            }
327
328            let df_result = {
329                let timeout = client::get_statement_timeout(client);
330                if let Some(timeout_duration) = timeout {
331                    tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
332                        .await
333                        .map_err(|_| {
334                            PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
335                                "ERROR".to_string(),
336                                "57014".to_string(), // query_canceled error code
337                                "canceling statement due to statement timeout".to_string(),
338                            )))
339                        })?
340                } else {
341                    self.session_context.sql(&query).await
342                }
343            };
344
345            // Handle query execution errors and transaction state
346            let df = match df_result {
347                Ok(df) => df,
348                Err(e) => {
349                    return Err(PgWireError::ApiError(Box::new(e)));
350                }
351            };
352
353            if query_lower.starts_with("insert into") {
354                let resp = map_rows_affected_for_insert(&df).await?;
355                results.push(resp);
356            } else {
357                // For non-INSERT queries, return a regular Query response
358                let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
359                results.push(Response::Query(resp));
360            }
361        }
362        Ok(results)
363    }
364}
365
366#[async_trait]
367impl ExtendedQueryHandler for DfSessionService {
368    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
369    type QueryParser = Parser;
370
371    fn query_parser(&self) -> Arc<Self::QueryParser> {
372        self.parser.clone()
373    }
374
375    async fn do_describe_statement<C>(
376        &self,
377        _client: &mut C,
378        target: &StoredStatement<Self::Statement>,
379    ) -> PgWireResult<DescribeStatementResponse>
380    where
381        C: ClientInfo + Unpin + Send + Sync,
382    {
383        if let (_, Some((_, plan))) = &target.statement {
384            let schema = plan.schema();
385            let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
386            let params = plan
387                .get_parameter_types()
388                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
389
390            let mut param_types = Vec::with_capacity(params.len());
391            for param_type in ordered_param_types(&params).iter() {
392                // Fixed: Use &params
393                if let Some(datatype) = param_type {
394                    let pgtype = into_pg_type(datatype)?;
395                    param_types.push(pgtype);
396                } else {
397                    param_types.push(Type::UNKNOWN);
398                }
399            }
400
401            Ok(DescribeStatementResponse::new(param_types, fields))
402        } else {
403            Ok(DescribeStatementResponse::no_data())
404        }
405    }
406
407    async fn do_describe_portal<C>(
408        &self,
409        _client: &mut C,
410        target: &Portal<Self::Statement>,
411    ) -> PgWireResult<DescribePortalResponse>
412    where
413        C: ClientInfo + Unpin + Send + Sync,
414    {
415        if let (_, Some((_, plan))) = &target.statement.statement {
416            let format = &target.result_column_format;
417            let schema = plan.schema();
418            let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
419
420            Ok(DescribePortalResponse::new(fields))
421        } else {
422            Ok(DescribePortalResponse::no_data())
423        }
424    }
425
426    async fn do_query<C>(
427        &self,
428        client: &mut C,
429        portal: &Portal<Self::Statement>,
430        _max_rows: usize,
431    ) -> PgWireResult<Response>
432    where
433        C: ClientInfo + Unpin + Send + Sync,
434    {
435        let query = portal
436            .statement
437            .statement
438            .0
439            .to_lowercase()
440            .trim()
441            .to_string();
442        log::debug!("Received execute extended query: {query}"); // Log for debugging
443
444        // Check query hooks first
445        if !self.query_hooks.is_empty() {
446            if let (_, Some((statement, plan))) = &portal.statement.statement {
447                // TODO: in the case where query hooks all return None, we do the param handling again later.
448                let param_types = plan
449                    .get_parameter_types()
450                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
451
452                let param_values: ParamValues =
453                    df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
454
455                for hook in &self.query_hooks {
456                    if let Some(result) = hook
457                        .handle_extended_query(
458                            statement,
459                            plan,
460                            &param_values,
461                            &self.session_context,
462                            client,
463                        )
464                        .await
465                    {
466                        return result;
467                    }
468                }
469            }
470        }
471
472        // Check permissions for the query (skip for SET and SHOW statements)
473        if !query.starts_with("set") && !query.starts_with("show") {
474            self.check_query_permission(client, &portal.statement.statement.0)
475                .await?;
476        }
477
478        if let Some(resp) = self
479            .try_respond_transaction_statements(client, &query)
480            .await?
481        {
482            return Ok(resp);
483        }
484
485        // Check if we're in a failed transaction and block non-transaction
486        // commands
487        if client.transaction_status() == TransactionStatus::Error {
488            return Err(PgWireError::UserError(Box::new(
489                pgwire::error::ErrorInfo::new(
490                    "ERROR".to_string(),
491                    "25P01".to_string(),
492                    "current transaction is aborted, commands ignored until end of transaction block".to_string(),
493                ),
494            )));
495        }
496
497        if let (_, Some((_, plan))) = &portal.statement.statement {
498            let param_types = plan
499                .get_parameter_types()
500                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
501
502            let param_values =
503                df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
504
505            let plan = plan
506                .clone()
507                .replace_params_with_values(&param_values)
508                .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
509                                                                   // &param_values
510            let optimised = self
511                .session_context
512                .state()
513                .optimize(&plan)
514                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
515
516            let dataframe = {
517                let timeout = client::get_statement_timeout(client);
518                if let Some(timeout_duration) = timeout {
519                    tokio::time::timeout(
520                        timeout_duration,
521                        self.session_context.execute_logical_plan(optimised),
522                    )
523                    .await
524                    .map_err(|_| {
525                        PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
526                            "ERROR".to_string(),
527                            "57014".to_string(), // query_canceled error code
528                            "canceling statement due to statement timeout".to_string(),
529                        )))
530                    })?
531                    .map_err(|e| PgWireError::ApiError(Box::new(e)))?
532                } else {
533                    self.session_context
534                        .execute_logical_plan(optimised)
535                        .await
536                        .map_err(|e| PgWireError::ApiError(Box::new(e)))?
537                }
538            };
539
540            if query.starts_with("insert into") {
541                let resp = map_rows_affected_for_insert(&dataframe).await?;
542
543                Ok(resp)
544            } else {
545                // For non-INSERT queries, return a regular Query response
546                let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
547                Ok(Response::Query(resp))
548            }
549        } else {
550            Ok(Response::EmptyQuery)
551        }
552    }
553}
554
555async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
556    // For INSERT queries, we need to execute the query to get the row count
557    // and return an Execution response with the proper tag
558    let result = df
559        .clone()
560        .collect()
561        .await
562        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
563
564    // Extract count field from the first batch
565    let rows_affected = result
566        .first()
567        .and_then(|batch| batch.column_by_name("count"))
568        .and_then(|col| {
569            col.as_any()
570                .downcast_ref::<datafusion::arrow::array::UInt64Array>()
571        })
572        .map_or(0, |array| array.value(0) as usize);
573
574    // Create INSERT tag with the affected row count
575    let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
576    Ok(Response::Execution(tag))
577}
578
579pub struct Parser {
580    session_context: Arc<SessionContext>,
581    sql_parser: PostgresCompatibilityParser,
582    query_hooks: Vec<Arc<dyn QueryHook>>,
583}
584
585impl Parser {
586    fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
587        // Check for transaction commands that shouldn't be parsed by DataFusion
588        let sql_lower = sql.to_lowercase();
589        let sql_trimmed = sql_lower.trim();
590
591        if matches!(
592            sql_trimmed,
593            "" | "begin"
594                | "begin transaction"
595                | "begin work"
596                | "start transaction"
597                | "commit"
598                | "commit transaction"
599                | "commit work"
600                | "end"
601                | "end transaction"
602                | "rollback"
603                | "rollback transaction"
604                | "rollback work"
605                | "abort"
606        ) {
607            // Return a dummy plan for transaction commands - they'll be handled by transaction handler
608            let dummy_schema = datafusion::common::DFSchema::empty();
609            return Ok(Some(LogicalPlan::EmptyRelation(
610                datafusion::logical_expr::EmptyRelation {
611                    produce_one_row: false,
612                    schema: Arc::new(dummy_schema),
613                },
614            )));
615        }
616
617        // show statement may not be supported by datafusion
618        if sql_trimmed.starts_with("show") {
619            let show_schema =
620                Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
621            let df_schema = show_schema.to_dfschema()?;
622            return Ok(Some(LogicalPlan::EmptyRelation(
623                datafusion::logical_expr::EmptyRelation {
624                    produce_one_row: true,
625                    schema: Arc::new(df_schema),
626                },
627            )));
628        }
629
630        Ok(None)
631    }
632}
633
634#[async_trait]
635impl QueryParser for Parser {
636    type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
637
638    async fn parse_sql<C>(
639        &self,
640        client: &C,
641        sql: &str,
642        _types: &[Type],
643    ) -> PgWireResult<Self::Statement>
644    where
645        C: ClientInfo + Unpin + Send + Sync,
646    {
647        log::debug!("Received parse extended query: {sql}"); // Log for debugging
648
649        let mut statements = self
650            .sql_parser
651            .parse(sql)
652            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
653        if statements.is_empty() {
654            return Ok((sql.to_string(), None));
655        }
656
657        let statement = statements.remove(0);
658
659        // Check for transaction commands that shouldn't be parsed by DataFusion
660        if let Some(plan) = self
661            .try_shortcut_parse_plan(sql)
662            .map_err(|e| PgWireError::ApiError(Box::new(e)))?
663        {
664            return Ok((sql.to_string(), Some((statement, plan))));
665        }
666
667        let query = statement.to_string();
668
669        let context = &self.session_context;
670        let state = context.state();
671
672        for hook in &self.query_hooks {
673            if let Some(logical_plan) = hook
674                .handle_extended_parse_query(&statement, context, client)
675                .await
676            {
677                return Ok((query, Some((statement, logical_plan?))));
678            }
679        }
680
681        let logical_plan = state
682            .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
683            .await
684            .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
685        Ok((query, Some((statement, logical_plan))))
686    }
687}
688
689fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
690    // Datafusion stores the parameters as a map.  In our case, the keys will be
691    // `$1`, `$2` etc.  The values will be the parameter types.
692    let mut types = types.iter().collect::<Vec<_>>();
693    types.sort_by(|a, b| a.0.cmp(b.0));
694    types.into_iter().map(|pt| pt.1.as_ref()).collect()
695}
696
697#[cfg(test)]
698mod tests {
699    use datafusion::prelude::SessionContext;
700
701    use super::*;
702    use crate::testing::MockClient;
703
704    struct TestHook;
705
706    #[async_trait]
707    impl QueryHook for TestHook {
708        async fn handle_simple_query(
709            &self,
710            statement: &sqlparser::ast::Statement,
711            _ctx: &SessionContext,
712            _client: &mut (dyn ClientInfo + Sync + Send),
713        ) -> Option<PgWireResult<Response>> {
714            if statement.to_string().contains("magic") {
715                Some(Ok(Response::EmptyQuery))
716            } else {
717                None
718            }
719        }
720
721        async fn handle_extended_parse_query(
722            &self,
723            _statement: &sqlparser::ast::Statement,
724            _session_context: &SessionContext,
725            _client: &(dyn ClientInfo + Send + Sync),
726        ) -> Option<PgWireResult<LogicalPlan>> {
727            None
728        }
729
730        async fn handle_extended_query(
731            &self,
732            _statement: &sqlparser::ast::Statement,
733            _logical_plan: &LogicalPlan,
734            _params: &ParamValues,
735            _session_context: &SessionContext,
736            _client: &mut (dyn ClientInfo + Send + Sync),
737        ) -> Option<PgWireResult<Response>> {
738            None
739        }
740    }
741
742    #[tokio::test]
743    async fn test_query_hooks() {
744        let hook = TestHook;
745        let ctx = SessionContext::new();
746        let mut client = MockClient::new();
747
748        // Parse a statement that contains "magic"
749        let parser = PostgresCompatibilityParser::new();
750        let statements = parser.parse("SELECT magic").unwrap();
751        let stmt = &statements[0];
752
753        // Hook should intercept
754        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
755        assert!(result.is_some());
756
757        // Parse a normal statement
758        let statements = parser.parse("SELECT 1").unwrap();
759        let stmt = &statements[0];
760
761        // Hook should not intercept
762        let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
763        assert!(result.is_none());
764    }
765
766    #[tokio::test]
767    async fn test_multiple_statements_with_hook_continue() {
768        // Bug #227: when a hook returned a result, the code used `break 'stmt`
769        // which would exit the entire statement loop, preventing subsequent statements
770        // from being processed.
771        let session_context = Arc::new(SessionContext::new());
772        let auth_manager = Arc::new(AuthManager::new());
773
774        let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
775        let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks);
776
777        let mut client = MockClient::new();
778
779        // Mix of queries with hooks and those without
780        let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
781
782        let results =
783            <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
784                .await
785                .unwrap();
786
787        assert_eq!(results.len(), 4, "Expected 4 responses");
788
789        assert!(matches!(results[0], Response::EmptyQuery));
790        assert!(matches!(results[1], Response::Query(_)));
791        assert!(matches!(results[2], Response::EmptyQuery));
792        assert!(matches!(results[3], Response::Query(_)));
793    }
794}