package db import ( "context" "bytes" "database/sql" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" _ "github.com/go-sql-driver/mysql" ) const accessTokenTTL = 24 * time.Hour // Account represents a verified account with its current tokens. type Account struct { ID int64 `json:"-"` Username string `json:"username"` RefreshToken string `json:"refresh_token"` AccessToken string `json:"access_token"` AccessTokenExpiry time.Time `json:"access_token_expires"` CreatedAt time.Time `json:"created_at"` } // DB wraps sql.DB for wikiapiserver. type DB struct { conn *sql.DB } // Connect opens a MariaDB connection, pings it, and returns a ready DB. func Connect(dsn string) (*DB, error) { sqlDB, err := sql.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("open mysql: %w", err) } if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("ping mysql: %w", err) } sqlDB.SetMaxOpenConns(25) sqlDB.SetMaxIdleConns(10) sqlDB.SetConnMaxLifetime(5 * time.Minute) return &DB{conn: sqlDB}, nil } // Close shuts down the connection pool. func (d *DB) Close() error { return d.conn.Close() } // Ping checks database liveness. func (d *DB) Ping(ctx context.Context) error { return d.conn.PingContext(ctx) } // WikimediaTokens holds the tokens returned by the Wikimedia auth API. type WikimediaTokens struct { RefreshToken string `json:"refresh_token"` AccessToken string `json:"access_token"` } // Wikimedialogin sends credentials to the Wikimedia Enterprise auth API // and returns the refresh and access tokens. func WikimediaLogin(ctx context.Context, username, password string) (*WikimediaTokens, error) { body := fmt.Sprintf("username=%s&password=%s", url.PathEscape(username), url.PathEscape(password)) req, err := http.NewRequestWithContext(ctx, "POST", "https://auth.enterprise.wikimedia.com/v1/login", strings.NewReader(body)) if err != nil { return nil, fmt.Errorf("new request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("wikimedia login request: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { bs, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("wikimedia login failed (status %d): %s", resp.StatusCode, string(bs)) } var tokens WikimediaTokens if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { return nil, fmt.Errorf("decode wikimedia tokens: %w", err) } return &tokens, nil } // WikimediaTokenRefresh posts to the Wikimedia token-refresh endpoint // and returns new tokens from the response. func WikimediaTokenRefresh(ctx context.Context, username, refreshToken string) (*WikimediaTokens, error) { body, err := json.Marshal(map[string]string{ "username": username, "refresh_token": refreshToken, }) if err != nil { return nil, fmt.Errorf("marshal token refresh body: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", "https://auth.enterprise.wikimedia.com/v1/token-refresh", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("new request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("token refresh request: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { bs, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(bs)) } var tokens WikimediaTokens if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { return nil, fmt.Errorf("decode token refresh response: %w", err) } return &tokens, nil } // --- error helpers --- func isDupKeyError(err error) bool { return err != nil && strings.Contains(err.Error(), "1062") } // --- queries --- // Register authenticates the user via the Wikimedia auth API, // then persists the account and tokens to the database. // If the Wikimedia API call fails, registration fails. func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Account, error) { tokens, err := WikimediaLogin(ctx, username, plaintextPW) if err != nil { return nil, fmt.Errorf("wikimedia login: %w", err) } acct, err := d.CreateAccount(ctx, username, plaintextPW, tokens.RefreshToken, tokens.AccessToken) if err != nil { return nil, err } return acct, nil } // CreateAccount inserts an account row with the given tokens. func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW, refreshToken, accessToken string) (*Account, error) { res, err := d.conn.ExecContext(ctx, `INSERT INTO account (username, password, refresh_token, access_token, refresh_token_created, access_token_created) VALUES (?, ?, ?, ?, NOW(), NOW())`, username, plaintextPW, refreshToken, accessToken, ) if err != nil { if isDupKeyError(err) { return nil, errors.New("username already exists") } return nil, fmt.Errorf("insert account: %w", err) } id, err := res.LastInsertId() if err != nil { return nil, err } now := time.Now() return &Account{ ID: id, Username: username, RefreshToken: refreshToken, AccessToken: accessToken, AccessTokenExpiry: now.Add(accessTokenTTL), CreatedAt: now, }, nil } // Authenticate verifies plaintext credentials and returns an account // with fresh tokens obtained from the Wikimedia auth API. func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) { var storedPW string err := d.conn.QueryRowContext(ctx, `SELECT password FROM account WHERE username = ?`, username, ).Scan(&storedPW) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.New("invalid credentials") } return nil, fmt.Errorf("query user: %w", err) } if storedPW != plaintextPW { return nil, errors.New("invalid credentials") } tokens, err := WikimediaLogin(ctx, username, plaintextPW) if err != nil { return nil, fmt.Errorf("wikimedia login: %w", err) } if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { return nil, err } return d.getAccountByUsername(ctx, username) } // RefreshTokens checks the age of the stored tokens and refreshes // them via the Wikimedia auth API as needed: // - refresh_token older than 90 days: full re-auth via WikimediaLogin // - access_token older than 24 hours: token refresh via WikimediaTokenRefresh func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*Account, error) { var id int64 var password, storedRT, accessToken string var refreshCreated, accessCreated time.Time err := d.conn.QueryRowContext(ctx, `SELECT id, password, refresh_token, access_token, refresh_token_created, access_token_created FROM account WHERE username = ?`, username, ).Scan(&id, &password, &storedRT, &accessToken, &refreshCreated, &accessCreated) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, errors.New("account not found") } return nil, fmt.Errorf("query account: %w", err) } // Verify the provided refresh token matches if storedRT != providedRT { return nil, errors.New("invalid refresh token") } // Refresh token older than 90 days: re-authenticate if time.Since(refreshCreated) > 90*24*time.Hour { tokens, err := WikimediaLogin(ctx, username, password) if err != nil { return nil, fmt.Errorf("wikimedia re-auth: %w", err) } if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { return nil, err } return d.getAccountByUsername(ctx, username) } // Access token older than 24 hours: refresh via token-refresh endpoint if time.Since(accessCreated) > 24*time.Hour { tokens, err := WikimediaTokenRefresh(ctx, username, storedRT) if err != nil { return nil, fmt.Errorf("wikimedia token refresh: %w", err) } // Update both tokens (refresh response may include a new refresh token) if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil { return nil, err } return d.getAccountByUsername(ctx, username) } // Tokens are still valid return d.getAccountByUsername(ctx, username) } // updateTokens writes new refresh and access tokens for a user. func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error { _, err := d.conn.ExecContext(ctx, `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW() WHERE username = ?`, refreshToken, accessToken, username) if err != nil { return fmt.Errorf("update tokens: %w", err) } return nil } // updateAccessToken writes a new access token for a user. func (d *DB) updateAccessToken(ctx context.Context, username, accessToken string) error { _, err := d.conn.ExecContext(ctx, `UPDATE account SET access_token = ?, access_token_created = NOW() WHERE username = ?`, accessToken, username) if err != nil { return fmt.Errorf("update access token: %w", err) } return nil } // getAccountByUsername fetches the current account state. func (d *DB) getAccountByUsername(ctx context.Context, username string) (*Account, error) { var acct Account var refreshCreated, accessCreated time.Time err := d.conn.QueryRowContext(ctx, `SELECT id, username, refresh_token, access_token, refresh_token_created, access_token_created, created_at FROM account WHERE username = ?`, username, ).Scan(&acct.ID, &acct.Username, &acct.RefreshToken, &acct.AccessToken, &refreshCreated, &accessCreated, &acct.CreatedAt) if err != nil { return nil, fmt.Errorf("get account: %w", err) } acct.AccessTokenExpiry = accessCreated return &acct, nil } // HealthCheck runs a trivial query to verify DB liveness. func (d *DB) HealthCheck(ctx context.Context) error { _, err := d.conn.ExecContext(ctx, "SELECT 1") return err }