package db import ( "context" "crypto/rand" "crypto/sha256" "database/sql" "encoding/hex" "errors" "fmt" "strings" "time" _ "github.com/go-sql-driver/mysql" ) const ( tokenLength = 32 accessTokenTTL = 24 * time.Hour ) // Account represents a verified account with its current tokens. type Account struct { ID int64 `json:"-"` Username string `json:"username"` RefreshToken string `json:"refresh_token"` AccessToken string `json:"access_token"` AccessTokenExpiry time.Time `json:"access_token_expires"` CreatedAt time.Time `json:"created_at"` } // DB wraps sql.DB for wikiapiserver. type DB struct { conn *sql.DB } // Connect opens a MariaDB connection, pings it, and returns a ready DB. func Connect(dsn string) (*DB, error) { sqlDB, err := sql.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("open mysql: %w", err) } if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("ping mysql: %w", err) } sqlDB.SetMaxOpenConns(25) sqlDB.SetMaxIdleConns(10) sqlDB.SetConnMaxLifetime(5 * time.Minute) return &DB{conn: sqlDB}, nil } // Close shuts down the connection pool. func (d *DB) Close() error { return d.conn.Close() } // Ping checks database liveness. func (d *DB) Ping(ctx context.Context) error { return d.conn.PingContext(ctx) } // --- crypto helpers --- func randomHex(n int) (string, error) { b := make([]byte, n) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("crypto rand: %w", err) } return hex.EncodeToString(b), nil } func sha256hex(s string) string { h := sha256.Sum256([]byte(s)) return hex.EncodeToString(h[:]) } // --- error helpers --- func isDupKeyError(err error) bool { return err != nil && strings.Contains(err.Error(), "1062") } // --- queries --- // CreateAccount inserts a new account with username and plaintext password. // Tokens are not generated here; they are set later via the Wikimedia API. func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW string) (*Account, error) { res, err := d.conn.ExecContext(ctx, `INSERT INTO account (username, password, refresh_token, access_token, access_token_created) VALUES (?, ?, '', '', NOW())`, username, plaintextPW, ) if err != nil { if isDupKeyError(err) { return nil, errors.New("username already exists") } return nil, fmt.Errorf("insert account: %w", err) } id, err := res.LastInsertId() if err != nil { return nil, err } now := time.Now() return &Account{ ID: id, Username: username, RefreshToken: "", AccessToken: "", AccessTokenExpiry: now, CreatedAt: now, }, nil } // Authenticate verifies plaintext credentials and returns fresh tokens. func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) { var storedPW string err := d.conn.QueryRowContext(ctx, `SELECT password FROM account WHERE username = ?`, username, ).Scan(&storedPW) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.New("invalid credentials") } return nil, fmt.Errorf("query user: %w", err) } if storedPW != plaintextPW { return nil, errors.New("invalid credentials") } // Look up account id, then rotate tokens var id int64 if err := d.conn.QueryRowContext(ctx, `SELECT id FROM account WHERE username = ?`, username, ).Scan(&id); err != nil { return nil, fmt.Errorf("lookup id: %w", err) } return d.RotateTokens(ctx, id) } // RotateTokens generates new refresh + access tokens for an account. func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) { rt, err := randomHex(tokenLength) if err != nil { return nil, err } at, err := randomHex(tokenLength) if err != nil { return nil, err } res, err := d.conn.ExecContext(ctx, `UPDATE account SET refresh_token = SHA2(?, 256), access_token = SHA2(?, 256), access_token_created = NOW() WHERE id = ?`, rt, at, id) if err != nil { return nil, fmt.Errorf("rotate tokens: %w", err) } if affected, _ := res.RowsAffected(); affected == 0 { return nil, errors.New("account not found") } var username string if err := d.conn.QueryRowContext(ctx, `SELECT username FROM account WHERE id = ?`, id, ).Scan(&username); err != nil { return nil, fmt.Errorf("lookup username: %w", err) } now := time.Now() return &Account{ ID: id, Username: username, RefreshToken: rt, AccessToken: at, AccessTokenExpiry: now.Add(accessTokenTTL), CreatedAt: now, }, nil } // RefreshByToken looks up an account by its refresh token and rotates both tokens. func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) { var id int64 err := d.conn.QueryRowContext(ctx, `SELECT id FROM account WHERE refresh_token = SHA2(?, 256)`, refreshToken, ).Scan(&id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.New("invalid refresh token") } return nil, fmt.Errorf("lookup refresh token: %w", err) } return d.RotateTokens(ctx, id) } // HealthCheck runs a trivial query to verify DB liveness. func (d *DB) HealthCheck(ctx context.Context) error { _, err := d.conn.ExecContext(ctx, "SELECT 1") return err }