1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use datafusion::arrow::datatypes::{DataType, Field, Schema};
6use datafusion::common::{ParamValues, ToDFSchema};
7use datafusion::error::DataFusionError;
8use datafusion::logical_expr::LogicalPlan;
9use datafusion::prelude::*;
10use datafusion::sql::parser::Statement;
11use datafusion::sql::sqlparser;
12use log::info;
13use pgwire::api::auth::noop::NoopStartupHandler;
14use pgwire::api::auth::StartupHandler;
15use pgwire::api::portal::{Format, Portal};
16use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
17use pgwire::api::results::{
18 DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag,
19};
20use pgwire::api::stmt::QueryParser;
21use pgwire::api::stmt::StoredStatement;
22use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
23use pgwire::error::{PgWireError, PgWireResult};
24use pgwire::messages::response::TransactionStatus;
25
26use crate::auth::AuthManager;
27use crate::client;
28use crate::hooks::set_show::SetShowHook;
29use crate::hooks::QueryHook;
30use arrow_pg::datatypes::df;
31use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
32use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
33use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
34
35pub struct SimpleStartupHandler;
38
39#[async_trait::async_trait]
40impl NoopStartupHandler for SimpleStartupHandler {}
41
42pub struct HandlerFactory {
43 pub session_service: Arc<DfSessionService>,
44}
45
46impl HandlerFactory {
47 pub fn new(session_context: Arc<SessionContext>, auth_manager: Arc<AuthManager>) -> Self {
48 let session_service =
49 Arc::new(DfSessionService::new(session_context, auth_manager.clone()));
50 HandlerFactory { session_service }
51 }
52
53 pub fn new_with_hooks(
54 session_context: Arc<SessionContext>,
55 auth_manager: Arc<AuthManager>,
56 query_hooks: Vec<Arc<dyn QueryHook>>,
57 ) -> Self {
58 let session_service = Arc::new(DfSessionService::new_with_hooks(
59 session_context,
60 auth_manager.clone(),
61 query_hooks,
62 ));
63 HandlerFactory { session_service }
64 }
65}
66
67impl PgWireServerHandlers for HandlerFactory {
68 fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
69 self.session_service.clone()
70 }
71
72 fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
73 self.session_service.clone()
74 }
75
76 fn startup_handler(&self) -> Arc<impl StartupHandler> {
77 Arc::new(SimpleStartupHandler)
78 }
79
80 fn error_handler(&self) -> Arc<impl ErrorHandler> {
81 Arc::new(LoggingErrorHandler)
82 }
83}
84
85struct LoggingErrorHandler;
86
87impl ErrorHandler for LoggingErrorHandler {
88 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
89 where
90 C: ClientInfo,
91 {
92 info!("Sending error: {error}")
93 }
94}
95
96pub struct DfSessionService {
98 session_context: Arc<SessionContext>,
99 parser: Arc<Parser>,
100 auth_manager: Arc<AuthManager>,
101 query_hooks: Vec<Arc<dyn QueryHook>>,
102}
103
104impl DfSessionService {
105 pub fn new(
106 session_context: Arc<SessionContext>,
107 auth_manager: Arc<AuthManager>,
108 ) -> DfSessionService {
109 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(SetShowHook)];
110 Self::new_with_hooks(session_context, auth_manager, hooks)
111 }
112
113 pub fn new_with_hooks(
114 session_context: Arc<SessionContext>,
115 auth_manager: Arc<AuthManager>,
116 query_hooks: Vec<Arc<dyn QueryHook>>,
117 ) -> DfSessionService {
118 let parser = Arc::new(Parser {
119 session_context: session_context.clone(),
120 sql_parser: PostgresCompatibilityParser::new(),
121 query_hooks: query_hooks.clone(),
122 });
123 DfSessionService {
124 session_context,
125 parser,
126 auth_manager,
127 query_hooks,
128 }
129 }
130
131 async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
133 where
134 C: ClientInfo,
135 {
136 let username = client
138 .metadata()
139 .get("user")
140 .map(|s| s.as_str())
141 .unwrap_or("anonymous");
142
143 let query_lower = query.to_lowercase();
145 let query_trimmed = query_lower.trim();
146
147 let (required_permission, resource) = if query_trimmed.starts_with("select") {
148 (Permission::Select, self.extract_table_from_query(query))
149 } else if query_trimmed.starts_with("insert") {
150 (Permission::Insert, self.extract_table_from_query(query))
151 } else if query_trimmed.starts_with("update") {
152 (Permission::Update, self.extract_table_from_query(query))
153 } else if query_trimmed.starts_with("delete") {
154 (Permission::Delete, self.extract_table_from_query(query))
155 } else if query_trimmed.starts_with("create table")
156 || query_trimmed.starts_with("create view")
157 {
158 (Permission::Create, ResourceType::All)
159 } else if query_trimmed.starts_with("drop") {
160 (Permission::Drop, self.extract_table_from_query(query))
161 } else if query_trimmed.starts_with("alter") {
162 (Permission::Alter, self.extract_table_from_query(query))
163 } else {
164 return Ok(());
166 };
167
168 let has_permission = self
170 .auth_manager
171 .check_permission(username, required_permission, resource)
172 .await;
173
174 if !has_permission {
175 return Err(PgWireError::UserError(Box::new(
176 pgwire::error::ErrorInfo::new(
177 "ERROR".to_string(),
178 "42501".to_string(), format!("permission denied for user \"{username}\""),
180 ),
181 )));
182 }
183
184 Ok(())
185 }
186
187 fn extract_table_from_query(&self, query: &str) -> ResourceType {
189 let words: Vec<&str> = query.split_whitespace().collect();
190
191 for (i, word) in words.iter().enumerate() {
193 let word_lower = word.to_lowercase();
194 if (word_lower == "from" || word_lower == "into" || word_lower == "table")
195 && i + 1 < words.len()
196 {
197 let table_name = words[i + 1].trim_matches(|c| c == '(' || c == ')' || c == ';');
198 return ResourceType::Table(table_name.to_string());
199 }
200 }
201
202 ResourceType::All
204 }
205
206 async fn try_respond_transaction_statements<C>(
207 &self,
208 client: &C,
209 query_lower: &str,
210 ) -> PgWireResult<Option<Response>>
211 where
212 C: ClientInfo,
213 {
214 match query_lower.trim() {
217 "begin" | "begin transaction" | "begin work" | "start transaction" => {
218 match client.transaction_status() {
219 TransactionStatus::Idle => {
220 Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
221 }
222 TransactionStatus::Transaction => {
223 log::warn!("BEGIN command ignored: already in transaction block");
226 Ok(Some(Response::Execution(Tag::new("BEGIN"))))
227 }
228 TransactionStatus::Error => {
229 Err(PgWireError::UserError(Box::new(
231 pgwire::error::ErrorInfo::new(
232 "ERROR".to_string(),
233 "25P01".to_string(),
234 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
235 ),
236 )))
237 }
238 }
239 }
240 "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
241 match client.transaction_status() {
242 TransactionStatus::Idle | TransactionStatus::Transaction => {
243 Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
244 }
245 TransactionStatus::Error => {
246 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
247 }
248 }
249 }
250 "rollback" | "rollback transaction" | "rollback work" | "abort" => {
251 Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
252 }
253 _ => Ok(None),
254 }
255 }
256}
257
258#[async_trait]
259impl SimpleQueryHandler for DfSessionService {
260 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
261 where
262 C: ClientInfo + Unpin + Send + Sync,
263 {
264 log::debug!("Received query: {query}"); let query_lower = query.to_lowercase().trim().to_string();
268 if let Some(resp) = self
269 .try_respond_transaction_statements(client, &query_lower)
270 .await?
271 {
272 return Ok(vec![resp]);
273 }
274
275 let statements = self
276 .parser
277 .sql_parser
278 .parse(query)
279 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
280
281 if statements.is_empty() {
283 return Ok(vec![Response::EmptyQuery]);
284 }
285
286 let mut results = vec![];
287 'stmt: for statement in statements {
288 let query = statement.to_string();
290 let query_lower = query.to_lowercase().trim().to_string();
291
292 if !query_lower.starts_with("set")
294 && !query_lower.starts_with("begin")
295 && !query_lower.starts_with("commit")
296 && !query_lower.starts_with("rollback")
297 && !query_lower.starts_with("start")
298 && !query_lower.starts_with("end")
299 && !query_lower.starts_with("abort")
300 && !query_lower.starts_with("show")
301 {
302 self.check_query_permission(client, &query).await?;
303 }
304
305 for hook in &self.query_hooks {
307 if let Some(result) = hook
308 .handle_simple_query(&statement, &self.session_context, client)
309 .await
310 {
311 results.push(result?);
312 continue 'stmt;
313 }
314 }
315
316 if client.transaction_status() == TransactionStatus::Error {
319 return Err(PgWireError::UserError(Box::new(
320 pgwire::error::ErrorInfo::new(
321 "ERROR".to_string(),
322 "25P01".to_string(),
323 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
324 ),
325 )));
326 }
327
328 let df_result = {
329 let timeout = client::get_statement_timeout(client);
330 if let Some(timeout_duration) = timeout {
331 tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
332 .await
333 .map_err(|_| {
334 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
335 "ERROR".to_string(),
336 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
338 )))
339 })?
340 } else {
341 self.session_context.sql(&query).await
342 }
343 };
344
345 let df = match df_result {
347 Ok(df) => df,
348 Err(e) => {
349 return Err(PgWireError::ApiError(Box::new(e)));
350 }
351 };
352
353 if query_lower.starts_with("insert into") {
354 let resp = map_rows_affected_for_insert(&df).await?;
355 results.push(resp);
356 } else {
357 let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
359 results.push(Response::Query(resp));
360 }
361 }
362 Ok(results)
363 }
364}
365
366#[async_trait]
367impl ExtendedQueryHandler for DfSessionService {
368 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
369 type QueryParser = Parser;
370
371 fn query_parser(&self) -> Arc<Self::QueryParser> {
372 self.parser.clone()
373 }
374
375 async fn do_describe_statement<C>(
376 &self,
377 _client: &mut C,
378 target: &StoredStatement<Self::Statement>,
379 ) -> PgWireResult<DescribeStatementResponse>
380 where
381 C: ClientInfo + Unpin + Send + Sync,
382 {
383 if let (_, Some((_, plan))) = &target.statement {
384 let schema = plan.schema();
385 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
386 let params = plan
387 .get_parameter_types()
388 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
389
390 let mut param_types = Vec::with_capacity(params.len());
391 for param_type in ordered_param_types(¶ms).iter() {
392 if let Some(datatype) = param_type {
394 let pgtype = into_pg_type(datatype)?;
395 param_types.push(pgtype);
396 } else {
397 param_types.push(Type::UNKNOWN);
398 }
399 }
400
401 Ok(DescribeStatementResponse::new(param_types, fields))
402 } else {
403 Ok(DescribeStatementResponse::no_data())
404 }
405 }
406
407 async fn do_describe_portal<C>(
408 &self,
409 _client: &mut C,
410 target: &Portal<Self::Statement>,
411 ) -> PgWireResult<DescribePortalResponse>
412 where
413 C: ClientInfo + Unpin + Send + Sync,
414 {
415 if let (_, Some((_, plan))) = &target.statement.statement {
416 let format = &target.result_column_format;
417 let schema = plan.schema();
418 let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
419
420 Ok(DescribePortalResponse::new(fields))
421 } else {
422 Ok(DescribePortalResponse::no_data())
423 }
424 }
425
426 async fn do_query<C>(
427 &self,
428 client: &mut C,
429 portal: &Portal<Self::Statement>,
430 _max_rows: usize,
431 ) -> PgWireResult<Response>
432 where
433 C: ClientInfo + Unpin + Send + Sync,
434 {
435 let query = portal
436 .statement
437 .statement
438 .0
439 .to_lowercase()
440 .trim()
441 .to_string();
442 log::debug!("Received execute extended query: {query}"); if !self.query_hooks.is_empty() {
446 if let (_, Some((statement, plan))) = &portal.statement.statement {
447 let param_types = plan
449 .get_parameter_types()
450 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
451
452 let param_values: ParamValues =
453 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?;
454
455 for hook in &self.query_hooks {
456 if let Some(result) = hook
457 .handle_extended_query(
458 statement,
459 plan,
460 ¶m_values,
461 &self.session_context,
462 client,
463 )
464 .await
465 {
466 return result;
467 }
468 }
469 }
470 }
471
472 if !query.starts_with("set") && !query.starts_with("show") {
474 self.check_query_permission(client, &portal.statement.statement.0)
475 .await?;
476 }
477
478 if let Some(resp) = self
479 .try_respond_transaction_statements(client, &query)
480 .await?
481 {
482 return Ok(resp);
483 }
484
485 if client.transaction_status() == TransactionStatus::Error {
488 return Err(PgWireError::UserError(Box::new(
489 pgwire::error::ErrorInfo::new(
490 "ERROR".to_string(),
491 "25P01".to_string(),
492 "current transaction is aborted, commands ignored until end of transaction block".to_string(),
493 ),
494 )));
495 }
496
497 if let (_, Some((_, plan))) = &portal.statement.statement {
498 let param_types = plan
499 .get_parameter_types()
500 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
501
502 let param_values =
503 df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan
506 .clone()
507 .replace_params_with_values(¶m_values)
508 .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let optimised = self
511 .session_context
512 .state()
513 .optimize(&plan)
514 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
515
516 let dataframe = {
517 let timeout = client::get_statement_timeout(client);
518 if let Some(timeout_duration) = timeout {
519 tokio::time::timeout(
520 timeout_duration,
521 self.session_context.execute_logical_plan(optimised),
522 )
523 .await
524 .map_err(|_| {
525 PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
526 "ERROR".to_string(),
527 "57014".to_string(), "canceling statement due to statement timeout".to_string(),
529 )))
530 })?
531 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
532 } else {
533 self.session_context
534 .execute_logical_plan(optimised)
535 .await
536 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
537 }
538 };
539
540 if query.starts_with("insert into") {
541 let resp = map_rows_affected_for_insert(&dataframe).await?;
542
543 Ok(resp)
544 } else {
545 let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
547 Ok(Response::Query(resp))
548 }
549 } else {
550 Ok(Response::EmptyQuery)
551 }
552 }
553}
554
555async fn map_rows_affected_for_insert(df: &DataFrame) -> PgWireResult<Response> {
556 let result = df
559 .clone()
560 .collect()
561 .await
562 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
563
564 let rows_affected = result
566 .first()
567 .and_then(|batch| batch.column_by_name("count"))
568 .and_then(|col| {
569 col.as_any()
570 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
571 })
572 .map_or(0, |array| array.value(0) as usize);
573
574 let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected);
576 Ok(Response::Execution(tag))
577}
578
579pub struct Parser {
580 session_context: Arc<SessionContext>,
581 sql_parser: PostgresCompatibilityParser,
582 query_hooks: Vec<Arc<dyn QueryHook>>,
583}
584
585impl Parser {
586 fn try_shortcut_parse_plan(&self, sql: &str) -> Result<Option<LogicalPlan>, DataFusionError> {
587 let sql_lower = sql.to_lowercase();
589 let sql_trimmed = sql_lower.trim();
590
591 if matches!(
592 sql_trimmed,
593 "" | "begin"
594 | "begin transaction"
595 | "begin work"
596 | "start transaction"
597 | "commit"
598 | "commit transaction"
599 | "commit work"
600 | "end"
601 | "end transaction"
602 | "rollback"
603 | "rollback transaction"
604 | "rollback work"
605 | "abort"
606 ) {
607 let dummy_schema = datafusion::common::DFSchema::empty();
609 return Ok(Some(LogicalPlan::EmptyRelation(
610 datafusion::logical_expr::EmptyRelation {
611 produce_one_row: false,
612 schema: Arc::new(dummy_schema),
613 },
614 )));
615 }
616
617 if sql_trimmed.starts_with("show") {
619 let show_schema =
620 Arc::new(Schema::new(vec![Field::new("show", DataType::Utf8, false)]));
621 let df_schema = show_schema.to_dfschema()?;
622 return Ok(Some(LogicalPlan::EmptyRelation(
623 datafusion::logical_expr::EmptyRelation {
624 produce_one_row: true,
625 schema: Arc::new(df_schema),
626 },
627 )));
628 }
629
630 Ok(None)
631 }
632}
633
634#[async_trait]
635impl QueryParser for Parser {
636 type Statement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>);
637
638 async fn parse_sql<C>(
639 &self,
640 client: &C,
641 sql: &str,
642 _types: &[Type],
643 ) -> PgWireResult<Self::Statement>
644 where
645 C: ClientInfo + Unpin + Send + Sync,
646 {
647 log::debug!("Received parse extended query: {sql}"); let mut statements = self
650 .sql_parser
651 .parse(sql)
652 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
653 if statements.is_empty() {
654 return Ok((sql.to_string(), None));
655 }
656
657 let statement = statements.remove(0);
658
659 if let Some(plan) = self
661 .try_shortcut_parse_plan(sql)
662 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
663 {
664 return Ok((sql.to_string(), Some((statement, plan))));
665 }
666
667 let query = statement.to_string();
668
669 let context = &self.session_context;
670 let state = context.state();
671
672 for hook in &self.query_hooks {
673 if let Some(logical_plan) = hook
674 .handle_extended_parse_query(&statement, context, client)
675 .await
676 {
677 return Ok((query, Some((statement, logical_plan?))));
678 }
679 }
680
681 let logical_plan = state
682 .statement_to_plan(Statement::Statement(Box::new(statement.clone())))
683 .await
684 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
685 Ok((query, Some((statement, logical_plan))))
686 }
687}
688
689fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
690 let mut types = types.iter().collect::<Vec<_>>();
693 types.sort_by(|a, b| a.0.cmp(b.0));
694 types.into_iter().map(|pt| pt.1.as_ref()).collect()
695}
696
697#[cfg(test)]
698mod tests {
699 use datafusion::prelude::SessionContext;
700
701 use super::*;
702 use crate::testing::MockClient;
703
704 struct TestHook;
705
706 #[async_trait]
707 impl QueryHook for TestHook {
708 async fn handle_simple_query(
709 &self,
710 statement: &sqlparser::ast::Statement,
711 _ctx: &SessionContext,
712 _client: &mut (dyn ClientInfo + Sync + Send),
713 ) -> Option<PgWireResult<Response>> {
714 if statement.to_string().contains("magic") {
715 Some(Ok(Response::EmptyQuery))
716 } else {
717 None
718 }
719 }
720
721 async fn handle_extended_parse_query(
722 &self,
723 _statement: &sqlparser::ast::Statement,
724 _session_context: &SessionContext,
725 _client: &(dyn ClientInfo + Send + Sync),
726 ) -> Option<PgWireResult<LogicalPlan>> {
727 None
728 }
729
730 async fn handle_extended_query(
731 &self,
732 _statement: &sqlparser::ast::Statement,
733 _logical_plan: &LogicalPlan,
734 _params: &ParamValues,
735 _session_context: &SessionContext,
736 _client: &mut (dyn ClientInfo + Send + Sync),
737 ) -> Option<PgWireResult<Response>> {
738 None
739 }
740 }
741
742 #[tokio::test]
743 async fn test_query_hooks() {
744 let hook = TestHook;
745 let ctx = SessionContext::new();
746 let mut client = MockClient::new();
747
748 let parser = PostgresCompatibilityParser::new();
750 let statements = parser.parse("SELECT magic").unwrap();
751 let stmt = &statements[0];
752
753 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
755 assert!(result.is_some());
756
757 let statements = parser.parse("SELECT 1").unwrap();
759 let stmt = &statements[0];
760
761 let result = hook.handle_simple_query(stmt, &ctx, &mut client).await;
763 assert!(result.is_none());
764 }
765
766 #[tokio::test]
767 async fn test_multiple_statements_with_hook_continue() {
768 let session_context = Arc::new(SessionContext::new());
772 let auth_manager = Arc::new(AuthManager::new());
773
774 let hooks: Vec<Arc<dyn QueryHook>> = vec![Arc::new(TestHook)];
775 let service = DfSessionService::new_with_hooks(session_context, auth_manager, hooks);
776
777 let mut client = MockClient::new();
778
779 let query = "SELECT magic; SELECT 1; SELECT magic; SELECT 1";
781
782 let results =
783 <DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, query)
784 .await
785 .unwrap();
786
787 assert_eq!(results.len(), 4, "Expected 4 responses");
788
789 assert!(matches!(results[0], Response::EmptyQuery));
790 assert!(matches!(results[1], Response::Query(_)));
791 assert!(matches!(results[2], Response::EmptyQuery));
792 assert!(matches!(results[3], Response::Query(_)));
793 }
794}