diff options
Diffstat (limited to 'db')
| -rw-r--r-- | db/db.go | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..e011334 --- /dev/null +++ b/db/db.go @@ -0,0 +1,210 @@ +package db + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +const ( + tokenLength = 32 // bytes → 64 hex chars + accessTokenTTL = 24 * time.Hour +) + +// Account represents a verified account with its current tokens. +type Account struct { + ID int64 `json:"-"` + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` + AccessTokenExpiry time.Time `json:"access_token_expires"` + CreatedAt time.Time `json:"created_at"` +} + +// DB wraps sql.DB for wikiapiserver. +type DB struct { + conn *sql.DB +} + +// Connect opens a MariaDB connection, pings it, and returns a ready DB. +func Connect(dsn string) (*DB, error) { + sqlDB, err := sql.Open("mysql", dsn) + if err != nil { + return nil, fmt.Errorf("open mysql: %w", err) + } + if err := sqlDB.Ping(); err != nil { + return nil, fmt.Errorf("ping mysql: %w", err) + } + sqlDB.SetMaxOpenConns(25) + sqlDB.SetMaxIdleConns(10) + sqlDB.SetConnMaxLifetime(5 * time.Minute) + return &DB{conn: sqlDB}, nil +} + +// Close shuts down the connection pool. +func (d *DB) Close() error { + return d.conn.Close() +} + +// Ping checks database liveness. +func (d *DB) Ping(ctx context.Context) error { + return d.conn.PingContext(ctx) +} + +// --- crypto helpers --- + +func randomHex(n int) (string, error) { + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("crypto rand: %w", err) + } + return hex.EncodeToString(b), nil +} + +func sha256hex(s string) string { + h := sha256.Sum256([]byte(s)) + return hex.EncodeToString(h[:]) +} + +// --- error helpers --- + +func isDupKeyError(err error) bool { + return err != nil && strings.Contains(err.Error(), "1062") +} + +// --- queries --- + +// CreateAccount inserts a new row with hashed credentials and fresh tokens. +func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW string) (*Account, error) { + rt, err := randomHex(tokenLength) + if err != nil { + return nil, err + } + at, err := randomHex(tokenLength) + if err != nil { + return nil, err + } + + res, err := d.conn.ExecContext(ctx, + `INSERT INTO account (username, password, refresh_token, access_token, access_token_created) + VALUES (?, SHA2(?, 256), SHA2(?, 256), SHA2(?, 256), NOW())`, + username, plaintextPW, rt, at, + ) + if err != nil { + if isDupKeyError(err) { + return nil, errors.New("username already exists") + } + return nil, fmt.Errorf("insert account: %w", err) + } + + id, err := res.LastInsertId() + if err != nil { + return nil, err + } + + now := time.Now() + return &Account{ + ID: id, + Username: username, + RefreshToken: rt, + AccessToken: at, + AccessTokenExpiry: now.Add(accessTokenTTL), + CreatedAt: now, + }, nil +} + +// Authenticate verifies plaintext credentials and returns fresh tokens. +func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) { + var storedHash string + err := d.conn.QueryRowContext(ctx, + `SELECT password FROM account WHERE username = ?`, username, + ).Scan(&storedHash) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.New("invalid credentials") + } + return nil, fmt.Errorf("query user: %w", err) + } + + if storedHash != sha256hex(plaintextPW) { + return nil, errors.New("invalid credentials") + } + + // Look up account id, then rotate tokens + var id int64 + if err := d.conn.QueryRowContext(ctx, + `SELECT id FROM account WHERE username = ?`, username, + ).Scan(&id); err != nil { + return nil, fmt.Errorf("lookup id: %w", err) + } + + return d.RotateTokens(ctx, id) +} + +// RotateTokens generates new refresh + access tokens for an account. +func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) { + rt, err := randomHex(tokenLength) + if err != nil { + return nil, err + } + at, err := randomHex(tokenLength) + if err != nil { + return nil, err + } + + res, err := d.conn.ExecContext(ctx, + `UPDATE account SET refresh_token = SHA2(?, 256), access_token = SHA2(?, 256), access_token_created = NOW() + WHERE id = ?`, rt, at, id) + if err != nil { + return nil, fmt.Errorf("rotate tokens: %w", err) + } + if affected, _ := res.RowsAffected(); affected == 0 { + return nil, errors.New("account not found") + } + + var username string + if err := d.conn.QueryRowContext(ctx, + `SELECT username FROM account WHERE id = ?`, id, + ).Scan(&username); err != nil { + return nil, fmt.Errorf("lookup username: %w", err) + } + + now := time.Now() + return &Account{ + ID: id, + Username: username, + RefreshToken: rt, + AccessToken: at, + AccessTokenExpiry: now.Add(accessTokenTTL), + CreatedAt: now, + }, nil +} + +// RefreshByToken looks up an account by its refresh token and rotates both tokens. +func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) { + var id int64 + err := d.conn.QueryRowContext(ctx, + `SELECT id FROM account WHERE refresh_token = SHA2(?, 256)`, refreshToken, + ).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.New("invalid refresh token") + } + return nil, fmt.Errorf("lookup refresh token: %w", err) + } + return d.RotateTokens(ctx, id) +} + +// HealthCheck runs a trivial query to verify DB liveness. +func (d *DB) HealthCheck(ctx context.Context) error { + _, err := d.conn.ExecContext(ctx, "SELECT 1") + return err +} |
