Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
User server parameters struct instead of server info bytesmut
  • Loading branch information
zainkabani committed Jun 14, 2023
commit ea669b22f531b6afe13e9bd40f15a94e0e75d107
17 changes: 9 additions & 8 deletions src/admin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::pool::BanReason;
use crate::server::ServerParameters;
use crate::stats::pool::PoolStats;
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
Expand All @@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};

pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new();
pub fn generate_server_parameters_for_admin() -> ServerParameters {
let mut server_parameters = ServerParameters::new();

server_info.put(server_parameter_message("application_name", ""));
server_info.put(server_parameter_message("client_encoding", "UTF8"));
server_info.put(server_parameter_message("server_encoding", "UTF8"));
server_info.put(server_parameter_message("server_version", VERSION));
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
server_parameters.set_dynamic_param("application_name".to_string(), "".to_string());
server_parameters.set_dynamic_param("client_encoding".to_string(), "UTF8".to_string());
server_parameters.set_dynamic_param("server_encoding".to_string(), "UTF8".to_string());
server_parameters.set_dynamic_param("server_version".to_string(), VERSION.to_string());
server_parameters.set_dynamic_param("DateStyle".to_string(), "ISO, MDY".to_string());

server_info
server_parameters
}

/// Handle admin client.
Expand Down
10 changes: 5 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;

use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::constants::*;
Expand Down Expand Up @@ -491,7 +491,7 @@ where
};

// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let (transaction_mode, server_parameters) = if admin {
let config = get_config();

// Compare server and client hashes.
Expand All @@ -510,7 +510,7 @@ where
return Err(error);
}

(false, generate_server_info_for_admin())
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
else {
Expand Down Expand Up @@ -643,13 +643,13 @@ where
}
}

(transaction_mode, pool.server_info())
(transaction_mode, pool.server_parameters())
};

debug!("Password authentication successful");

auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?;
write_all(&mut write, server_parameters.get_bytes()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;

Expand Down
49 changes: 28 additions & 21 deletions src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use bytes::{BufMut, BytesMut};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
Expand All @@ -25,7 +24,7 @@ use crate::errors::Error;

use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, ServerStats};

Expand Down Expand Up @@ -188,10 +187,10 @@ pub struct ConnectionPool {
/// that should not be queried.
banlist: BanList,

/// The server information (K messages) have to be passed to the
/// The server information has to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: Arc<RwLock<BytesMut>>,
/// on pool creation and save the startup parameters here.
original_server_parameters: ServerParameters,

/// Pool configuration.
pub settings: PoolSettings,
Expand Down Expand Up @@ -258,6 +257,7 @@ impl ConnectionPool {
.clone()
.into_keys()
.collect::<Vec<String>>();
let mut original_server_parameters = ServerParameters::new();

// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
Expand Down Expand Up @@ -415,6 +415,20 @@ impl ConnectionPool {
pool.build_unchecked(manager)
};

// Set original server parameters by getting a connection
// If we don't want to validate then a default set of parameters will be used
if config.general.validate_config {
match pool.get().await {
Ok(conn) => {
original_server_parameters = conn.server_parameters();
}
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", address, err);
return Err(Error::ServerError);
}
};
}

pools.push(pool);
servers.push(address);
}
Expand All @@ -437,7 +451,7 @@ impl ConnectionPool {
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
original_server_parameters,
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: match user.pool_mode {
Expand Down Expand Up @@ -488,6 +502,7 @@ impl ConnectionPool {
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config {
// TODO: this can't be optional since we need some startup parameters to bootstrap with
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
Expand All @@ -512,30 +527,22 @@ impl ConnectionPool {
pub async fn validate(&mut self) -> Result<(), Error> {
let mut futures = Vec::new();
let validated = Arc::clone(&self.validated);
validated.store(true, Ordering::Relaxed);

for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let databases = self.databases.clone();
let validated = Arc::clone(&validated);
let pool_server_info = Arc::clone(&self.server_info);

let task = tokio::task::spawn(async move {
let connection = match databases[shard][server].get().await {
Ok(conn) => conn,
match databases[shard][server].get().await {
Ok(_) => {}
Err(err) => {
validated.store(false, Ordering::Relaxed);
error!("Shard {} down or misconfigured: {:?}", shard, err);
return;
}
};

let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();

let mut guard = pool_server_info.write();
guard.clear();
guard.put(server_info.clone());
validated.store(true, Ordering::Relaxed);
});

futures.push(task);
Expand All @@ -546,7 +553,7 @@ impl ConnectionPool {

// TODO: compare server information to make sure
// all shards are running identical configurations.
if self.server_info.read().is_empty() {
if !self.validated() {
error!("Could not validate connection pool");
return Err(Error::AllServersDown);
}
Expand Down Expand Up @@ -906,8 +913,8 @@ impl ConnectionPool {
&self.addresses[shard][server]
}

pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone()
pub fn server_parameters(&self) -> ServerParameters {
self.original_server_parameters.clone()
}

fn busy_connection_count(&self, address: &Address) -> u32 {
Expand Down
Loading