diff --git a/libsql/src/database.rs b/libsql/src/database.rs index bb07bb189d..7069799caa 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -712,6 +712,7 @@ impl Database { read_your_writes: *read_your_writes, context: db.sync_ctx.clone().unwrap(), state: std::sync::Arc::new(Mutex::new(State::Init)), + needs_pull: std::sync::atomic::AtomicBool::new(false).into(), }; let conn = std::sync::Arc::new(synced); diff --git a/libsql/src/sync/connection.rs b/libsql/src/sync/connection.rs index 807e49a2f4..c2809c0f57 100644 --- a/libsql/src/sync/connection.rs +++ b/libsql/src/sync/connection.rs @@ -8,7 +8,10 @@ use crate::{ sync::SyncContext, BatchRows, Error, Result, Statement, Transaction, TransactionBehavior, }; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::time::Duration; use tokio::sync::Mutex; @@ -21,6 +24,7 @@ pub struct SyncedConnection { pub read_your_writes: bool, pub context: Arc>, pub state: Arc>, + pub needs_pull: Arc, } impl SyncedConnection { @@ -89,7 +93,7 @@ impl SyncedConnection { _ => { *state = predicted_end_state; false - }, + } }; Ok(should_execute_local) @@ -106,6 +110,11 @@ impl Conn for SyncedConnection { async fn execute_batch(&self, sql: &str) -> Result { if self.should_execute_local(sql).await? { + if self.needs_pull.load(Ordering::Relaxed) { + let mut context = self.context.lock().await; + crate::sync::try_pull(&mut context, &self.local).await?; + self.needs_pull.store(false, Ordering::Relaxed); + } self.local.execute_batch(sql) } else { self.remote.execute_batch(sql).await @@ -114,6 +123,11 @@ impl Conn for SyncedConnection { async fn execute_transactional_batch(&self, sql: &str) -> Result { if self.should_execute_local(sql).await? { + if self.needs_pull.load(Ordering::Relaxed) { + let mut context = self.context.lock().await; + crate::sync::try_pull(&mut context, &self.local).await?; + self.needs_pull.store(false, Ordering::Relaxed); + } self.local.execute_transactional_batch(sql)?; Ok(BatchRows::empty()) } else { @@ -123,8 +137,17 @@ impl Conn for SyncedConnection { async fn prepare(&self, sql: &str) -> Result { if self.should_execute_local(sql).await? { - Ok(Statement { + let stmt = Statement { inner: Box::new(LibsqlStmt(self.local.prepare(sql)?)), + }; + + Ok(Statement { + inner: Box::new(SyncedStatement { + conn: self.local.clone(), + inner: stmt, + context: self.context.clone(), + needs_pull: self.needs_pull.clone(), + }), }) } else { let stmt = Statement { @@ -132,16 +155,10 @@ impl Conn for SyncedConnection { }; if self.read_your_writes { - Ok(Statement { - inner: Box::new(SyncedStatement { - conn: self.local.clone(), - context: self.context.clone(), - inner: stmt, - }), - }) - } else { - Ok(stmt) + self.needs_pull.store(true, Ordering::Relaxed); } + + Ok(stmt) } } diff --git a/libsql/src/sync/statement.rs b/libsql/src/sync/statement.rs index 679933298b..ad2183b0f4 100644 --- a/libsql/src/sync/statement.rs +++ b/libsql/src/sync/statement.rs @@ -4,13 +4,14 @@ use crate::{ statement::Stmt, sync::SyncContext, Column, Result, Rows, Statement, }; -use std::sync::Arc; +use std::sync::{atomic::{AtomicBool, Ordering}, Arc}; use tokio::sync::Mutex; pub struct SyncedStatement { pub conn: local::Connection, - pub context: Arc>, pub inner: Statement, + pub context: Arc>, + pub needs_pull: Arc, } #[async_trait::async_trait] @@ -20,24 +21,30 @@ impl Stmt for SyncedStatement { } async fn execute(&mut self, params: &Params) -> Result { - let result = self.inner.execute(params).await; - let mut context = self.context.lock().await; - crate::sync::try_pull(&mut context, &self.conn).await?; - result + if self.needs_pull.load(Ordering::Relaxed) { + let mut context = self.context.lock().await; + crate::sync::try_pull(&mut context, &self.conn).await?; + self.needs_pull.store(false, Ordering::Relaxed); + } + self.inner.execute(params).await } async fn query(&mut self, params: &Params) -> Result { - let result = self.inner.query(params).await; - let mut context = self.context.lock().await; - crate::sync::try_pull(&mut context, &self.conn).await?; - result + if self.needs_pull.load(Ordering::Relaxed) { + let mut context = self.context.lock().await; + crate::sync::try_pull(&mut context, &self.conn).await?; + self.needs_pull.store(false, Ordering::Relaxed); + } + self.inner.query(params).await } async fn run(&mut self, params: &Params) -> Result<()> { - let result = self.inner.run(params).await; - let mut context = self.context.lock().await; - crate::sync::try_pull(&mut context, &self.conn).await?; - result + if self.needs_pull.load(Ordering::Relaxed) { + let mut context = self.context.lock().await; + crate::sync::try_pull(&mut context, &self.conn).await?; + self.needs_pull.store(false, Ordering::Relaxed); + } + self.inner.run(params).await } fn interrupt(&mut self) -> Result<()> { @@ -64,3 +71,4 @@ impl Stmt for SyncedStatement { self.inner.columns() } } +