package db import ( "context" "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" _ "github.com/go-sql-driver/mysql" ) const ( tokenLength = 32 accessTokenTTL = 24 * time.Hour ) // Account represents a verified account with its current tokens. type Account struct { ID int64 `json:"-"` Username string `json:"username"` RefreshToken string `json:"refresh_token"` AccessToken string `json:"access_token"` AccessTokenExpiry time.Time `json:"access_token_expires"` CreatedAt time.Time `json:"created_at"` } // DB wraps sql.DB for wikiapiserver. type DB struct { conn *sql.DB } // Connect opens a MariaDB connection, pings it, and returns a ready DB. func Connect(dsn string) (*DB, error) { sqlDB, err := sql.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("open mysql: %w", err) } if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("ping mysql: %w", err) } sqlDB.SetMaxOpenConns(25) sqlDB.SetMaxIdleConns(10) sqlDB.SetConnMaxLifetime(5 * time.Minute) return &DB{conn: sqlDB}, nil } // Close shuts down the connection pool. func (d *DB) Close() error { return d.conn.Close() } // Ping checks database liveness. func (d *DB) Ping(ctx context.Context) error { return d.conn.PingContext(ctx) } // --- crypto helpers --- func randomHex(n int) (string, error) { b := make([]byte, n) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("crypto rand: %w", err) } return hex.EncodeToString(b), nil } // WikimediaTokens holds the tokens returned by the Wikimedia auth API. type WikimediaTokens struct { RefreshToken string `json:"refresh_token"` AccessToken string `json:"access_token"` } // Wikimedialogin sends credentials to the Wikimedia Enterprise auth API // and returns the refresh and access tokens. func WikimediaLogin(ctx context.Context, username, password string) (*WikimediaTokens, error) { body := fmt.Sprintf("username=%s&password=%s", url.PathEscape(username), url.PathEscape(password)) req, err := http.NewRequestWithContext(ctx, "POST", "https://auth.enterprise.wikimedia.com/v1/login", strings.NewReader(body)) if err != nil { return nil, fmt.Errorf("new request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("wikimedia login request: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { bs, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("wikimedia login failed (status %d): %s", resp.StatusCode, string(bs)) } var tokens WikimediaTokens if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { return nil, fmt.Errorf("decode wikimedia tokens: %w", err) } return &tokens, nil } // --- error helpers --- func isDupKeyError(err error) bool { return err != nil && strings.Contains(err.Error(), "1062") } // --- queries --- // Register authenticates the user via the Wikimedia auth API, // then persists the account and tokens to the database. // If the Wikimedia API call fails, registration fails. func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Account, error) { tokens, err := WikimediaLogin(ctx, username, plaintextPW) if err != nil { return nil, fmt.Errorf("wikimedia login: %w", err) } acct, err := d.CreateAccount(ctx, username, plaintextPW, tokens.RefreshToken, tokens.AccessToken) if err != nil { return nil, err } return acct, nil } // CreateAccount inserts an account row with the given tokens. func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW, refreshToken, accessToken string) (*Account, error) { res, err := d.conn.ExecContext(ctx, `INSERT INTO account (username, password, refresh_token, access_token, refresh_token_created, access_token_created) VALUES (?, ?, ?, ?, NOW(), NOW())`, username, plaintextPW, refreshToken, accessToken, ) if err != nil { if isDupKeyError(err) { return nil, errors.New("username already exists") } return nil, fmt.Errorf("insert account: %w", err) } id, err := res.LastInsertId() if err != nil { return nil, err } now := time.Now() return &Account{ ID: id, Username: username, RefreshToken: refreshToken, AccessToken: accessToken, AccessTokenExpiry: now.Add(accessTokenTTL), CreatedAt: now, }, nil } // Authenticate verifies plaintext credentials and returns fresh tokens. func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) { var storedPW string err := d.conn.QueryRowContext(ctx, `SELECT password FROM account WHERE username = ?`, username, ).Scan(&storedPW) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.New("invalid credentials") } return nil, fmt.Errorf("query user: %w", err) } if storedPW != plaintextPW { return nil, errors.New("invalid credentials") } // Look up account id, then rotate tokens var id int64 if err := d.conn.QueryRowContext(ctx, `SELECT id FROM account WHERE username = ?`, username, ).Scan(&id); err != nil { return nil, fmt.Errorf("lookup id: %w", err) } return d.RotateTokens(ctx, id) } // RotateTokens generates new refresh + access tokens for an account. func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) { rt, err := randomHex(tokenLength) if err != nil { return nil, err } at, err := randomHex(tokenLength) if err != nil { return nil, err } res, err := d.conn.ExecContext(ctx, `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW() WHERE id = ?`, rt, at, id) if err != nil { return nil, fmt.Errorf("rotate tokens: %w", err) } if affected, _ := res.RowsAffected(); affected == 0 { return nil, errors.New("account not found") } var username string if err := d.conn.QueryRowContext(ctx, `SELECT username FROM account WHERE id = ?`, id, ).Scan(&username); err != nil { return nil, fmt.Errorf("lookup username: %w", err) } now := time.Now() return &Account{ ID: id, Username: username, RefreshToken: rt, AccessToken: at, AccessTokenExpiry: now.Add(accessTokenTTL), CreatedAt: now, }, nil } // RefreshByToken looks up an account by its refresh token and rotates both tokens. func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) { var id int64 err := d.conn.QueryRowContext(ctx, `SELECT id FROM account WHERE refresh_token = ?`, refreshToken, ).Scan(&id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.New("invalid refresh token") } return nil, fmt.Errorf("lookup refresh token: %w", err) } return d.RotateTokens(ctx, id) } // HealthCheck runs a trivial query to verify DB liveness. func (d *DB) HealthCheck(ctx context.Context) error { _, err := d.conn.ExecContext(ctx, "SELECT 1") return err }