tests failing
This commit is contained in:
parent
412ea3bb36
commit
cec1f0bc96
1
go.mod
1
go.mod
@ -3,6 +3,7 @@ module github.com/dalu/mongostore
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/gorilla/securecookie v1.1.1
|
||||
github.com/gorilla/sessions v1.2.0
|
||||
go.mongodb.org/mongo-driver v1.3.1
|
||||
)
|
||||
|
46
store.go
46
store.go
@ -4,13 +4,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
@ -24,31 +24,51 @@ type MongoStore struct {
|
||||
dbTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration) *MongoStore {
|
||||
func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration, keyPairs ...[]byte) *MongoStore {
|
||||
s := new(MongoStore)
|
||||
s.client = client
|
||||
s.dbDatabaseName = dbDatabase
|
||||
s.dbCollectionName = dbCollection
|
||||
s.dbTimeout = dbTimeout
|
||||
s.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: 86400 * 30 * 365,
|
||||
}
|
||||
s.Codecs = securecookie.CodecsFromPairs(keyPairs...)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *MongoStore) MaxAge(age int) {
|
||||
s.Options.MaxAge = age
|
||||
|
||||
for _, codec := range s.Codecs {
|
||||
if sc, ok := codec.(*securecookie.SecureCookie); ok {
|
||||
sc.MaxAge(age)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
||||
session := sessions.NewSession(s, name)
|
||||
options := *s.Options
|
||||
session.Options = &options
|
||||
session.Options = &sessions.Options{
|
||||
Path: s.Options.Path,
|
||||
MaxAge: s.Options.MaxAge,
|
||||
Domain: s.Options.Domain,
|
||||
Secure: s.Options.Secure,
|
||||
HttpOnly: s.Options.HttpOnly,
|
||||
}
|
||||
session.IsNew = true
|
||||
cookie, e := r.Cookie(name)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
return session, e
|
||||
}
|
||||
e = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.Codecs...)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
return session, e
|
||||
}
|
||||
e = s.load(session)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
return session, e
|
||||
} else {
|
||||
session.IsNew = false
|
||||
}
|
||||
@ -129,8 +149,14 @@ func (s *MongoStore) save(session *sessions.Session) error {
|
||||
ns.ID = oid
|
||||
ns.Data = encoded
|
||||
ns.Modified = modified
|
||||
if _, e := collection.InsertOne(ctx, ns); e != nil {
|
||||
return e
|
||||
if session.IsNew {
|
||||
if _, e := collection.InsertOne(ctx, ns); e != nil {
|
||||
return e
|
||||
}
|
||||
} else {
|
||||
if _, e := collection.UpdateOne(ctx, bson.M{"_id": oid}, ns); e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
176
store_test.go
Normal file
176
store_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright (c) 2013 Gregor Robinson.
|
||||
// Copyright (c) 2013 Brian Jones.
|
||||
// All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mongostore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
type FlashMessage struct {
|
||||
Type int
|
||||
Message string
|
||||
}
|
||||
|
||||
func TestMongoStore(t *testing.T) {
|
||||
var req *http.Request
|
||||
var rsp *httptest.ResponseRecorder
|
||||
var hdr http.Header
|
||||
var err error
|
||||
var ok bool
|
||||
var cookies []string
|
||||
var session *sessions.Session
|
||||
var flashes []interface{}
|
||||
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Round 1 ----------------------------------------------------------------
|
||||
ctx, _ := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017"))
|
||||
err = client.Ping(ctx, readpref.Primary())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
store := NewMongoStore(client, "test", "sessions", 3*time.Second)
|
||||
store.Options.Path = "/"
|
||||
store.Options.MaxAge = 86400 * 30 * 365
|
||||
defer store.Close()
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
rsp = httptest.NewRecorder()
|
||||
// Get a session.
|
||||
if session, err = store.Get(req, "session-key"); err != nil {
|
||||
t.Fatalf("Error getting session: %v", err)
|
||||
}
|
||||
// Get a flash.
|
||||
flashes = session.Flashes()
|
||||
if len(flashes) != 0 {
|
||||
t.Errorf("Expected empty flashes; Got %v", flashes)
|
||||
}
|
||||
// Add some flashes.
|
||||
session.AddFlash("foo")
|
||||
session.AddFlash("bar")
|
||||
// Custom key.
|
||||
session.AddFlash("baz", "custom_key")
|
||||
// Save.
|
||||
if err = sessions.Save(req, rsp); err != nil {
|
||||
t.Fatalf("Error saving session: %v", err)
|
||||
}
|
||||
hdr = rsp.Header()
|
||||
cookies, ok = hdr["Set-Cookie"]
|
||||
if !ok || len(cookies) != 1 {
|
||||
t.Fatalf("No cookies. Header: %v", hdr)
|
||||
}
|
||||
|
||||
// Round 2 ----------------------------------------------------------------
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
req.Header.Add("Cookie", cookies[0])
|
||||
rsp = httptest.NewRecorder()
|
||||
// Get a session.
|
||||
if session, err = store.Get(req, "session-key"); err != nil {
|
||||
t.Fatalf("Error getting session: %v", err)
|
||||
}
|
||||
// Check all saved values.
|
||||
flashes = session.Flashes()
|
||||
if len(flashes) != 2 {
|
||||
t.Fatalf("Expected flashes; Got %v", flashes)
|
||||
}
|
||||
if flashes[0] != "foo" || flashes[1] != "bar" {
|
||||
t.Errorf("Expected foo,bar; Got %v", flashes)
|
||||
}
|
||||
flashes = session.Flashes()
|
||||
if len(flashes) != 0 {
|
||||
t.Errorf("Expected dumped flashes; Got %v", flashes)
|
||||
}
|
||||
// Custom key.
|
||||
flashes = session.Flashes("custom_key")
|
||||
if len(flashes) != 1 {
|
||||
t.Errorf("Expected flashes; Got %v", flashes)
|
||||
} else if flashes[0] != "baz" {
|
||||
t.Errorf("Expected baz; Got %v", flashes)
|
||||
}
|
||||
flashes = session.Flashes("custom_key")
|
||||
if len(flashes) != 0 {
|
||||
t.Errorf("Expected dumped flashes; Got %v", flashes)
|
||||
}
|
||||
|
||||
session.Options.MaxAge = -1
|
||||
// Save.
|
||||
if err = sessions.Save(req, rsp); err != nil {
|
||||
t.Fatalf("Error saving session: %v", err)
|
||||
}
|
||||
|
||||
// Round 3 ----------------------------------------------------------------
|
||||
// Custom type
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
rsp = httptest.NewRecorder()
|
||||
// Get a session.
|
||||
if session, err = store.Get(req, "session-key"); err != nil {
|
||||
t.Fatalf("Error getting session: %v", err)
|
||||
}
|
||||
// Get a flash.
|
||||
flashes = session.Flashes()
|
||||
if len(flashes) != 0 {
|
||||
t.Errorf("Expected empty flashes; Got %v", flashes)
|
||||
}
|
||||
// Add some flashes.
|
||||
session.AddFlash(&FlashMessage{42, "foo"})
|
||||
// Save.
|
||||
if err = sessions.Save(req, rsp); err != nil {
|
||||
t.Fatalf("Error saving session: %v", err)
|
||||
}
|
||||
hdr = rsp.Header()
|
||||
cookies, ok = hdr["Set-Cookie"]
|
||||
if !ok || len(cookies) != 1 {
|
||||
t.Fatalf("No cookies. Header: %v", hdr)
|
||||
}
|
||||
|
||||
// Round 4 ----------------------------------------------------------------
|
||||
// Custom type
|
||||
|
||||
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
|
||||
req.Header.Add("Cookie", cookies[0])
|
||||
rsp = httptest.NewRecorder()
|
||||
// Get a session.
|
||||
if session, err = store.Get(req, "session-key"); err != nil {
|
||||
t.Fatalf("Error getting session: %v", err)
|
||||
}
|
||||
// Check all saved values.
|
||||
flashes = session.Flashes()
|
||||
if len(flashes) != 1 {
|
||||
t.Fatalf("Expected flashes; Got %v", flashes)
|
||||
}
|
||||
custom := flashes[0].(FlashMessage)
|
||||
if custom.Type != 42 || custom.Message != "foo" {
|
||||
t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom)
|
||||
}
|
||||
|
||||
// Delete session.
|
||||
session.Options.MaxAge = -1
|
||||
// Save.
|
||||
if err = sessions.Save(req, rsp); err != nil {
|
||||
t.Fatalf("Error saving session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(FlashMessage{})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user