Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor to use hashmap for all params and add server parameters to c…
…lient
  • Loading branch information
zainkabani committed Jun 15, 2023
commit f0efa97c33c98fc5a2a1e8a561ae856ff2a3c45b
10 changes: 5 additions & 5 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState}
pub fn generate_server_parameters_for_admin() -> ServerParameters {
let mut server_parameters = ServerParameters::new();

server_parameters.set_dynamic_param("application_name".to_string(), "".to_string());
server_parameters.set_dynamic_param("client_encoding".to_string(), "UTF8".to_string());
server_parameters.set_dynamic_param("server_encoding".to_string(), "UTF8".to_string());
server_parameters.set_dynamic_param("server_version".to_string(), VERSION.to_string());
server_parameters.set_dynamic_param("DateStyle".to_string(), "ISO, MDY".to_string());
server_parameters.set_param("application_name".to_string(), "".to_string(), false);
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false);
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), false);
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), false);
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false);

server_parameters
}
Expand Down
12 changes: 10 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;

Expand Down Expand Up @@ -91,6 +91,9 @@ pub struct Client<S, T> {
/// Application name for this client (defaults to pgcat)
application_name: String,

/// Server startup and session parameters that we're going to track
server_parameters: ServerParameters,

/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
}
Expand Down Expand Up @@ -491,7 +494,7 @@ where
};

// Authenticate admin user.
let (transaction_mode, server_parameters) = if admin {
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();

// Compare server and client hashes.
Expand Down Expand Up @@ -646,6 +649,9 @@ where
(transaction_mode, pool.server_parameters())
};

// Update the parameters to merge what the application sent and what's originally on the server
server_parameters.set_from_hashmap(&parameters, false);

debug!("Password authentication successful");

auth_ok(&mut write).await?;
Expand Down Expand Up @@ -680,6 +686,7 @@ where
pool_name: pool_name.clone(),
username: username.clone(),
application_name: application_name.to_string(),
server_parameters,
shutdown,
connected_to_server: false,
})
Expand Down Expand Up @@ -714,6 +721,7 @@ where
pool_name: String::from("undefined"),
username: String::from("undefined"),
application_name: String::from("undefined"),
server_parameters: ServerParameters::new(),
shutdown,
connected_to_server: false,
})
Expand Down
16 changes: 16 additions & 0 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,3 +689,19 @@ impl BytesMutReader for Cursor<&BytesMut> {
}
}
}

impl BytesMutReader for BytesMut {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let null_index = self.iter().position(|&byte| byte == b'\0');

match null_index {
Some(index) => {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
173 changes: 73 additions & 100 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::HashMap;
use std::io::{Cursor, Read};
use std::collections::{HashMap, HashSet};
use std::mem;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::{Arc, Once};
use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::TcpStream;
Expand Down Expand Up @@ -147,14 +147,24 @@ impl std::fmt::Display for CleanupState {
}
}

static INIT: Once = Once::new();
static TRACKED_PARAMETERS: Lazy<HashSet<String>> = Lazy::new(|| {
INIT.call_once(|| {
println!("Initializing the hashset");
});

let mut set = HashSet::new();
set.insert("client_encoding".to_string());
set.insert("datestyle".to_string());
set.insert("timezone".to_string());
set.insert("standard_conforming_strings".to_string());
set.insert("application_name".to_string());
set
});

#[derive(Debug, Clone)]
pub struct ServerParameters {
base_original: BytesMut,
client_encoding: String,
date_style: String,
timezone: String,
standard_conforming_strings: String,
application_name: String,
parameters: HashMap<String, String>,
}

impl Default for ServerParameters {
Expand All @@ -166,84 +176,39 @@ impl Default for ServerParameters {
impl ServerParameters {
pub fn new() -> Self {
ServerParameters {
base_original: BytesMut::new(),
client_encoding: "UTF8".to_string(),
date_style: "ISO".to_string(),
timezone: "UTC".to_string(),
standard_conforming_strings: "on".to_string(),
application_name: "pgcat".to_string(),
parameters: HashMap::new(),
}
}

pub fn set_dynamic_param(&mut self, key: String, value: String) {
match key.as_str() {
"client_encoding" => {
self.client_encoding = value;
}
"date_style" => {
self.date_style = value;
}
"timezone" => {
self.timezone = value;
}
"standard_conforming_strings" => {
self.standard_conforming_strings = value;
}
"application_name" => {
self.application_name = value;
// returns true if parameter was set, false if it already exists or was a non-tracked parameter
pub fn set_param(&mut self, key: String, value: String, startup: bool) -> bool {
println!("set_param: {} = {}", key, value);

if TRACKED_PARAMETERS.contains(&key) {
self.parameters.insert(key, value);
true
} else {
if startup {
self.parameters.insert(key, value);
return false;
}
_ => {}
true
}
}

fn set_param_from_bytes(&mut self, raw_bytes: BytesMut) {
let mut message_cursor = Cursor::new(&raw_bytes);

message_cursor.get_u8();
message_cursor.get_i32();

let key = match message_cursor.read_string() {
Ok(key) => key,
Err(_) => {
return;
},
};
let value = message_cursor.read_string().unwrap();

match key.as_str() {
"client_encoding" => {
self.client_encoding = value;
}
"date_style" => {
self.date_style = value;
}
"timezone" => {
self.timezone = value;
}
"standard_conforming_strings" => {
self.standard_conforming_strings = value;
}
"application_name" => {
self.application_name = value;
}
_ => {
self.base_original.extend(raw_bytes);
}
pub fn set_from_hashmap(&mut self, parameters: &HashMap<String, String>, startup: bool) {
// iterate through each and call set_param
for (key, value) in parameters {
self.set_param(key.to_string(), value.to_string(), startup);
}
}

pub fn get_bytes(&self) -> BytesMut {
let mut bytes = self.base_original.clone();

self.add_parameter_message("client_encoding", &self.client_encoding, &mut bytes);
self.add_parameter_message("date_style", &self.date_style, &mut bytes);
self.add_parameter_message("timezone", &self.timezone, &mut bytes);
self.add_parameter_message(
"standard_conforming_strings",
&self.standard_conforming_strings,
&mut bytes,
);
self.add_parameter_message("application_name", &self.application_name, &mut bytes);
let mut bytes = BytesMut::new();

for (key, value) in &self.parameters {
self.add_parameter_message(key, value, &mut bytes);
}

bytes
}
Expand Down Expand Up @@ -277,6 +242,9 @@ pub struct Server {
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,

// Original server parameters that we started with (used when we discard all)
original_server_parameters: ServerParameters,

/// Server information the server sent us over on startup.
server_parameters: ServerParameters,

Expand Down Expand Up @@ -728,15 +696,10 @@ impl Server {

// ParameterStatus
'S' => {
let mut bytes = BytesMut::with_capacity(len as usize + 1);
bytes.put_u8(code as u8);
bytes.put_i32(len);
bytes.resize(bytes.len() + len as usize - mem::size_of::<i32>(), b'0');

let slice_start = mem::size_of::<u8>() + mem::size_of::<i32>();
let slice_end = slice_start + len as usize - mem::size_of::<i32>();
let mut bytes = BytesMut::with_capacity(len as usize - 4);
bytes.resize(len as usize - mem::size_of::<i32>(), b'0');

match stream.read_exact(&mut bytes[slice_start..slice_end]).await {
match stream.read_exact(&mut bytes[..]).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ServerStartupError(
Expand All @@ -746,10 +709,13 @@ impl Server {
}
};

let key = bytes.read_string().unwrap();
let value = bytes.read_string().unwrap();

// Save the parameter so we can pass it to the client later.
// These can be server_encoding, client_encoding, server timezone, Postgres version,
// and many more interesting things we should know about the Postgres server we are talking to.
server_parameters.set_param_from_bytes(bytes);
let _ = server_parameters.set_param(key, value, true);
}

// BackendKeyData
Expand Down Expand Up @@ -795,6 +761,7 @@ impl Server {
address: address.clone(),
stream: BufStream::new(stream),
buffer: BytesMut::with_capacity(8196),
original_server_parameters: server_parameters.clone(),
server_parameters,
process_id,
secret_key,
Expand Down Expand Up @@ -951,24 +918,23 @@ impl Server {

// CommandComplete
'C' => {
let mut command_tag = String::new();
match message.reader().read_to_string(&mut command_tag) {
Ok(_) => {
match message.read_string() {
Ok(command) => {
// Non-exhaustive list of commands that are likely to change session variables/resources
// which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables
match command_tag.as_str() {
"SET\0" => {
// We don't detect set statements in transactions
// No great way to differentiate between set and set local
// As a result, we will miss cases when set statements are used in transactions
// This will reduce amount of discard statements sent
if !self.in_transaction {
debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_set = true;
}
}
"PREPARE\0" => {
match command.as_str() {
// "SET" => {
// // We don't detect set statements in transactions
// // No great way to differentiate between set and set local
// // As a result, we will miss cases when set statements are used in transactions
// // This will reduce amount of discard statements sent
// if !self.in_transaction {
// debug!("Server connection marked for clean up");
// self.cleanup_state.needs_cleanup_set = true;
// }
// }
"PREPARE" => {
debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_prepare = true;
}
Expand All @@ -982,6 +948,13 @@ impl Server {
}
}

'S' => {
let key = message.read_string().unwrap();
let value = message.read_string().unwrap();

self.server_parameters.set_param(key, value, false);
}

// DataRow
'D' => {
// More data is available after this message, this is not the end of the reply.
Expand Down