1use std::collections::HashMap;
2
3use crate::query::{AlterTable, Update};
4use anyhow::Result;
5
6use crate::query::AlterAction;
7use crate::query::CreateIndex;
8use crate::query::CreateTable;
9use crate::query::DropTable;
10use crate::schema::{Constraint, Schema};
11use crate::{Dialect, ToSql};
12use topo_sort::{SortResults, TopoSort};
13
14#[derive(Debug, Clone, Default)]
15pub struct MigrationOptions {
16 pub debug: bool,
17 pub allow_destructive: bool,
18}
19
20pub fn migrate(current: Schema, desired: Schema, options: &MigrationOptions) -> Result<Migration> {
21 let current_tables = current
22 .tables
23 .iter()
24 .map(|t| (&t.name, t))
25 .collect::<HashMap<_, _>>();
26 let desired_tables = desired
27 .tables
28 .iter()
29 .map(|t| (&t.name, t))
30 .collect::<HashMap<_, _>>();
31
32 let mut debug_results = vec![];
33 let mut statements = Vec::new();
34 for (_name, table) in desired_tables
36 .iter()
37 .filter(|(name, _)| !current_tables.contains_key(*name))
38 {
39 let statement = Statement::CreateTable(CreateTable::from_table(table));
40 statements.push(statement);
41 }
42
43 for (name, desired_table) in desired_tables
45 .iter()
46 .filter(|(name, _)| current_tables.contains_key(*name))
47 {
48 let current_table = current_tables[name];
49 let current_columns = current_table
50 .columns
51 .iter()
52 .map(|c| (&c.name, c))
53 .collect::<HashMap<_, _>>();
54 let mut actions = vec![];
56 for desired_column in desired_table.columns.iter() {
57 if let Some(current) = current_columns.get(&desired_column.name) {
58 if current.nullable != desired_column.nullable {
59 actions.push(AlterAction::set_nullable(
60 desired_column.name.clone(),
61 desired_column.nullable,
62 ));
63 }
64 if !desired_column.typ.lossy_eq(¤t.typ) {
65 actions.push(AlterAction::set_type(
66 desired_column.name.clone(),
67 desired_column.typ.clone(),
68 ));
69 };
70 if desired_column.constraint.is_some() && current.constraint.is_none() {
71 if let Some(c) = &desired_column.constraint {
72 let name = desired_column.name.clone();
73 actions.push(AlterAction::add_constraint(
74 &desired_table.name,
75 name,
76 c.clone(),
77 ));
78 }
79 }
80 } else {
81 if desired_column.nullable {
83 actions.push(AlterAction::AddColumn {
84 column: desired_column.clone(),
85 });
86 } else {
87 let mut nullable = desired_column.clone();
88 nullable.nullable = true;
89 statements.push(Statement::AlterTable(AlterTable {
90 schema: desired_table.schema.clone(),
91 name: desired_table.name.clone(),
92 actions: vec![AlterAction::AddColumn { column: nullable }],
93 }));
94 statements.push(Statement::Update(
95 Update::new(name)
96 .set(
97 &desired_column.name,
98 "/* TODO set a value before setting the column to null */",
99 )
100 .where_(crate::query::Where::raw("true")),
101 ));
102 statements.push(Statement::AlterTable(AlterTable {
103 schema: desired_table.schema.clone(),
104 name: desired_table.name.clone(),
105 actions: vec![AlterAction::AlterColumn {
106 name: desired_column.name.clone(),
107 action: crate::query::AlterColumnAction::SetNullable(false),
108 }],
109 }));
110 }
111 }
112 }
113 if actions.is_empty() {
114 debug_results.push(DebugResults::TablesIdentical(name.to_string()));
115 } else {
116 statements.push(Statement::AlterTable(AlterTable {
117 schema: desired_table.schema.clone(),
118 name: desired_table.name.clone(),
119 actions,
120 }));
121 }
122 }
123
124 for (_name, current_table) in current_tables
125 .iter()
126 .filter(|(name, _)| !desired_tables.contains_key(*name))
127 {
128 if options.allow_destructive {
129 statements.push(Statement::DropTable(DropTable {
130 schema: current_table.schema.clone(),
131 name: current_table.name.clone(),
132 }));
133 } else {
134 debug_results.push(DebugResults::SkippedDropTable(current_table.name.clone()));
135 }
136 }
137
138 let sorted_statements = topologically_sort_statements(&statements, &desired_tables);
140
141 Ok(Migration {
142 statements: sorted_statements,
143 debug_results,
144 })
145}
146
147fn topologically_sort_statements(
149 statements: &[Statement],
150 tables: &HashMap<&String, &crate::schema::Table>,
151) -> Vec<Statement> {
152 let create_statements: Vec<_> = statements
154 .iter()
155 .filter(|s| matches!(s, Statement::CreateTable(_)))
156 .collect();
157
158 if create_statements.is_empty() {
159 return statements.to_vec();
161 }
162
163 let mut table_to_index = HashMap::new();
165 for (i, stmt) in create_statements.iter().enumerate() {
166 if let Statement::CreateTable(create) = stmt {
167 table_to_index.insert(create.name.clone(), i);
168 }
169 }
170
171 let mut topo_sort = TopoSort::new();
173
174 for stmt in &create_statements {
176 if let Statement::CreateTable(create) = stmt {
177 let table_name = &create.name;
178 let mut dependencies = Vec::new();
179
180 if let Some(table) = tables.values().find(|t| &t.name == table_name) {
182 for column in &table.columns {
184 if let Some(Constraint::ForeignKey(fk)) = &column.constraint {
185 dependencies.push(fk.table.clone());
186 }
187 }
188 }
189
190 dbg!(table_name, &dependencies);
192 topo_sort.insert(table_name.clone(), dependencies);
193 }
194 }
195
196 let table_order = match topo_sort.into_vec_nodes() {
198 SortResults::Full(nodes) => nodes,
199 SortResults::Partial(nodes) => {
200 nodes
202 }
203 };
204
205 let mut sorted_statements = Vec::new();
207 for table_name in &table_order {
208 if let Some(&idx) = table_to_index.get(table_name) {
209 sorted_statements.push(create_statements[idx].clone());
210 }
211 }
212
213 for stmt in statements {
215 if !matches!(stmt, Statement::CreateTable(_)) {
216 sorted_statements.push(stmt.clone());
217 }
218 }
219
220 sorted_statements
221}
222
223#[derive(Debug)]
224pub struct Migration {
225 pub statements: Vec<Statement>,
226 pub debug_results: Vec<DebugResults>,
227}
228
229impl Migration {
230 pub fn is_empty(&self) -> bool {
231 self.statements.is_empty()
232 }
233
234 pub fn set_schema(&mut self, schema_name: &str) {
235 for statement in &mut self.statements {
236 statement.set_schema(schema_name);
237 }
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq)]
242pub enum Statement {
243 CreateTable(CreateTable),
244 CreateIndex(CreateIndex),
245 AlterTable(AlterTable),
246 DropTable(DropTable),
247 Update(Update),
248}
249
250impl Statement {
251 pub fn set_schema(&mut self, schema_name: &str) {
252 match self {
253 Statement::CreateTable(s) => {
254 s.schema = Some(schema_name.to_string());
255 }
256 Statement::AlterTable(s) => {
257 s.schema = Some(schema_name.to_string());
258 }
259 Statement::DropTable(s) => {
260 s.schema = Some(schema_name.to_string());
261 }
262 Statement::CreateIndex(s) => {
263 s.schema = Some(schema_name.to_string());
264 }
265 Statement::Update(s) => {
266 s.schema = Some(schema_name.to_string());
267 }
268 }
269 }
270
271 pub fn table_name(&self) -> &str {
272 match self {
273 Statement::CreateTable(s) => &s.name,
274 Statement::AlterTable(s) => &s.name,
275 Statement::DropTable(s) => &s.name,
276 Statement::CreateIndex(s) => &s.table,
277 Statement::Update(s) => &s.table,
278 }
279 }
280}
281
282impl ToSql for Statement {
283 fn write_sql(&self, buf: &mut String, dialect: Dialect) {
284 use Statement::*;
285 match self {
286 CreateTable(c) => c.write_sql(buf, dialect),
287 CreateIndex(c) => c.write_sql(buf, dialect),
288 AlterTable(a) => a.write_sql(buf, dialect),
289 DropTable(d) => d.write_sql(buf, dialect),
290 Update(u) => u.write_sql(buf, dialect),
291 }
292 }
293}
294
295#[derive(Debug)]
296pub enum DebugResults {
297 TablesIdentical(String),
298 SkippedDropTable(String),
299}
300
301impl DebugResults {
302 pub fn table_name(&self) -> &str {
303 match self {
304 DebugResults::TablesIdentical(name) => name,
305 DebugResults::SkippedDropTable(name) => name,
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 use crate::schema::{Column, Constraint, ForeignKey};
315 use crate::Table;
316 use crate::Type;
317
318 #[test]
319 fn test_drop_table() {
320 let empty_schema = Schema::default();
321 let mut single_table_schema = Schema::default();
322 let t = Table::new("new_table");
323 single_table_schema.tables.push(t.clone());
324 let mut allow_destructive_options = MigrationOptions::default();
325 allow_destructive_options.allow_destructive = true;
326
327 let mut migrations = migrate(
328 single_table_schema,
329 empty_schema,
330 &allow_destructive_options,
331 )
332 .unwrap();
333
334 let statement = migrations.statements.pop().unwrap();
335 let expected_statement = Statement::DropTable(DropTable {
336 schema: t.schema,
337 name: t.name,
338 });
339
340 assert_eq!(statement, expected_statement);
341 }
342
343 #[test]
344 fn test_drop_table_without_destructive_operations() {
345 let empty_schema = Schema::default();
346 let mut single_table_schema = Schema::default();
347 let t = Table::new("new_table");
348 single_table_schema.tables.push(t.clone());
349 let options = MigrationOptions::default();
350
351 let migrations = migrate(single_table_schema, empty_schema, &options).unwrap();
352 assert!(migrations.statements.is_empty());
353 }
354
355 #[test]
356 fn test_topological_sort_statements() {
357 let empty_schema = Schema::default();
358 let mut schema_with_tables = Schema::default();
359
360 let team_table = Table::new("team").column(Column {
362 name: "id".to_string(),
363 typ: Type::I32,
364 nullable: false,
365 primary_key: true,
366 default: None,
367 constraint: None,
368 });
369
370 let user_table = Table::new("user")
371 .column(Column {
372 name: "id".to_string(),
373 typ: Type::I32,
374 nullable: false,
375 primary_key: true,
376 default: None,
377 constraint: None,
378 })
379 .column(Column {
380 name: "team_id".to_string(),
381 typ: Type::I32,
382 nullable: false,
383 primary_key: false,
384 default: None,
385 constraint: Some(Constraint::ForeignKey(ForeignKey {
386 table: "team".to_string(),
387 columns: vec!["id".to_string()],
388 })),
389 });
390
391 schema_with_tables.tables.push(user_table);
392 schema_with_tables.tables.push(team_table);
393
394 let options = MigrationOptions::default();
395
396 let migration = migrate(empty_schema, schema_with_tables, &options).unwrap();
398
399 let team_index = migration
401 .statements
402 .iter()
403 .position(|s| {
404 if let Statement::CreateTable(create) = s {
405 create.name == "team"
406 } else {
407 false
408 }
409 })
410 .unwrap();
411
412 let user_index = migration
413 .statements
414 .iter()
415 .position(|s| {
416 if let Statement::CreateTable(create) = s {
417 create.name == "user"
418 } else {
419 false
420 }
421 })
422 .unwrap();
423
424 assert!(
425 team_index < user_index,
426 "Team table should be created before User table"
427 );
428 }
429}