diff --git a/aws-dsql-auth b/aws-dsql-auth index f33fce18f3864..3c53ec315ed61 160000 --- a/aws-dsql-auth +++ b/aws-dsql-auth @@ -1 +1 @@ -Subproject commit f33fce18f3864e9428b10ae63c018b004b18436a +Subproject commit 3c53ec315ed619715c56875e81476119c365ddac diff --git a/scripts/build-dsql.sh b/scripts/build-dsql.sh index 83173df1b8ada..2a2f9f6197cab 100755 --- a/scripts/build-dsql.sh +++ b/scripts/build-dsql.sh @@ -51,7 +51,7 @@ else echo " aws-dsql-auth submodules already initialized." fi -if [ ! -f "aws-dsql-auth/build/install/lib64/libaws-dsql-auth.a" ]; then +if [ ! -d "aws-dsql-auth/build/install/" ]; then # Build aws-dsql-auth echo " Building aws-dsql-auth library..." cd aws-dsql-auth diff --git a/scripts/install.sh b/scripts/install.sh index 7959634c7d33b..0dcbba2d306d3 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -196,11 +196,14 @@ install_via_zip() { # Clean up temp files rm -rf "$TEMP_DIR" - # Make the binary executable + # Make the binaries executable chmod +x "$INSTALL_PATH/bin/pdsql" + chmod +x "$INSTALL_PATH/bin/pgbench" echo "ZIP installation completed successfully!" - echo "PostgreSQL DSQL (pdsql) installed to: $INSTALL_PATH/bin/pdsql" + echo "PostgreSQL DSQL tools installed to:" + echo " - pdsql (DSQL client): $INSTALL_PATH/bin/pdsql" + echo " - pgbench (benchmark tool): $INSTALL_PATH/bin/pgbench" # Check if installation path is in PATH if [[ ":$PATH:" != *":$INSTALL_PATH/bin:"* ]]; then diff --git a/src/bin/pgbench/pgbench.c b/src/bin/pgbench/pgbench.c index 8c4617917bb34..967612a185828 100644 --- a/src/bin/pgbench/pgbench.c +++ b/src/bin/pgbench/pgbench.c @@ -68,6 +68,7 @@ #include "pgbench.h" #include "port/pg_bitutils.h" #include "portability/instr_time.h" +#include "fe-dsql-auth.h" /* X/Open (XSI) requires to provide M_PI, but core POSIX does not */ #ifndef M_PI @@ -7124,6 +7125,29 @@ main(int argc, char **argv) setenv("PGDSQL", "1", 1); is_no_vacuum = true; foreign_keys = false; + + /* Initialize DSQL token generator */ + if (dsql_initialize_token_generator() != 0) + { + pg_fatal("Failed to initialize DSQL token generator"); + } + + /* Validate AWS credentials */ + { + char *err_msg = NULL; + if (dsql_validate_aws_credentials(&err_msg) != 0) + { + if (err_msg) + { + pg_fatal("DSQL credential validation failed: %s", err_msg); + free(err_msg); + } + else + { + pg_fatal("DSQL credential validation failed"); + } + } + } } /* set default script if none */ diff --git a/src/bin/psql/startup.c b/src/bin/psql/startup.c index 291fd317b024f..abdcb57abf589 100644 --- a/src/bin/psql/startup.c +++ b/src/bin/psql/startup.c @@ -6,6 +6,7 @@ * src/bin/psql/startup.c */ #include "postgres_fe.h" +#include "libpq-fe.h" #ifndef WIN32 #include @@ -26,6 +27,8 @@ #include "mainloop.h" #include "settings.h" +#include "fe-dsql-auth.h" + /* * Global psql options */ @@ -221,6 +224,29 @@ main(int argc, char *argv[]) { setenv("PGDSQL", "1", 1); pset.getPassword = TRI_NO; + + /* Initialize DSQL token generator */ + if (dsql_initialize_token_generator() != 0) + { + pg_fatal("Failed to initialize DSQL token generator"); + } + + /* Validate AWS credentials */ + { + char *err_msg = NULL; + if (dsql_validate_aws_credentials(&err_msg) != 0) + { + if (err_msg) + { + pg_fatal("DSQL credential validation failed: %s", err_msg); + free(err_msg); + } + else + { + pg_fatal("DSQL credential validation failed"); + } + } + } } /* diff --git a/src/interfaces/libpq/Makefile b/src/interfaces/libpq/Makefile index 03556fcd12c49..883aaa78da118 100644 --- a/src/interfaces/libpq/Makefile +++ b/src/interfaces/libpq/Makefile @@ -90,15 +90,19 @@ ifeq ($(PORTNAME), linux) # Link with AWS libraries using start-group to resolve dependencies $(LD) -r -o fe-dsql-auth-with-aws.o fe-dsql-auth-temp.o --start-group $(AWS_DSQL_AUTH_ALL_LIBS) --end-group # Create a list of symbols to keep (only the public API from fe-dsql-auth.h) - echo "generate_dsql_token" > keep-symbols.txt - echo "dsql_auth_cleanup" >> keep-symbols.txt + echo "dsql_initialize_token_generator" > keep-symbols.txt + echo "dsql_generate_token" >> keep-symbols.txt + echo "dsql_validate_aws_credentials" >> keep-symbols.txt + echo "dsql_cleanup" >> keep-symbols.txt # Hide all symbols except the ones we want to keep objcopy --keep-global-symbols=keep-symbols.txt fe-dsql-auth-with-aws.o $@ rm -f fe-dsql-auth-temp.o fe-dsql-auth-with-aws.o keep-symbols.txt else ifeq ($(PORTNAME), darwin) # macOS: Use ld with exported symbols list - echo "_generate_dsql_token" > exported-symbols.txt - echo "_dsql_auth_cleanup" >> exported-symbols.txt + echo "_dsql_initialize_token_generator" > exported-symbols.txt + echo "_dsql_generate_token" >> exported-symbols.txt + echo "_dsql_validate_aws_credentials" >> exported-symbols.txt + echo "_dsql_cleanup" >> exported-symbols.txt $(LD) -r -o $@ fe-dsql-auth-temp.o -exported_symbols_list exported-symbols.txt $(AWS_DSQL_AUTH_ALL_LIBS) rm -f fe-dsql-auth-temp.o exported-symbols.txt else diff --git a/src/interfaces/libpq/exports.txt b/src/interfaces/libpq/exports.txt index 0625cf39e9af3..7ccf0b80cbd03 100644 --- a/src/interfaces/libpq/exports.txt +++ b/src/interfaces/libpq/exports.txt @@ -211,3 +211,7 @@ PQgetAuthDataHook 208 PQdefaultAuthDataHook 209 PQfullProtocolVersion 210 appendPQExpBufferVA 211 +dsql_initialize_token_generator 212 +dsql_generate_token 213 +dsql_validate_aws_credentials 214 +dsql_cleanup 215 diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c index 682b3b064c406..55cd08889ca0e 100644 --- a/src/interfaces/libpq/fe-connect.c +++ b/src/interfaces/libpq/fe-connect.c @@ -1441,7 +1441,7 @@ pqConnectOptions2(PGconn *conn) pwhost = conn->connhost[i].hostaddr; is_admin = strcmp("admin", conn->pguser) == 0; - token = generate_dsql_token(pwhost, is_admin, &err_msg); + token = dsql_generate_token(pwhost, is_admin, &err_msg); if (!token) { libpq_append_conn_error(conn, "DSQL token generation failed for host=%s: %s", diff --git a/src/interfaces/libpq/fe-dsql-auth.c b/src/interfaces/libpq/fe-dsql-auth.c index c7ed75d300eaf..6163302a38cab 100644 --- a/src/interfaces/libpq/fe-dsql-auth.c +++ b/src/interfaces/libpq/fe-dsql-auth.c @@ -18,6 +18,8 @@ /* Include AWS DSQL Auth library functions */ #include #include +#include +#include #include #include #include @@ -37,6 +39,17 @@ static struct aws_event_loop_group *s_el_group = NULL; static struct aws_host_resolver *s_host_resolver = NULL; static struct aws_client_bootstrap *s_client_bootstrap = NULL; +/* + * DSQL Token Generator - holds state for efficient token generation + */ +struct dsql_token_generator { + struct aws_allocator *allocator; + struct aws_credentials_provider *credentials_provider; +}; + +/* Global token generator instance */ +static struct dsql_token_generator s_token_generator = {0}; + /* * Initialize DSQL logging */ @@ -47,7 +60,7 @@ initialize_dsql_logging(void) { struct aws_allocator *allocator = aws_default_allocator(); struct aws_logger_standard_options logger_options = { - .level = AWS_LOG_LEVEL_DEBUG, /* Can be controlled by environment variable */ + .level = AWS_LOG_LEVEL_NONE, /* Can be controlled by environment variable */ .file = stderr /* Log to stderr by default */ }; @@ -90,29 +103,6 @@ initialize_dsql_logging(void) } } -/* - * Clean up DSQL authentication resources - */ -void -dsql_auth_cleanup(void) -{ - if (dsql_logger_initialized) - { - aws_logger_set(NULL); - aws_logger_clean_up(&dsql_logger); - dsql_logger_initialized = false; - } - - if (aws_libs_initialized) - { - aws_sdkutils_library_clean_up(); - aws_auth_library_clean_up(); - aws_io_library_clean_up(); - aws_common_library_clean_up(); - aws_libs_initialized = false; - } -} - /* * Initialize AWS libraries if not already initialized */ @@ -121,7 +111,11 @@ initialize_aws_libs(void) { if (!aws_libs_initialized) { - struct aws_allocator *allocator = aws_default_allocator(); + struct aws_allocator *allocator; + struct aws_host_resolver_default_options resolver_options; + struct aws_client_bootstrap_options bootstrap_options; + + allocator = aws_default_allocator(); aws_common_library_init(allocator); aws_io_library_init(allocator); aws_http_library_init(allocator); @@ -137,20 +131,18 @@ initialize_aws_libs(void) goto error; } - struct aws_host_resolver_default_options resolver_options = { - .el_group = s_el_group, - .max_entries = 8, - }; + AWS_ZERO_STRUCT(resolver_options); + resolver_options.el_group = s_el_group; + resolver_options.max_entries = 8; s_host_resolver = aws_host_resolver_new_default(allocator, &resolver_options); if (!s_host_resolver) { AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to create host resolver"); goto error; } - struct aws_client_bootstrap_options bootstrap_options = { - .event_loop_group = s_el_group, - .host_resolver = s_host_resolver, - }; + AWS_ZERO_STRUCT(bootstrap_options); + bootstrap_options.event_loop_group = s_el_group; + bootstrap_options.host_resolver = s_host_resolver; s_client_bootstrap = aws_client_bootstrap_new(allocator, &bootstrap_options); if (!s_client_bootstrap) { AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to create client bootstrap"); @@ -183,41 +175,130 @@ initialize_aws_libs(void) } } +/* + * Initialize the DSQL token generator with long-lived components + */ +static int +initialize_token_generator(void) +{ + struct aws_credentials_provider_chain_default_options credentials_options; + + if (s_token_generator.allocator != NULL) { + /* Already initialized */ + return AWS_OP_SUCCESS; + } + + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Initializing DSQL token generator"); + + s_token_generator.allocator = aws_default_allocator(); + + /* Create credentials provider with client bootstrap for IMDS */ + AWS_ZERO_STRUCT(credentials_options); + credentials_options.bootstrap = s_client_bootstrap; + + s_token_generator.credentials_provider = aws_credentials_provider_new_chain_default( + s_token_generator.allocator, &credentials_options); + + if (!s_token_generator.credentials_provider) { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to create credentials provider for token generator"); + s_token_generator.allocator = NULL; + return AWS_OP_ERR; + } + + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "DSQL token generator initialized successfully"); + return AWS_OP_SUCCESS; +} + +/* + * Public initialization function for the DSQL token generator + */ +int +dsql_initialize_token_generator(void) +{ + /* Initialize AWS libraries and logging */ + initialize_aws_libs(); + initialize_dsql_logging(); + + return initialize_token_generator(); +} + +/* + * Clean up the DSQL token generator + */ +static void +cleanup_token_generator(void) +{ + if (s_token_generator.allocator != NULL) { + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Cleaning up DSQL token generator"); + + if (s_token_generator.credentials_provider) { + aws_credentials_provider_release(s_token_generator.credentials_provider); + s_token_generator.credentials_provider = NULL; + } + + s_token_generator.allocator = NULL; + } +} + +/* + * Clean up DSQL authentication resources + */ +void +dsql_cleanup(void) +{ + cleanup_token_generator(); + + if (dsql_logger_initialized) + { + aws_logger_set(NULL); + aws_logger_clean_up(&dsql_logger); + dsql_logger_initialized = false; + } + + if (aws_libs_initialized) + { + aws_sdkutils_library_clean_up(); + aws_auth_library_clean_up(); + aws_io_library_clean_up(); + aws_common_library_clean_up(); + aws_libs_initialized = false; + } +} + /* * Generate a DSQL authentication token for the specified endpoint. - * Uses the AWS DSQL auth library to generate a real token. + * Uses a local auth_config for thread safety and cached credentials provider for efficiency. * Returns a newly allocated string containing the token. */ char * -generate_dsql_token(const char *endpoint, bool admin, char **err_msg) +dsql_generate_token(const char *endpoint, bool admin, char **err_msg) { - struct aws_allocator *allocator; - struct aws_dsql_auth_config auth_config; + struct aws_dsql_auth_config auth_config = {0}; struct aws_dsql_auth_token auth_token = {0}; struct aws_string *aws_region = NULL; - struct aws_credentials_provider *credentials_provider = NULL; - struct aws_credentials_provider_chain_default_options credentials_options; char *token = NULL; int aws_error; const char *env_region; const char *token_str; - /* Initialize AWS libraries and logging */ - initialize_aws_libs(); - initialize_dsql_logging(); - - allocator = aws_default_allocator(); + /* Check if token generator is initialized */ + if (s_token_generator.allocator == NULL) { + if (err_msg) + *err_msg = strdup("Token generator not initialized"); + return NULL; + } AWS_LOGF_INFO(AWS_LS_AUTH_GENERAL, "Starting DSQL token generation for endpoint: %s", endpoint); - /* Initialize DSQL auth config */ + /* Initialize a local auth config for thread safety */ if (aws_dsql_auth_config_init(&auth_config) != AWS_OP_SUCCESS) { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to initialize local auth config"); if (err_msg) - *err_msg = strdup("Failed to initialize DSQL auth config"); - goto cleanup; + *err_msg = strdup("Failed to initialize auth config"); + return NULL; } - /* Set hostname */ + /* Set hostname on the local auth config */ aws_dsql_auth_config_set_hostname(&auth_config, endpoint); /* Try to get region from environment variable first */ @@ -225,7 +306,7 @@ generate_dsql_token(const char *endpoint, bool admin, char **err_msg) if (env_region != NULL && env_region[0] != '\0') { AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Using AWS_REGION from environment: %s", env_region); - aws_region = aws_string_new_from_c_str(allocator, env_region); + aws_region = aws_string_new_from_c_str(s_token_generator.allocator, env_region); if (!aws_region) { AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to create region string from AWS_REGION"); if (err_msg) @@ -237,7 +318,7 @@ generate_dsql_token(const char *endpoint, bool admin, char **err_msg) { AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "AWS_REGION not set, attempting to infer from hostname: %s", endpoint); /* Try to infer region from hostname */ - if (aws_dsql_auth_config_infer_region(allocator, &auth_config, &aws_region) != AWS_OP_SUCCESS || + if (aws_dsql_auth_config_infer_region(s_token_generator.allocator, &auth_config, &aws_region) != AWS_OP_SUCCESS || aws_region == NULL) { AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to infer AWS region from hostname: %s", endpoint); @@ -249,28 +330,15 @@ generate_dsql_token(const char *endpoint, bool admin, char **err_msg) } aws_dsql_auth_config_set_region(&auth_config, aws_region); - /* Create default credentials provider with client bootstrap for IMDS */ - AWS_ZERO_STRUCT(credentials_options); - credentials_options.bootstrap = s_client_bootstrap; - - AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Creating credentials provider chain with bootstrap for IMDS"); - credentials_provider = aws_credentials_provider_new_chain_default(allocator, &credentials_options); - if (!credentials_provider) { - aws_error = aws_last_error(); - if (err_msg) - *err_msg = strdup(aws_error_str(aws_error)); - goto cleanup; - } - - /* Set credentials provider */ - aws_dsql_auth_config_set_credentials_provider(&auth_config, credentials_provider); + /* Set the cached credentials provider */ + aws_dsql_auth_config_set_credentials_provider(&auth_config, s_token_generator.credentials_provider); /* Set expiration time to 5 seconds for shorter token lifetime */ aws_dsql_auth_config_set_expires_in(&auth_config, 5); /* 5 seconds */ - /* Generate the token */ + /* Generate the token using local auth config and cached components */ AWS_ZERO_STRUCT(auth_token); - if (aws_dsql_auth_token_generate(&auth_config, admin, allocator, &auth_token) != AWS_OP_SUCCESS) + if (aws_dsql_auth_token_generate(&auth_config, admin, s_token_generator.allocator, &auth_token) != AWS_OP_SUCCESS) { aws_error = aws_last_error(); if (err_msg) @@ -283,10 +351,11 @@ generate_dsql_token(const char *endpoint, bool admin, char **err_msg) if (token_str) { token = strdup(token_str); - /* Token generation successful */ + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "DSQL token generated successfully using local auth config and cached credentials"); } else { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to get token string from generated token"); if (err_msg) *err_msg = strdup("Failed to get token string"); } @@ -295,11 +364,6 @@ generate_dsql_token(const char *endpoint, bool admin, char **err_msg) aws_dsql_auth_token_clean_up(&auth_token); aws_dsql_auth_config_clean_up(&auth_config); - if (credentials_provider) - { - aws_credentials_provider_release(credentials_provider); - } - if (aws_region) { aws_string_destroy(aws_region); @@ -307,3 +371,116 @@ generate_dsql_token(const char *endpoint, bool admin, char **err_msg) return token; } + +/* Synchronous credential retrieval state */ +struct credential_validation_state { + struct aws_credentials *credentials; + int error_code; + bool completed; + struct aws_mutex mutex; + struct aws_condition_variable condition_variable; +}; + +/* Callback for synchronous credential retrieval */ +static void +s_on_credentials_acquired(struct aws_credentials *credentials, int error_code, void *user_data) +{ + struct credential_validation_state *state = (struct credential_validation_state *)user_data; + + aws_mutex_lock(&state->mutex); + + state->credentials = credentials; + state->error_code = error_code; + state->completed = true; + + if (credentials) { + aws_credentials_acquire(credentials); + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Credentials acquired successfully for validation"); + } else { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Credentials acquisition failed with error: %d", error_code); + } + + aws_condition_variable_notify_one(&state->condition_variable); + aws_mutex_unlock(&state->mutex); +} + +/* + * Validate AWS credentials early for DSQL authentication. + * This initializes the token generator and validates that credentials can be obtained. + * Returns AWS_OP_SUCCESS on success, AWS_OP_ERR on failure. + */ +int +dsql_validate_aws_credentials(char **err_msg) +{ + struct credential_validation_state state = {0}; + int result = AWS_OP_ERR; + + /* Check if token generator is initialized */ + if (s_token_generator.allocator == NULL) { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Token generator not initialized during credential validation"); + if (err_msg) + *err_msg = strdup("Token generator not initialized"); + return AWS_OP_ERR; + } + + AWS_LOGF_INFO(AWS_LS_AUTH_GENERAL, "Validating AWS credentials for DSQL authentication"); + + /* Initialize synchronization primitives */ + if (aws_mutex_init(&state.mutex) != AWS_OP_SUCCESS) { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to initialize mutex for credential validation"); + if (err_msg) + *err_msg = strdup("Failed to initialize synchronization"); + return AWS_OP_ERR; + } + + if (aws_condition_variable_init(&state.condition_variable) != AWS_OP_SUCCESS) { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to initialize condition variable for credential validation"); + aws_mutex_clean_up(&state.mutex); + if (err_msg) + *err_msg = strdup("Failed to initialize synchronization"); + return AWS_OP_ERR; + } + + /* Actually retrieve credentials to validate they exist and are accessible */ + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Attempting to retrieve AWS credentials for validation"); + + if (aws_credentials_provider_get_credentials( + s_token_generator.credentials_provider, + s_on_credentials_acquired, + &state) != AWS_OP_SUCCESS) { + + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to initiate credentials retrieval"); + if (err_msg) + *err_msg = strdup("Failed to initiate credentials retrieval"); + goto cleanup; + } + + /* Wait for credentials retrieval to complete */ + aws_mutex_lock(&state.mutex); + while (!state.completed) { + aws_condition_variable_wait(&state.condition_variable, &state.mutex); + } + aws_mutex_unlock(&state.mutex); + + /* Check if credentials were successfully retrieved */ + if (state.credentials && state.error_code == AWS_OP_SUCCESS) { + AWS_LOGF_INFO(AWS_LS_AUTH_GENERAL, "AWS credentials validation completed successfully"); + AWS_LOGF_DEBUG(AWS_LS_AUTH_GENERAL, "Token generator ready for DSQL authentication"); + result = AWS_OP_SUCCESS; + + aws_credentials_release(state.credentials); + } else { + AWS_LOGF_ERROR(AWS_LS_AUTH_GENERAL, "Failed to retrieve AWS credentials: %s", + aws_error_str(state.error_code)); + if (err_msg) { + const char *error_str = aws_error_str(state.error_code); + *err_msg = strdup(error_str ? error_str : "Unknown credential retrieval error"); + } + } + +cleanup: + aws_condition_variable_clean_up(&state.condition_variable); + aws_mutex_clean_up(&state.mutex); + + return result; +} diff --git a/src/interfaces/libpq/fe-dsql-auth.h b/src/interfaces/libpq/fe-dsql-auth.h index 23c65c2f0cd5a..d24c3a600c7d7 100644 --- a/src/interfaces/libpq/fe-dsql-auth.h +++ b/src/interfaces/libpq/fe-dsql-auth.h @@ -10,10 +10,16 @@ #include +/* Initialize the DSQL token generator */ +int dsql_initialize_token_generator(void); + /* Generate a DSQL authentication token for the specified endpoint */ -char *generate_dsql_token(const char *endpoint, bool admin, char **err_msg); +char *dsql_generate_token(const char *endpoint, bool admin, char **err_msg); + +/* Initialize and validate AWS credentials early (for startup validation) */ +int dsql_validate_aws_credentials(char **err_msg); /* Clean up DSQL authentication resources */ -void dsql_auth_cleanup(void); +void dsql_cleanup(void); #endif /* FE_DSQL_AUTH_H */