diff --git a/core/strings.go b/core/strings.go index f260387..2e3d21f 100644 --- a/core/strings.go +++ b/core/strings.go @@ -13,6 +13,8 @@ func toSnake(str string) string { i++ } p = i + } else if !(c >= 'a' && c <= 'z') { + p = i } } return string(runes) diff --git a/core/strings_test.go b/core/strings_test.go index 990165d..44570dd 100644 --- a/core/strings_test.go +++ b/core/strings_test.go @@ -8,4 +8,6 @@ func TestToSnake(t *testing.T) { assertEqual(t, toSnake("TestID"), "test_id") assertEqual(t, toSnake("ImageURL"), "image_url") assertEqual(t, toSnake("AbCdEf"), "ab_cd_ef") + assertEqual(t, toSnake("models.ID"), "models.id") + assertEqual(t, toSnake("`models`.`ID`"), "`models`.`id`") } diff --git a/cursor.go b/cursor.go index 1af1c2e..a463385 100644 --- a/cursor.go +++ b/cursor.go @@ -1,6 +1,7 @@ package pageboy import ( + "bytes" "net/url" "reflect" "strings" @@ -150,12 +151,6 @@ func (cursor *Cursor) Scope() func(db *gorm.DB) *gorm.DB { } db = db.InstanceSet("pageboy:cursor", cursor) - - if cursor.Reverse { - db = db.Order(pbc.OrderClauseBuilder(cursor.columns...)(pbc.ReverseOrders(cursor.rawOrders)...)) - } else { - db = db.Order(pbc.OrderClauseBuilder(cursor.columns...)(cursor.rawOrders...)) - } return db.Limit(cursor.Limit) } } @@ -212,16 +207,30 @@ func cursorHandleBeforeQuery(db *gorm.DB) { ty = ty.Elem() } + table := db.Statement.Table + columns := make([]string, len(cursor.columns)) + for i, column := range cursor.columns { + buf := bytes.NewBuffer([]byte{}) + db.Dialector.QuoteTo(buf, table+"."+column) + columns[i] = buf.String() + } + if cursor.Before != "" { segments := pbc.NewCursorSegments(cursor.Before) args := segments.Interface(ty, cursor.columns...) - db = db.Scopes(pbc.MakeComparisonScope(cursor.columns, cursor.comparisons(true), cursor.nullsOrders, args)) + db = db.Scopes(pbc.MakeComparisonScope(columns, cursor.comparisons(true), cursor.nullsOrders, args)) } if cursor.After != "" { segments := pbc.NewCursorSegments(cursor.After) args := segments.Interface(ty, cursor.columns...) - db = db.Scopes(pbc.MakeComparisonScope(cursor.columns, cursor.comparisons(false), cursor.nullsOrders, args)) + db = db.Scopes(pbc.MakeComparisonScope(columns, cursor.comparisons(false), cursor.nullsOrders, args)) + } + + if cursor.Reverse { + db = db.Order(pbc.OrderClauseBuilder(columns...)(pbc.ReverseOrders(cursor.rawOrders)...)) + } else { + db = db.Order(pbc.OrderClauseBuilder(columns...)(cursor.rawOrders...)) } limit, ok := db.Statement.Clauses[new(clause.Limit).Name()] diff --git a/tests/cursor_test.go b/tests/cursor_test.go index 1c21c4a..12b0f59 100644 --- a/tests/cursor_test.go +++ b/tests/cursor_test.go @@ -3,6 +3,7 @@ package pageboy_test import ( "encoding/json" "fmt" + "math/rand" "net/url" "strconv" "testing" @@ -20,6 +21,10 @@ type cursorModel struct { Time *time.Time } +type childModel struct { + gorm.Model +} + var ( DESC string = "DESC" ASC string = "ASC" @@ -1175,6 +1180,89 @@ func TestCursorPaginateReverse(t *testing.T) { }) } +func TestCursor_where_clause_is_ambiguous(t *testing.T) { + db := openDB() + assertNoError(t, db.Migrator().DropTable(&cursorModel{})) + assertNoError(t, db.Migrator().DropTable(&childModel{})) + assertNoError(t, db.AutoMigrate(&cursorModel{})) + assertNoError(t, db.AutoMigrate(&childModel{})) + + baseURL, err := url.Parse("https://example.com/users?a=1") + assertNoError(t, err) + + now := time.Now() + var childModels [2]childModel + for i := 0; i < 2; i++ { + assertNoError(t, db.Create(&childModels[i]).Error) + } + + create := func(createdAt time.Time) *cursorModel { + model := &cursorModel{ + SubID: &childModels[rand.Intn(2)].ID, + Model: gorm.Model{ + CreatedAt: createdAt, + }, + } + assertNoError(t, db.Create(model).Error) + return model + } + model1 := create(now) + model2 := create(now) + model3 := create(now.Add(2 * time.Second)) + model4 := create(now.Add(4 * time.Second)) + + whereClause := func(db *gorm.DB) *gorm.DB { + return db.Where("cursor_models.ID < ?", 999).Joins("INNER JOIN child_models as ch ON ch.id = cursor_models.sub_id") + } + + var models []*cursorModel + cursor := &pageboy.Cursor{ + Limit: 1, + } + url := buildURL(cursor, *baseURL) + assertNoError(t, db.Scopes(whereClause, cursor.Paginate("CreatedAt", "ID").Order(ASC, ASC).Scope()).Find(&models).Error) + assertEqual(t, len(models), 1) + assertEqual(t, models[0].ID, model1.ID) + assertEqual(t, cursor.GetNextAfter(), pbc.FormatCursorString(&models[0].CreatedAt, models[0].ID)) + assertEqual(t, cursor.GetNextBefore(), pbc.FormatCursorString(&models[0].CreatedAt, models[0].ID)) + assertEqual(t, *cursor.BuildNextPagingUrls(url), pageboy.CursorPagingUrls{ + Next: "https://example.com/users?a=1" + + "&after=" + string(pbc.FormatCursorString(&models[0].CreatedAt, models[0].ID)) + + "&limit=1", + }) + + cursor = &pageboy.Cursor{ + After: cursor.GetNextAfter(), + Limit: 2, + } + url = buildURL(cursor, *baseURL) + assertNoError(t, db.Scopes(whereClause, cursor.Paginate("CreatedAt", "ID").Order(ASC, ASC).Scope()).Find(&models).Error) + assertEqual(t, len(models), 2) + assertEqual(t, models[0].ID, model2.ID) + assertEqual(t, models[1].ID, model3.ID) + assertEqual(t, cursor.GetNextAfter(), pbc.FormatCursorString(&models[1].CreatedAt, models[1].ID)) + assertEqual(t, cursor.GetNextBefore(), pbc.FormatCursorString(&models[0].CreatedAt, models[0].ID)) + assertEqual(t, *cursor.BuildNextPagingUrls(url), pageboy.CursorPagingUrls{ + Next: "https://example.com/users?a=1" + + "&after=" + string(pbc.FormatCursorString(&models[1].CreatedAt, models[1].ID)) + + "&limit=2", + }) + + cursor = &pageboy.Cursor{ + After: cursor.GetNextAfter(), + Limit: 2, + } + url = buildURL(cursor, *baseURL) + assertNoError(t, db.Scopes(whereClause, cursor.Paginate("CreatedAt", "ID").Order(ASC, ASC).Scope()).Find(&models).Error) + assertEqual(t, len(models), 1) + assertEqual(t, models[0].ID, model4.ID) + assertEqual(t, cursor.GetNextAfter(), pbc.FormatCursorString(&models[0].CreatedAt, models[0].ID)) + assertEqual(t, cursor.GetNextBefore(), pbc.FormatCursorString(&models[0].CreatedAt, models[0].ID)) + assertEqual(t, *cursor.BuildNextPagingUrls(url), pageboy.CursorPagingUrls{ + Next: "", + }) +} + func ExampleCursor() { db := openDB()