diff options
Diffstat (limited to 'db')
| -rw-r--r-- | db/db.go | 191 |
1 files changed, 129 insertions, 62 deletions
@@ -2,9 +2,8 @@ package db import ( "context" - "crypto/rand" + "bytes" "database/sql" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,10 +16,7 @@ import ( _ "github.com/go-sql-driver/mysql" ) -const ( - tokenLength = 32 - accessTokenTTL = 24 * time.Hour -) +const accessTokenTTL = 24 * time.Hour // Account represents a verified account with its current tokens. type Account struct { @@ -62,16 +58,6 @@ func (d *DB) Ping(ctx context.Context) error { return d.conn.PingContext(ctx) } -// --- crypto helpers --- - -func randomHex(n int) (string, error) { - b := make([]byte, n) - if _, err := rand.Read(b); err != nil { - return "", fmt.Errorf("crypto rand: %w", err) - } - return hex.EncodeToString(b), nil -} - // WikimediaTokens holds the tokens returned by the Wikimedia auth API. type WikimediaTokens struct { RefreshToken string `json:"refresh_token"` @@ -111,6 +97,44 @@ func WikimediaLogin(ctx context.Context, username, password string) (*WikimediaT return &tokens, nil } +// WikimediaTokenRefresh posts to the Wikimedia token-refresh endpoint +// and returns new tokens from the response. +func WikimediaTokenRefresh(ctx context.Context, username, refreshToken string) (*WikimediaTokens, error) { + body, err := json.Marshal(map[string]string{ + "username": username, + "refresh_token": refreshToken, + }) + if err != nil { + return nil, fmt.Errorf("marshal token refresh body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", + "https://auth.enterprise.wikimedia.com/v1/token-refresh", + bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token refresh request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + bs, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(bs)) + } + + var tokens WikimediaTokens + if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { + return nil, fmt.Errorf("decode token refresh response: %w", err) + } + + return &tokens, nil +} + // --- error helpers --- func isDupKeyError(err error) bool { @@ -166,7 +190,8 @@ func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW, refreshTo }, nil } -// Authenticate verifies plaintext credentials and returns fresh tokens. +// Authenticate verifies plaintext credentials and returns an account +// with fresh tokens obtained from the Wikimedia auth API. func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) { var storedPW string err := d.conn.QueryRowContext(ctx, @@ -183,69 +208,111 @@ func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*A return nil, errors.New("invalid credentials") } - // Look up account id, then rotate tokens - var id int64 - if err := d.conn.QueryRowContext(ctx, - `SELECT id FROM account WHERE username = ?`, username, - ).Scan(&id); err != nil { - return nil, fmt.Errorf("lookup id: %w", err) + tokens, err := WikimediaLogin(ctx, username, plaintextPW) + if err != nil { + return nil, fmt.Errorf("wikimedia login: %w", err) } - return d.RotateTokens(ctx, id) + if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { + return nil, err + } + + return d.getAccountByUsername(ctx, username) } -// RotateTokens generates new refresh + access tokens for an account. -func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) { - rt, err := randomHex(tokenLength) +// RefreshTokens checks the age of the stored tokens and refreshes +// them via the Wikimedia auth API as needed: +// - refresh_token older than 90 days: full re-auth via WikimediaLogin +// - access_token older than 24 hours: token refresh via WikimediaTokenRefresh +func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*Account, error) { + var id int64 + var password, storedRT, accessToken string + var refreshCreated, accessCreated time.Time + + err := d.conn.QueryRowContext(ctx, + `SELECT id, password, refresh_token, access_token, refresh_token_created, access_token_created + FROM account WHERE username = ?`, username, + ).Scan(&id, &password, &storedRT, &accessToken, &refreshCreated, &accessCreated) if err != nil { - return nil, err + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, fmt.Errorf("query account: %w", err) } - at, err := randomHex(tokenLength) - if err != nil { - return nil, err + + // Verify the provided refresh token matches + if storedRT != providedRT { + return nil, errors.New("invalid refresh token") } - res, err := d.conn.ExecContext(ctx, - `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW() - WHERE id = ?`, rt, at, id) - if err != nil { - return nil, fmt.Errorf("rotate tokens: %w", err) + // Refresh token older than 90 days: re-authenticate + if time.Since(refreshCreated) > 90*24*time.Hour { + tokens, err := WikimediaLogin(ctx, username, password) + if err != nil { + return nil, fmt.Errorf("wikimedia re-auth: %w", err) + } + if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) } - if affected, _ := res.RowsAffected(); affected == 0 { - return nil, errors.New("account not found") + + // Access token older than 24 hours: refresh via token-refresh endpoint + if time.Since(accessCreated) > 24*time.Hour { + tokens, err := WikimediaTokenRefresh(ctx, username, storedRT) + if err != nil { + return nil, fmt.Errorf("wikimedia token refresh: %w", err) + } + + // Update both tokens (refresh response may include a new refresh token) + if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) } - var username string - if err := d.conn.QueryRowContext(ctx, - `SELECT username FROM account WHERE id = ?`, id, - ).Scan(&username); err != nil { - return nil, fmt.Errorf("lookup username: %w", err) + // Tokens are still valid + return d.getAccountByUsername(ctx, username) +} + +// updateTokens writes new refresh and access tokens for a user. +func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error { + _, err := d.conn.ExecContext(ctx, + `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW() + WHERE username = ?`, refreshToken, accessToken, username) + if err != nil { + return fmt.Errorf("update tokens: %w", err) } + return nil +} - now := time.Now() - return &Account{ - ID: id, - Username: username, - RefreshToken: rt, - AccessToken: at, - AccessTokenExpiry: now.Add(accessTokenTTL), - CreatedAt: now, - }, nil +// updateAccessToken writes a new access token for a user. +func (d *DB) updateAccessToken(ctx context.Context, username, accessToken string) error { + _, err := d.conn.ExecContext(ctx, + `UPDATE account SET access_token = ?, access_token_created = NOW() + WHERE username = ?`, accessToken, username) + if err != nil { + return fmt.Errorf("update access token: %w", err) + } + return nil } -// RefreshByToken looks up an account by its refresh token and rotates both tokens. -func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) { - var id int64 +// getAccountByUsername fetches the current account state. +func (d *DB) getAccountByUsername(ctx context.Context, username string) (*Account, error) { + var acct Account + var refreshCreated, accessCreated time.Time + err := d.conn.QueryRowContext(ctx, - `SELECT id FROM account WHERE refresh_token = ?`, refreshToken, - ).Scan(&id) + `SELECT id, username, refresh_token, access_token, refresh_token_created, access_token_created, created_at + FROM account WHERE username = ?`, username, + ).Scan(&acct.ID, &acct.Username, &acct.RefreshToken, &acct.AccessToken, + &refreshCreated, &accessCreated, &acct.CreatedAt) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, errors.New("invalid refresh token") - } - return nil, fmt.Errorf("lookup refresh token: %w", err) + return nil, fmt.Errorf("get account: %w", err) } - return d.RotateTokens(ctx, id) + + acct.AccessTokenExpiry = accessCreated + return &acct, nil } // HealthCheck runs a trivial query to verify DB liveness. |
