Skip to content

Commit 69cbe52

Browse files
DOsingaDouwe Osinga
authored andcommitted
Session manager fixes (aaif-goose#6809)
Co-authored-by: Douwe Osinga <douwe@squareup.com>
1 parent 72b3c3d commit 69cbe52

1 file changed

Lines changed: 37 additions & 24 deletions

File tree

crates/goose/src/session/session_manager.rs

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,9 @@ impl SessionStorage {
724724
}
725725

726726
async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
727-
let current_version = Self::get_schema_version(pool).await?;
727+
let mut tx = pool.begin().await?;
728+
729+
let current_version = Self::get_schema_version(&mut tx).await?;
728730

729731
if current_version < CURRENT_SCHEMA_VERSION {
730732
info!(
@@ -734,18 +736,19 @@ impl SessionStorage {
734736

735737
for version in (current_version + 1)..=CURRENT_SCHEMA_VERSION {
736738
info!(" Applying migration v{}...", version);
737-
Self::apply_migration(pool, version).await?;
738-
Self::update_schema_version(pool, version).await?;
739+
Self::apply_migration(&mut tx, version).await?;
740+
Self::update_schema_version(&mut tx, version).await?;
739741
info!(" ✓ Migration v{} complete", version);
740742
}
741743

742744
info!("All migrations complete");
743745
}
744746

747+
tx.commit().await?;
745748
Ok(())
746749
}
747750

748-
async fn get_schema_version(pool: &Pool<Sqlite>) -> Result<i32> {
751+
async fn get_schema_version(tx: &mut sqlx::Transaction<'_, Sqlite>) -> Result<i32> {
749752
let table_exists = sqlx::query_scalar::<_, bool>(
750753
r#"
751754
SELECT EXISTS (
@@ -754,30 +757,33 @@ impl SessionStorage {
754757
)
755758
"#,
756759
)
757-
.fetch_one(pool)
760+
.fetch_one(&mut **tx)
758761
.await?;
759762

760763
if !table_exists {
761764
return Ok(0);
762765
}
763766

764767
let version = sqlx::query_scalar::<_, i32>("SELECT MAX(version) FROM schema_version")
765-
.fetch_one(pool)
768+
.fetch_one(&mut **tx)
766769
.await?;
767770

768771
Ok(version)
769772
}
770773

771-
async fn update_schema_version(pool: &Pool<Sqlite>, version: i32) -> Result<()> {
774+
async fn update_schema_version(
775+
tx: &mut sqlx::Transaction<'_, Sqlite>,
776+
version: i32,
777+
) -> Result<()> {
772778
sqlx::query("INSERT INTO schema_version (version) VALUES (?)")
773779
.bind(version)
774-
.execute(pool)
780+
.execute(&mut **tx)
775781
.await?;
776782
Ok(())
777783
}
778784

779785
#[allow(clippy::too_many_lines)]
780-
async fn apply_migration(pool: &Pool<Sqlite>, version: i32) -> Result<()> {
786+
async fn apply_migration(tx: &mut sqlx::Transaction<'_, Sqlite>, version: i32) -> Result<()> {
781787
match version {
782788
1 => {
783789
sqlx::query(
@@ -788,7 +794,7 @@ impl SessionStorage {
788794
)
789795
"#,
790796
)
791-
.execute(pool)
797+
.execute(&mut **tx)
792798
.await?;
793799
}
794800
2 => {
@@ -797,7 +803,7 @@ impl SessionStorage {
797803
ALTER TABLE sessions ADD COLUMN user_recipe_values_json TEXT
798804
"#,
799805
)
800-
.execute(pool)
806+
.execute(&mut **tx)
801807
.await?;
802808
}
803809
3 => {
@@ -806,7 +812,7 @@ impl SessionStorage {
806812
ALTER TABLE messages ADD COLUMN metadata_json TEXT
807813
"#,
808814
)
809-
.execute(pool)
815+
.execute(&mut **tx)
810816
.await?;
811817
}
812818
4 => {
@@ -815,15 +821,15 @@ impl SessionStorage {
815821
ALTER TABLE sessions ADD COLUMN name TEXT DEFAULT ''
816822
"#,
817823
)
818-
.execute(pool)
824+
.execute(&mut **tx)
819825
.await?;
820826

821827
sqlx::query(
822828
r#"
823829
ALTER TABLE sessions ADD COLUMN user_set_name BOOLEAN DEFAULT FALSE
824830
"#,
825831
)
826-
.execute(pool)
832+
.execute(&mut **tx)
827833
.await?;
828834
}
829835
5 => {
@@ -832,11 +838,11 @@ impl SessionStorage {
832838
ALTER TABLE sessions ADD COLUMN session_type TEXT NOT NULL DEFAULT 'user'
833839
"#,
834840
)
835-
.execute(pool)
841+
.execute(&mut **tx)
836842
.await?;
837843

838844
sqlx::query("CREATE INDEX idx_sessions_type ON sessions(session_type)")
839-
.execute(pool)
845+
.execute(&mut **tx)
840846
.await?;
841847
}
842848
6 => {
@@ -845,15 +851,15 @@ impl SessionStorage {
845851
ALTER TABLE sessions ADD COLUMN provider_name TEXT
846852
"#,
847853
)
848-
.execute(pool)
854+
.execute(&mut **tx)
849855
.await?;
850856

851857
sqlx::query(
852858
r#"
853859
ALTER TABLE sessions ADD COLUMN model_config_json TEXT
854860
"#,
855861
)
856-
.execute(pool)
862+
.execute(&mut **tx)
857863
.await?;
858864
}
859865
7 => {
@@ -862,7 +868,7 @@ impl SessionStorage {
862868
ALTER TABLE messages ADD COLUMN message_id TEXT
863869
"#,
864870
)
865-
.execute(pool)
871+
.execute(&mut **tx)
866872
.await?;
867873

868874
sqlx::query(
@@ -871,11 +877,11 @@ impl SessionStorage {
871877
SET message_id = 'msg_' || session_id || '_' || id
872878
"#,
873879
)
874-
.execute(pool)
880+
.execute(&mut **tx)
875881
.await?;
876882

877883
sqlx::query("CREATE INDEX idx_messages_message_id ON messages(message_id)")
878-
.execute(pool)
884+
.execute(&mut **tx)
879885
.await?;
880886
}
881887
_ => {
@@ -1158,12 +1164,18 @@ impl SessionStorage {
11581164
for message in conversation.messages() {
11591165
let metadata_json = serde_json::to_string(&message.metadata)?;
11601166

1167+
let message_id = message
1168+
.id
1169+
.clone()
1170+
.unwrap_or_else(|| format!("msg_{}_{}", session_id, uuid::Uuid::new_v4()));
1171+
11611172
sqlx::query(
11621173
r#"
1163-
INSERT INTO messages (session_id, role, content_json, created_timestamp, metadata_json)
1164-
VALUES (?, ?, ?, ?, ?)
1174+
INSERT INTO messages (message_id, session_id, role, content_json, created_timestamp, metadata_json)
1175+
VALUES (?, ?, ?, ?, ?, ?)
11651176
"#,
11661177
)
1178+
.bind(message_id)
11671179
.bind(session_id)
11681180
.bind(role_to_string(&message.role))
11691181
.bind(serde_json::to_string(&message.content)?)
@@ -1403,7 +1415,8 @@ impl SessionStorage {
14031415
crate::conversation::message::MessageMetadata,
14041416
) -> crate::conversation::message::MessageMetadata,
14051417
{
1406-
let mut tx = self.pool.begin().await?;
1418+
let pool = self.pool().await?;
1419+
let mut tx = pool.begin().await?;
14071420

14081421
let current_metadata_json = sqlx::query_scalar::<_, String>(
14091422
"SELECT metadata_json FROM messages WHERE message_id = ? AND session_id = ?",

0 commit comments

Comments
 (0)