Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ go test ./...

## License

MIT[./LICENSE]
[MIT](./LICENSE)
19 changes: 13 additions & 6 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func (db *DB) SetString(k, v string, expireAt time.Time) {
}

// SetStringWithOptions sets key to value honouring NX/XX/KEEPTTL and optional expiry.
// Returns true if the value was stored, false if the preconditions failed (e.g. NX with existing key).
func (db *DB) SetStringWithOptions(now time.Time, k, v string, opts SetOptions) bool {
// Returns (stored, prevValue, prevExists). When stored is false the key was not updated.
func (db *DB) SetStringWithOptions(now time.Time, k, v string, opts SetOptions) (bool, string, bool) {
db.mu.Lock()
defer db.mu.Unlock()

Expand All @@ -63,22 +63,29 @@ func (db *DB) SetStringWithOptions(now time.Time, k, v string, opts SetOptions)
}

if opts.NX && exists {
return false
return false, "", false
}
if opts.XX && !exists {
return false
return false, "", false
}

prev := ""
prevExists := false
if exists && e != nil && e.typ == TString {
prev = e.s
prevExists = true
}

expireAt := time.Time{}
if opts.KeepTTL && exists {
if opts.KeepTTL && exists && e != nil {
expireAt = e.expireAt
}
if opts.HasExpire {
expireAt = opts.ExpireAt
}

db.entries[k] = &entry{typ: TString, s: v, expireAt: expireAt}
return true
return true, prev, prevExists
}

// GetString fetches string value if key exists and is not expired.
Expand Down
14 changes: 8 additions & 6 deletions internal/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,19 @@ func TestStore_SetStringWithOptions(t *testing.T) {
st := New()
st.SetString("exists", "v1", time.Time{})

if ok := st.SetStringWithOptions(now, "exists", "nx", SetOptions{NX: true}); ok {
if ok, _, _ := st.SetStringWithOptions(now, "exists", "nx", SetOptions{NX: true}); ok {
t.Fatalf("expected NX to fail when key exists")
}
if got, _ := st.GetString(time.Time{}, "exists"); got != "v1" {
t.Fatalf("expected value to remain v1, got %q", got)
}

if ok := st.SetStringWithOptions(now, "missing", "xx", SetOptions{XX: true}); ok {
if ok, _, _ := st.SetStringWithOptions(now, "missing", "xx", SetOptions{XX: true}); ok {
t.Fatalf("expected XX to fail when key missing")
}

st.SetString("keepttl", "v2", now.Add(10*time.Second))
if ok := st.SetStringWithOptions(now, "keepttl", "v3", SetOptions{KeepTTL: true, XX: true}); !ok {
if ok, _, _ := st.SetStringWithOptions(now, "keepttl", "v3", SetOptions{KeepTTL: true, XX: true}); !ok {
t.Fatalf("expected keepttl set to succeed")
}
if ttl := st.TTL(now, "keepttl"); ttl != 10 {
Expand All @@ -265,7 +265,7 @@ func TestStore_SetStringWithOptions(t *testing.T) {
st := New()
st.SetString("stale", "old", now.Add(-time.Second))

if ok := st.SetStringWithOptions(now, "stale", "fresh", SetOptions{NX: true}); !ok {
if ok, _, _ := st.SetStringWithOptions(now, "stale", "fresh", SetOptions{NX: true}); !ok {
t.Fatalf("expected NX to succeed when existing key is expired")
}
if got, _ := st.GetString(time.Time{}, "stale"); got != "fresh" {
Expand All @@ -278,7 +278,7 @@ func TestStore_SetStringWithOptions(t *testing.T) {

st := New()
exp := now.Add(5 * time.Second)
if ok := st.SetStringWithOptions(now, "foo", "bar", SetOptions{
if ok, _, _ := st.SetStringWithOptions(now, "foo", "bar", SetOptions{
HasExpire: true,
ExpireAt: exp,
}); !ok {
Expand All @@ -294,8 +294,10 @@ func TestStore_SetStringWithOptions(t *testing.T) {

st := New()
st.SetString("foo", "old", time.Time{})
if ok := st.SetStringWithOptions(now, "foo", "new", SetOptions{XX: true}); !ok {
if ok, prev, prevExists := st.SetStringWithOptions(now, "foo", "new", SetOptions{XX: true}); !ok {
t.Fatalf("expected XX to succeed for existing key")
} else if !prevExists || prev != "old" {
t.Fatalf("expected prev=old, ok=true; got prev=%q prevExists=%v", prev, prevExists)
}
if got, _ := st.GetString(time.Time{}, "foo"); got != "new" {
t.Fatalf("expected value new, got %q", got)
Expand Down
56 changes: 55 additions & 1 deletion internal/server/cmd_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func (s *Server) cmdSet(w *resp.Writer, r *request) error {
now := s.Now()

opts := db.SetOptions{}
returnOld := false
for i := 3; i < len(r.args); i++ {
opt := strings.ToUpper(string(r.args[i]))
switch opt {
Expand Down Expand Up @@ -70,18 +71,71 @@ func (s *Server) cmdSet(w *resp.Writer, r *request) error {
}
opts.HasExpire = true
opts.ExpireAt = now.Add(time.Duration(ms) * time.Millisecond)
case "EXAT":
if opts.HasExpire || opts.KeepTTL {
return w.WriteErrorAndFlush(ErrSyntax)
}
i++
if i >= len(r.args) {
return w.WriteErrorAndFlush(ErrSyntax)
}
sec, ok := resp.ParseInt(r.args[i])
if !ok {
return w.WriteErrorAndFlush(ErrValueNotInteger)
}
if sec <= 0 {
return w.WriteErrorAndFlush(ErrInvalidExpireTime)
}
opts.HasExpire = true
opts.ExpireAt = time.Unix(sec, 0)
case "PXAT":
if opts.HasExpire || opts.KeepTTL {
return w.WriteErrorAndFlush(ErrSyntax)
}
i++
if i >= len(r.args) {
return w.WriteErrorAndFlush(ErrSyntax)
}
ms, ok := resp.ParseInt(r.args[i])
if !ok {
return w.WriteErrorAndFlush(ErrValueNotInteger)
}
if ms <= 0 {
return w.WriteErrorAndFlush(ErrInvalidExpireTime)
}
opts.HasExpire = true
opts.ExpireAt = time.UnixMilli(ms)
case "GET":
if returnOld {
return w.WriteErrorAndFlush(ErrSyntax)
}
returnOld = true
default:
return w.WriteErrorAndFlush(ErrSyntax)
}
}

stored := s.db(r.session).SetStringWithOptions(now, key, val, opts)
stored, prev, prevExists := s.db(r.session).SetStringWithOptions(now, key, val, opts)
if !stored {
if err := w.WriteNull(); err != nil {
return err
}
return nil
}

if returnOld {
if prevExists {
if err := w.WriteBulk([]byte(prev)); err != nil {
return err
}
} else {
if err := w.WriteNull(); err != nil {
return err
}
}
return nil
}

if err := w.WriteString("OK"); err != nil {
return err
}
Expand Down
118 changes: 117 additions & 1 deletion internal/server/cmd_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"bufio"
"bytes"
"strconv"
"testing"
"time"

Expand All @@ -15,7 +16,7 @@ import (
func TestServer_cmdSet(t *testing.T) {
t.Parallel()

now := time.Now()
now := time.Unix(1_000, 0)

tcs := []struct {
name string
Expand Down Expand Up @@ -74,6 +75,38 @@ func TestServer_cmdSet(t *testing.T) {
},
want: "+OK\r\n",
},
{
name: "sets expire with EXAT",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("bar"),
[]byte("EXAT"),
[]byte(strconv.FormatInt(now.Add(15*time.Second).Unix(), 10)),
},
assert: func(t *testing.T, db *db.DB) {
if ttl := db.TTL(now, "foo"); ttl != 15 {
t.Fatalf("expected ttl 15, got %d", ttl)
}
},
want: "+OK\r\n",
},
{
name: "sets expire with PXAT",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("bar"),
[]byte("pxat"),
[]byte(strconv.FormatInt(now.Add(1500*time.Millisecond).UnixMilli(), 10)),
},
assert: func(t *testing.T, db *db.DB) {
if ttl := db.TTL(now, "foo"); ttl != 1 {
t.Fatalf("expected ttl 1, got %d", ttl)
}
},
want: "+OK\r\n",
},
{
name: "rejects invalid expire time",
args: resp.Args{
Expand Down Expand Up @@ -122,6 +155,89 @@ func TestServer_cmdSet(t *testing.T) {
},
want: "+OK\r\n",
},
{
name: "returns old value with GET option",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("baz"),
[]byte("GET"),
},
arrange: func(db *db.DB) {
db.SetString("foo", "bar", time.Time{})
},
assert: func(t *testing.T, db *db.DB) {
got, ok := db.GetString(time.Time{}, "foo")
if !ok || got != "baz" {
t.Fatalf("expected foo=baz, got %q ok=%v", got, ok)
}
},
want: "$3\r\nbar\r\n",
},
{
name: "returns null with GET when key missing",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("baz"),
[]byte("GET"),
},
assert: func(t *testing.T, db *db.DB) {
got, ok := db.GetString(time.Time{}, "foo")
if !ok || got != "baz" {
t.Fatalf("expected foo=baz, got %q ok=%v", got, ok)
}
},
want: "$-1\r\n",
},
{
name: "returns null with GET when NX succeeds on fresh key",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("baz"),
[]byte("NX"),
[]byte("GET"),
},
assert: func(t *testing.T, db *db.DB) {
got, ok := db.GetString(time.Time{}, "foo")
if !ok || got != "baz" {
t.Fatalf("expected foo=baz, got %q ok=%v", got, ok)
}
},
want: "$-1\r\n",
},
{
name: "returns null with GET when NX condition fails",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("new"),
[]byte("NX"),
[]byte("GET"),
},
arrange: func(db *db.DB) {
db.SetString("foo", "old", time.Time{})
},
assert: func(t *testing.T, db *db.DB) {
got, ok := db.GetString(time.Time{}, "foo")
if !ok || got != "old" {
t.Fatalf("expected foo to remain old, got %q ok=%v", got, ok)
}
},
want: "$-1\r\n",
},
{
name: "rejects duplicate GET option",
args: resp.Args{
[]byte("set"),
[]byte("foo"),
[]byte("bar"),
[]byte("GET"),
[]byte("GET"),
},
want: "-ERR syntax error\r\n",
},
{
name: "returns null when NX condition fails",
args: resp.Args{
Expand Down