From 985e7af8352b1ff73ee899321e5a11a10f28d507 Mon Sep 17 00:00:00 2001 From: Buravit Yenjit Date: Fri, 4 Apr 2025 22:15:13 +0700 Subject: [PATCH] test: add authentication and JWT test suites --- backend/cmd/generate_keys/main.go | 21 +++ backend/go.mod | 4 + backend/go.sum | 2 + backend/internal/api/auth_test.go | 239 +++++++++++++++++++++++++ backend/internal/utilities/jwt_test.go | 44 +++++ 5 files changed, 310 insertions(+) create mode 100644 backend/cmd/generate_keys/main.go create mode 100644 backend/internal/api/auth_test.go create mode 100644 backend/internal/utilities/jwt_test.go diff --git a/backend/cmd/generate_keys/main.go b/backend/cmd/generate_keys/main.go new file mode 100644 index 0000000..4ccf5fb --- /dev/null +++ b/backend/cmd/generate_keys/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "os" +) + +func main() { + key := make([]byte, 64) + _, err := rand.Read(key) + if err != nil { + fmt.Println("Error generating key:", err) + os.Exit(1) + } + + secret := base64.StdEncoding.EncodeToString(key) + fmt.Println("Generated JWT Secret (add to your .env as JWT_SECRET_KEY):") + fmt.Println(secret) +} diff --git a/backend/go.mod b/backend/go.mod index ff93bfa..f02fce1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -18,6 +18,7 @@ require ( github.com/rabbitmq/amqp091-go v1.10.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.36.0 google.golang.org/api v0.186.0 ) @@ -30,6 +31,7 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-logr/logr v1.4.1 // indirect @@ -49,6 +51,7 @@ require ( github.com/mfridman/interpolate v0.0.2 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -57,6 +60,7 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.6 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 95e37d6..e2f302f 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -156,6 +156,8 @@ github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/backend/internal/api/auth_test.go b/backend/internal/api/auth_test.go new file mode 100644 index 0000000..f35c534 --- /dev/null +++ b/backend/internal/api/auth_test.go @@ -0,0 +1,239 @@ +package api + +import ( + "context" + "github.com/danielgtaylor/huma/v2" + "github.com/forfarm/backend/internal/utilities" + "golang.org/x/crypto/bcrypt" + "log/slog" + "os" + "testing" + + "github.com/forfarm/backend/internal/domain" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockUserRepository struct { + mock.Mock +} + +type EmailPasswordInput struct { + Email string `json:"email" example:"Email address of the user"` + Password string `json:"password" example:"Password of the user"` +} + +func (m *MockUserRepository) GetByID(ctx context.Context, id int64) (domain.User, error) { + args := m.Called(ctx, id) + return args.Get(0).(domain.User), args.Error(1) +} + +func (m *MockUserRepository) GetByUUID(ctx context.Context, uuid string) (domain.User, error) { + args := m.Called(ctx, uuid) + return args.Get(0).(domain.User), args.Error(1) +} + +func (m *MockUserRepository) GetByUsername(ctx context.Context, username string) (domain.User, error) { + args := m.Called(ctx, username) + return args.Get(0).(domain.User), args.Error(1) +} + +func (m *MockUserRepository) GetByEmail(ctx context.Context, email string) (domain.User, error) { + args := m.Called(ctx, email) + return args.Get(0).(domain.User), args.Error(1) +} + +func (m *MockUserRepository) CreateOrUpdate(ctx context.Context, u *domain.User) error { + args := m.Called(ctx, u) + return args.Error(0) +} + +func (m *MockUserRepository) Delete(ctx context.Context, id int64) error { + args := m.Called(ctx, id) + return args.Error(0) +} + +func TestRegisterHandler(t *testing.T) { + var tests = []struct { + name string + input RegisterInput + mockSetup func(*MockUserRepository) + expectedError error + }{ + { + name: "successful registration", + input: RegisterInput{ + Body: EmailPasswordInput{ + Email: "test@example.com", + Password: "ValidPass123!", + }, + }, + mockSetup: func(m *MockUserRepository) { + m.On("GetByEmail", mock.Anything, "test@example.com").Return(domain.User{}, domain.ErrNotFound) + m.On("CreateOrUpdate", mock.Anything, mock.AnythingOfType("*domain.User")).Return(nil) + }, + expectedError: nil, + }, + + { + name: "existing email", + input: RegisterInput{ + Body: struct { + Email string `json:"email" example:"Email address of the user"` + Password string `json:"password" example:"Password of the user"` + }(struct { + Email string `json:"email"` + Password string `json:"password"` + }{ + Email: "existing@example.com", + Password: "ValidPass123!", + }), + }, + mockSetup: func(m *MockUserRepository) { + m.On("GetByEmail", mock.Anything, "existing@example.com").Return(domain.User{ + Email: "existing@example.com", + }, nil) + }, + expectedError: huma.Error409Conflict("User with this email already exists"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := &MockUserRepository{} + if tt.mockSetup != nil { + tt.mockSetup(mockRepo) + } + + api := &api{ + userRepo: mockRepo, + logger: nil, + } + + _, err := api.registerHandler(context.Background(), &tt.input) + + if tt.expectedError == nil { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tt.expectedError.Error()) + } + + mockRepo.AssertExpectations(t) + }) + } +} + +func TestLoginHandler(t *testing.T) { + correctPassword := "ValidPass123!" + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(correctPassword), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("Failed to generate bcrypt hash: %v", err) + } + + userUUID := uuid.New().String() + testUser := domain.User{ + UUID: userUUID, + Email: "test@example.com", + Password: string(hashedPassword), + IsActive: true, + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelError, + })) + + tests := []struct { + name string + input LoginInput + mockSetup func(*MockUserRepository) + expectedError error + }{ + { + name: "successful login", + input: LoginInput{ + Body: EmailPasswordInput{ + Email: "test@example.com", + Password: correctPassword, + }, + }, + mockSetup: func(m *MockUserRepository) { + m.On("GetByEmail", mock.Anything, "test@example.com").Return(testUser, nil) + }, + expectedError: nil, + }, + { + name: "invalid credentials", + input: LoginInput{ + Body: EmailPasswordInput{ + Email: "test@example.com", + Password: "wrongpassword", + }, + }, + mockSetup: func(m *MockUserRepository) { + m.On("GetByEmail", mock.Anything, "test@example.com").Return(testUser, nil) + }, + expectedError: huma.Error401Unauthorized("Invalid email or password"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := &MockUserRepository{} + if tt.mockSetup != nil { + tt.mockSetup(mockRepo) + } + + api := &api{ + userRepo: mockRepo, + logger: logger, + } + + _, err := api.loginHandler(context.Background(), &tt.input) + + if tt.expectedError == nil { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tt.expectedError.Error()) + } + + mockRepo.AssertExpectations(t) + }) + } +} + +func TestLoginHandler_TokenGeneration(t *testing.T) { + userUUID := uuid.New().String() + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("ValidPass123!"), bcrypt.DefaultCost) + testUser := domain.User{ + UUID: userUUID, + Email: "test@example.com", + Password: string(hashedPassword), + IsActive: true, + } + + mockRepo := &MockUserRepository{} + mockRepo.On("GetByEmail", mock.Anything, "test@example.com").Return(testUser, nil) + + api := &api{ + userRepo: mockRepo, + logger: nil, + } + + input := &LoginInput{ + Body: EmailPasswordInput{ + Email: "test@example.com", + Password: "ValidPass123!", + }, + } + + output, err := api.loginHandler(context.Background(), input) + assert.NoError(t, err) + assert.NotEmpty(t, output.Body.Token) + + err = utilities.VerifyJwtToken(output.Body.Token) + assert.NoError(t, err) + + extractedUUID, err := utilities.ExtractUUIDFromToken(output.Body.Token) + assert.NoError(t, err) + assert.Equal(t, userUUID, extractedUUID) +} diff --git a/backend/internal/utilities/jwt_test.go b/backend/internal/utilities/jwt_test.go new file mode 100644 index 0000000..441adf7 --- /dev/null +++ b/backend/internal/utilities/jwt_test.go @@ -0,0 +1,44 @@ +package utilities + +import ( + "github.com/golang-jwt/jwt/v5" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestJWTTokenCreationAndVerification(t *testing.T) { + testUUID := "123e4567-e89b-12d3-a456-426614174000" + + token, err := CreateJwtToken(testUUID) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + err = VerifyJwtToken(token) + assert.NoError(t, err) + + uuid, err := ExtractUUIDFromToken(token) + assert.NoError(t, err) + assert.Equal(t, testUUID, uuid) +} + +func TestExpiredJWTToken(t *testing.T) { + + oldKey := defaultSecretKey + defaultSecretKey = []byte("test-secret-key-1234567890-1234567890") + defer func() { defaultSecretKey = oldKey }() + + testUUID := "123e4567-e89b-12d3-a456-426614174000" + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "uuid": testUUID, + "exp": time.Now().Add(-time.Hour).Unix(), + }) + tokenString, err := token.SignedString(defaultSecretKey) + assert.NoError(t, err) + + err = VerifyJwtToken(tokenString) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token is expired") +}