Thanks to visit codestin.com
Credit goes to docs.rs

sqlmo/
migrate.rs

1use std::collections::HashMap;
2
3use crate::query::{AlterTable, Update};
4use anyhow::Result;
5
6use crate::query::AlterAction;
7use crate::query::CreateIndex;
8use crate::query::CreateTable;
9use crate::query::DropTable;
10use crate::schema::{Constraint, Schema};
11use crate::{Dialect, ToSql};
12use topo_sort::{SortResults, TopoSort};
13
14#[derive(Debug, Clone, Default)]
15pub struct MigrationOptions {
16    pub debug: bool,
17    pub allow_destructive: bool,
18}
19
20pub fn migrate(current: Schema, desired: Schema, options: &MigrationOptions) -> Result<Migration> {
21    let current_tables = current
22        .tables
23        .iter()
24        .map(|t| (&t.name, t))
25        .collect::<HashMap<_, _>>();
26    let desired_tables = desired
27        .tables
28        .iter()
29        .map(|t| (&t.name, t))
30        .collect::<HashMap<_, _>>();
31
32    let mut debug_results = vec![];
33    let mut statements = Vec::new();
34    // new tables
35    for (_name, table) in desired_tables
36        .iter()
37        .filter(|(name, _)| !current_tables.contains_key(*name))
38    {
39        let statement = Statement::CreateTable(CreateTable::from_table(table));
40        statements.push(statement);
41    }
42
43    // alter existing tables
44    for (name, desired_table) in desired_tables
45        .iter()
46        .filter(|(name, _)| current_tables.contains_key(*name))
47    {
48        let current_table = current_tables[name];
49        let current_columns = current_table
50            .columns
51            .iter()
52            .map(|c| (&c.name, c))
53            .collect::<HashMap<_, _>>();
54        // add columns
55        let mut actions = vec![];
56        for desired_column in desired_table.columns.iter() {
57            if let Some(current) = current_columns.get(&desired_column.name) {
58                if current.nullable != desired_column.nullable {
59                    actions.push(AlterAction::set_nullable(
60                        desired_column.name.clone(),
61                        desired_column.nullable,
62                    ));
63                }
64                if !desired_column.typ.lossy_eq(&current.typ) {
65                    actions.push(AlterAction::set_type(
66                        desired_column.name.clone(),
67                        desired_column.typ.clone(),
68                    ));
69                };
70                if desired_column.constraint.is_some() && current.constraint.is_none() {
71                    if let Some(c) = &desired_column.constraint {
72                        let name = desired_column.name.clone();
73                        actions.push(AlterAction::add_constraint(
74                            &desired_table.name,
75                            name,
76                            c.clone(),
77                        ));
78                    }
79                }
80            } else {
81                // add the column can be in 1 step if the column is nullable
82                if desired_column.nullable {
83                    actions.push(AlterAction::AddColumn {
84                        column: desired_column.clone(),
85                    });
86                } else {
87                    let mut nullable = desired_column.clone();
88                    nullable.nullable = true;
89                    statements.push(Statement::AlterTable(AlterTable {
90                        schema: desired_table.schema.clone(),
91                        name: desired_table.name.clone(),
92                        actions: vec![AlterAction::AddColumn { column: nullable }],
93                    }));
94                    statements.push(Statement::Update(
95                        Update::new(name)
96                            .set(
97                                &desired_column.name,
98                                "/* TODO set a value before setting the column to null */",
99                            )
100                            .where_(crate::query::Where::raw("true")),
101                    ));
102                    statements.push(Statement::AlterTable(AlterTable {
103                        schema: desired_table.schema.clone(),
104                        name: desired_table.name.clone(),
105                        actions: vec![AlterAction::AlterColumn {
106                            name: desired_column.name.clone(),
107                            action: crate::query::AlterColumnAction::SetNullable(false),
108                        }],
109                    }));
110                }
111            }
112        }
113        if actions.is_empty() {
114            debug_results.push(DebugResults::TablesIdentical(name.to_string()));
115        } else {
116            statements.push(Statement::AlterTable(AlterTable {
117                schema: desired_table.schema.clone(),
118                name: desired_table.name.clone(),
119                actions,
120            }));
121        }
122    }
123
124    for (_name, current_table) in current_tables
125        .iter()
126        .filter(|(name, _)| !desired_tables.contains_key(*name))
127    {
128        if options.allow_destructive {
129            statements.push(Statement::DropTable(DropTable {
130                schema: current_table.schema.clone(),
131                name: current_table.name.clone(),
132            }));
133        } else {
134            debug_results.push(DebugResults::SkippedDropTable(current_table.name.clone()));
135        }
136    }
137
138    // Sort statements topologically based on foreign key dependencies
139    let sorted_statements = topologically_sort_statements(&statements, &desired_tables);
140
141    Ok(Migration {
142        statements: sorted_statements,
143        debug_results,
144    })
145}
146
147/// Topologically sorts the migration statements based on foreign key dependencies
148fn topologically_sort_statements(
149    statements: &[Statement],
150    tables: &HashMap<&String, &crate::schema::Table>,
151) -> Vec<Statement> {
152    // First, extract create table statements
153    let create_statements: Vec<_> = statements
154        .iter()
155        .filter(|s| matches!(s, Statement::CreateTable(_)))
156        .collect();
157
158    if create_statements.is_empty() {
159        // If there are no create statements, just return the original
160        return statements.to_vec();
161    }
162
163    // Build a map of table name to index in the statements array
164    let mut table_to_index = HashMap::new();
165    for (i, stmt) in create_statements.iter().enumerate() {
166        if let Statement::CreateTable(create) = stmt {
167            table_to_index.insert(create.name.clone(), i);
168        }
169    }
170
171    // Set up topological sort
172    let mut topo_sort = TopoSort::new();
173
174    // Find table dependencies and add them to topo_sort
175    for stmt in &create_statements {
176        if let Statement::CreateTable(create) = stmt {
177            let table_name = &create.name;
178            let mut dependencies = Vec::new();
179
180            // Get the actual table from the tables map
181            if let Some(table) = tables.values().find(|t| &t.name == table_name) {
182                // Check all columns for foreign key constraints
183                for column in &table.columns {
184                    if let Some(Constraint::ForeignKey(fk)) = &column.constraint {
185                        dependencies.push(fk.table.clone());
186                    }
187                }
188            }
189
190            // Add this table and its dependencies to the topo_sort
191            dbg!(table_name, &dependencies);
192            topo_sort.insert(table_name.clone(), dependencies);
193        }
194    }
195
196    // Perform the sort
197    let table_order = match topo_sort.into_vec_nodes() {
198        SortResults::Full(nodes) => nodes,
199        SortResults::Partial(nodes) => {
200            // Return partial results even if there's a cycle
201            nodes
202        }
203    };
204
205    // First create a sorted list of CREATE TABLE statements
206    let mut sorted_statements = Vec::new();
207    for table_name in &table_order {
208        if let Some(&idx) = table_to_index.get(table_name) {
209            sorted_statements.push(create_statements[idx].clone());
210        }
211    }
212
213    // Add remaining statements (non-create-table) in their original order
214    for stmt in statements {
215        if !matches!(stmt, Statement::CreateTable(_)) {
216            sorted_statements.push(stmt.clone());
217        }
218    }
219
220    sorted_statements
221}
222
223#[derive(Debug)]
224pub struct Migration {
225    pub statements: Vec<Statement>,
226    pub debug_results: Vec<DebugResults>,
227}
228
229impl Migration {
230    pub fn is_empty(&self) -> bool {
231        self.statements.is_empty()
232    }
233
234    pub fn set_schema(&mut self, schema_name: &str) {
235        for statement in &mut self.statements {
236            statement.set_schema(schema_name);
237        }
238    }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq)]
242pub enum Statement {
243    CreateTable(CreateTable),
244    CreateIndex(CreateIndex),
245    AlterTable(AlterTable),
246    DropTable(DropTable),
247    Update(Update),
248}
249
250impl Statement {
251    pub fn set_schema(&mut self, schema_name: &str) {
252        match self {
253            Statement::CreateTable(s) => {
254                s.schema = Some(schema_name.to_string());
255            }
256            Statement::AlterTable(s) => {
257                s.schema = Some(schema_name.to_string());
258            }
259            Statement::DropTable(s) => {
260                s.schema = Some(schema_name.to_string());
261            }
262            Statement::CreateIndex(s) => {
263                s.schema = Some(schema_name.to_string());
264            }
265            Statement::Update(s) => {
266                s.schema = Some(schema_name.to_string());
267            }
268        }
269    }
270
271    pub fn table_name(&self) -> &str {
272        match self {
273            Statement::CreateTable(s) => &s.name,
274            Statement::AlterTable(s) => &s.name,
275            Statement::DropTable(s) => &s.name,
276            Statement::CreateIndex(s) => &s.table,
277            Statement::Update(s) => &s.table,
278        }
279    }
280}
281
282impl ToSql for Statement {
283    fn write_sql(&self, buf: &mut String, dialect: Dialect) {
284        use Statement::*;
285        match self {
286            CreateTable(c) => c.write_sql(buf, dialect),
287            CreateIndex(c) => c.write_sql(buf, dialect),
288            AlterTable(a) => a.write_sql(buf, dialect),
289            DropTable(d) => d.write_sql(buf, dialect),
290            Update(u) => u.write_sql(buf, dialect),
291        }
292    }
293}
294
295#[derive(Debug)]
296pub enum DebugResults {
297    TablesIdentical(String),
298    SkippedDropTable(String),
299}
300
301impl DebugResults {
302    pub fn table_name(&self) -> &str {
303        match self {
304            DebugResults::TablesIdentical(name) => name,
305            DebugResults::SkippedDropTable(name) => name,
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    use crate::schema::{Column, Constraint, ForeignKey};
315    use crate::Table;
316    use crate::Type;
317
318    #[test]
319    fn test_drop_table() {
320        let empty_schema = Schema::default();
321        let mut single_table_schema = Schema::default();
322        let t = Table::new("new_table");
323        single_table_schema.tables.push(t.clone());
324        let mut allow_destructive_options = MigrationOptions::default();
325        allow_destructive_options.allow_destructive = true;
326
327        let mut migrations = migrate(
328            single_table_schema,
329            empty_schema,
330            &allow_destructive_options,
331        )
332        .unwrap();
333
334        let statement = migrations.statements.pop().unwrap();
335        let expected_statement = Statement::DropTable(DropTable {
336            schema: t.schema,
337            name: t.name,
338        });
339
340        assert_eq!(statement, expected_statement);
341    }
342
343    #[test]
344    fn test_drop_table_without_destructive_operations() {
345        let empty_schema = Schema::default();
346        let mut single_table_schema = Schema::default();
347        let t = Table::new("new_table");
348        single_table_schema.tables.push(t.clone());
349        let options = MigrationOptions::default();
350
351        let migrations = migrate(single_table_schema, empty_schema, &options).unwrap();
352        assert!(migrations.statements.is_empty());
353    }
354
355    #[test]
356    fn test_topological_sort_statements() {
357        let empty_schema = Schema::default();
358        let mut schema_with_tables = Schema::default();
359
360        // Create dependent tables: User depends on Team
361        let team_table = Table::new("team").column(Column {
362            name: "id".to_string(),
363            typ: Type::I32,
364            nullable: false,
365            primary_key: true,
366            default: None,
367            constraint: None,
368        });
369
370        let user_table = Table::new("user")
371            .column(Column {
372                name: "id".to_string(),
373                typ: Type::I32,
374                nullable: false,
375                primary_key: true,
376                default: None,
377                constraint: None,
378            })
379            .column(Column {
380                name: "team_id".to_string(),
381                typ: Type::I32,
382                nullable: false,
383                primary_key: false,
384                default: None,
385                constraint: Some(Constraint::ForeignKey(ForeignKey {
386                    table: "team".to_string(),
387                    columns: vec!["id".to_string()],
388                })),
389            });
390
391        schema_with_tables.tables.push(user_table);
392        schema_with_tables.tables.push(team_table);
393
394        let options = MigrationOptions::default();
395
396        // Generate migration
397        let migration = migrate(empty_schema, schema_with_tables, &options).unwrap();
398
399        // Check that team table is created before user table
400        let team_index = migration
401            .statements
402            .iter()
403            .position(|s| {
404                if let Statement::CreateTable(create) = s {
405                    create.name == "team"
406                } else {
407                    false
408                }
409            })
410            .unwrap();
411
412        let user_index = migration
413            .statements
414            .iter()
415            .position(|s| {
416                if let Statement::CreateTable(create) = s {
417                    create.name == "user"
418                } else {
419                    false
420                }
421            })
422            .unwrap();
423
424        assert!(
425            team_index < user_index,
426            "Team table should be created before User table"
427        );
428    }
429}