summaryrefslogtreecommitdiff
path: root/db
diff options
context:
space:
mode:
Diffstat (limited to 'db')
-rw-r--r--db/db.go210
1 files changed, 210 insertions, 0 deletions
diff --git a/db/db.go b/db/db.go
new file mode 100644
index 0000000..e011334
--- /dev/null
+++ b/db/db.go
@@ -0,0 +1,210 @@
+package db
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "database/sql"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ _ "github.com/go-sql-driver/mysql"
+)
+
+const (
+ tokenLength = 32 // bytes → 64 hex chars
+ accessTokenTTL = 24 * time.Hour
+)
+
+// Account represents a verified account with its current tokens.
+type Account struct {
+ ID int64 `json:"-"`
+ Username string `json:"username"`
+ RefreshToken string `json:"refresh_token"`
+ AccessToken string `json:"access_token"`
+ AccessTokenExpiry time.Time `json:"access_token_expires"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// DB wraps sql.DB for wikiapiserver.
+type DB struct {
+ conn *sql.DB
+}
+
+// Connect opens a MariaDB connection, pings it, and returns a ready DB.
+func Connect(dsn string) (*DB, error) {
+ sqlDB, err := sql.Open("mysql", dsn)
+ if err != nil {
+ return nil, fmt.Errorf("open mysql: %w", err)
+ }
+ if err := sqlDB.Ping(); err != nil {
+ return nil, fmt.Errorf("ping mysql: %w", err)
+ }
+ sqlDB.SetMaxOpenConns(25)
+ sqlDB.SetMaxIdleConns(10)
+ sqlDB.SetConnMaxLifetime(5 * time.Minute)
+ return &DB{conn: sqlDB}, nil
+}
+
+// Close shuts down the connection pool.
+func (d *DB) Close() error {
+ return d.conn.Close()
+}
+
+// Ping checks database liveness.
+func (d *DB) Ping(ctx context.Context) error {
+ return d.conn.PingContext(ctx)
+}
+
+// --- crypto helpers ---
+
+func randomHex(n int) (string, error) {
+ b := make([]byte, n)
+ if _, err := rand.Read(b); err != nil {
+ return "", fmt.Errorf("crypto rand: %w", err)
+ }
+ return hex.EncodeToString(b), nil
+}
+
+func sha256hex(s string) string {
+ h := sha256.Sum256([]byte(s))
+ return hex.EncodeToString(h[:])
+}
+
+// --- error helpers ---
+
+func isDupKeyError(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "1062")
+}
+
+// --- queries ---
+
+// CreateAccount inserts a new row with hashed credentials and fresh tokens.
+func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW string) (*Account, error) {
+ rt, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+ at, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := d.conn.ExecContext(ctx,
+ `INSERT INTO account (username, password, refresh_token, access_token, access_token_created)
+ VALUES (?, SHA2(?, 256), SHA2(?, 256), SHA2(?, 256), NOW())`,
+ username, plaintextPW, rt, at,
+ )
+ if err != nil {
+ if isDupKeyError(err) {
+ return nil, errors.New("username already exists")
+ }
+ return nil, fmt.Errorf("insert account: %w", err)
+ }
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return nil, err
+ }
+
+ now := time.Now()
+ return &Account{
+ ID: id,
+ Username: username,
+ RefreshToken: rt,
+ AccessToken: at,
+ AccessTokenExpiry: now.Add(accessTokenTTL),
+ CreatedAt: now,
+ }, nil
+}
+
+// Authenticate verifies plaintext credentials and returns fresh tokens.
+func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) {
+ var storedHash string
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT password FROM account WHERE username = ?`, username,
+ ).Scan(&storedHash)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("invalid credentials")
+ }
+ return nil, fmt.Errorf("query user: %w", err)
+ }
+
+ if storedHash != sha256hex(plaintextPW) {
+ return nil, errors.New("invalid credentials")
+ }
+
+ // Look up account id, then rotate tokens
+ var id int64
+ if err := d.conn.QueryRowContext(ctx,
+ `SELECT id FROM account WHERE username = ?`, username,
+ ).Scan(&id); err != nil {
+ return nil, fmt.Errorf("lookup id: %w", err)
+ }
+
+ return d.RotateTokens(ctx, id)
+}
+
+// RotateTokens generates new refresh + access tokens for an account.
+func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) {
+ rt, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+ at, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := d.conn.ExecContext(ctx,
+ `UPDATE account SET refresh_token = SHA2(?, 256), access_token = SHA2(?, 256), access_token_created = NOW()
+ WHERE id = ?`, rt, at, id)
+ if err != nil {
+ return nil, fmt.Errorf("rotate tokens: %w", err)
+ }
+ if affected, _ := res.RowsAffected(); affected == 0 {
+ return nil, errors.New("account not found")
+ }
+
+ var username string
+ if err := d.conn.QueryRowContext(ctx,
+ `SELECT username FROM account WHERE id = ?`, id,
+ ).Scan(&username); err != nil {
+ return nil, fmt.Errorf("lookup username: %w", err)
+ }
+
+ now := time.Now()
+ return &Account{
+ ID: id,
+ Username: username,
+ RefreshToken: rt,
+ AccessToken: at,
+ AccessTokenExpiry: now.Add(accessTokenTTL),
+ CreatedAt: now,
+ }, nil
+}
+
+// RefreshByToken looks up an account by its refresh token and rotates both tokens.
+func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) {
+ var id int64
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT id FROM account WHERE refresh_token = SHA2(?, 256)`, refreshToken,
+ ).Scan(&id)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("invalid refresh token")
+ }
+ return nil, fmt.Errorf("lookup refresh token: %w", err)
+ }
+ return d.RotateTokens(ctx, id)
+}
+
+// HealthCheck runs a trivial query to verify DB liveness.
+func (d *DB) HealthCheck(ctx context.Context) error {
+ _, err := d.conn.ExecContext(ctx, "SELECT 1")
+ return err
+}