summaryrefslogtreecommitdiff
path: root/db
diff options
context:
space:
mode:
Diffstat (limited to 'db')
-rw-r--r--db/db.go191
1 files changed, 129 insertions, 62 deletions
diff --git a/db/db.go b/db/db.go
index ce0dbfd..6fed1b0 100644
--- a/db/db.go
+++ b/db/db.go
@@ -2,9 +2,8 @@ package db
import (
"context"
- "crypto/rand"
+ "bytes"
"database/sql"
- "encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -17,10 +16,7 @@ import (
_ "github.com/go-sql-driver/mysql"
)
-const (
- tokenLength = 32
- accessTokenTTL = 24 * time.Hour
-)
+const accessTokenTTL = 24 * time.Hour
// Account represents a verified account with its current tokens.
type Account struct {
@@ -62,16 +58,6 @@ func (d *DB) Ping(ctx context.Context) error {
return d.conn.PingContext(ctx)
}
-// --- crypto helpers ---
-
-func randomHex(n int) (string, error) {
- b := make([]byte, n)
- if _, err := rand.Read(b); err != nil {
- return "", fmt.Errorf("crypto rand: %w", err)
- }
- return hex.EncodeToString(b), nil
-}
-
// WikimediaTokens holds the tokens returned by the Wikimedia auth API.
type WikimediaTokens struct {
RefreshToken string `json:"refresh_token"`
@@ -111,6 +97,44 @@ func WikimediaLogin(ctx context.Context, username, password string) (*WikimediaT
return &tokens, nil
}
+// WikimediaTokenRefresh posts to the Wikimedia token-refresh endpoint
+// and returns new tokens from the response.
+func WikimediaTokenRefresh(ctx context.Context, username, refreshToken string) (*WikimediaTokens, error) {
+ body, err := json.Marshal(map[string]string{
+ "username": username,
+ "refresh_token": refreshToken,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("marshal token refresh body: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST",
+ "https://auth.enterprise.wikimedia.com/v1/token-refresh",
+ bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("new request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("token refresh request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ bs, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(bs))
+ }
+
+ var tokens WikimediaTokens
+ if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
+ return nil, fmt.Errorf("decode token refresh response: %w", err)
+ }
+
+ return &tokens, nil
+}
+
// --- error helpers ---
func isDupKeyError(err error) bool {
@@ -166,7 +190,8 @@ func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW, refreshTo
}, nil
}
-// Authenticate verifies plaintext credentials and returns fresh tokens.
+// Authenticate verifies plaintext credentials and returns an account
+// with fresh tokens obtained from the Wikimedia auth API.
func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) {
var storedPW string
err := d.conn.QueryRowContext(ctx,
@@ -183,69 +208,111 @@ func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*A
return nil, errors.New("invalid credentials")
}
- // Look up account id, then rotate tokens
- var id int64
- if err := d.conn.QueryRowContext(ctx,
- `SELECT id FROM account WHERE username = ?`, username,
- ).Scan(&id); err != nil {
- return nil, fmt.Errorf("lookup id: %w", err)
+ tokens, err := WikimediaLogin(ctx, username, plaintextPW)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia login: %w", err)
}
- return d.RotateTokens(ctx, id)
+ if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+
+ return d.getAccountByUsername(ctx, username)
}
-// RotateTokens generates new refresh + access tokens for an account.
-func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) {
- rt, err := randomHex(tokenLength)
+// RefreshTokens checks the age of the stored tokens and refreshes
+// them via the Wikimedia auth API as needed:
+// - refresh_token older than 90 days: full re-auth via WikimediaLogin
+// - access_token older than 24 hours: token refresh via WikimediaTokenRefresh
+func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*Account, error) {
+ var id int64
+ var password, storedRT, accessToken string
+ var refreshCreated, accessCreated time.Time
+
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT id, password, refresh_token, access_token, refresh_token_created, access_token_created
+ FROM account WHERE username = ?`, username,
+ ).Scan(&id, &password, &storedRT, &accessToken, &refreshCreated, &accessCreated)
if err != nil {
- return nil, err
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("account not found")
+ }
+ return nil, fmt.Errorf("query account: %w", err)
}
- at, err := randomHex(tokenLength)
- if err != nil {
- return nil, err
+
+ // Verify the provided refresh token matches
+ if storedRT != providedRT {
+ return nil, errors.New("invalid refresh token")
}
- res, err := d.conn.ExecContext(ctx,
- `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW()
- WHERE id = ?`, rt, at, id)
- if err != nil {
- return nil, fmt.Errorf("rotate tokens: %w", err)
+ // Refresh token older than 90 days: re-authenticate
+ if time.Since(refreshCreated) > 90*24*time.Hour {
+ tokens, err := WikimediaLogin(ctx, username, password)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia re-auth: %w", err)
+ }
+ if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
}
- if affected, _ := res.RowsAffected(); affected == 0 {
- return nil, errors.New("account not found")
+
+ // Access token older than 24 hours: refresh via token-refresh endpoint
+ if time.Since(accessCreated) > 24*time.Hour {
+ tokens, err := WikimediaTokenRefresh(ctx, username, storedRT)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia token refresh: %w", err)
+ }
+
+ // Update both tokens (refresh response may include a new refresh token)
+ if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
}
- var username string
- if err := d.conn.QueryRowContext(ctx,
- `SELECT username FROM account WHERE id = ?`, id,
- ).Scan(&username); err != nil {
- return nil, fmt.Errorf("lookup username: %w", err)
+ // Tokens are still valid
+ return d.getAccountByUsername(ctx, username)
+}
+
+// updateTokens writes new refresh and access tokens for a user.
+func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error {
+ _, err := d.conn.ExecContext(ctx,
+ `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW()
+ WHERE username = ?`, refreshToken, accessToken, username)
+ if err != nil {
+ return fmt.Errorf("update tokens: %w", err)
}
+ return nil
+}
- now := time.Now()
- return &Account{
- ID: id,
- Username: username,
- RefreshToken: rt,
- AccessToken: at,
- AccessTokenExpiry: now.Add(accessTokenTTL),
- CreatedAt: now,
- }, nil
+// updateAccessToken writes a new access token for a user.
+func (d *DB) updateAccessToken(ctx context.Context, username, accessToken string) error {
+ _, err := d.conn.ExecContext(ctx,
+ `UPDATE account SET access_token = ?, access_token_created = NOW()
+ WHERE username = ?`, accessToken, username)
+ if err != nil {
+ return fmt.Errorf("update access token: %w", err)
+ }
+ return nil
}
-// RefreshByToken looks up an account by its refresh token and rotates both tokens.
-func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) {
- var id int64
+// getAccountByUsername fetches the current account state.
+func (d *DB) getAccountByUsername(ctx context.Context, username string) (*Account, error) {
+ var acct Account
+ var refreshCreated, accessCreated time.Time
+
err := d.conn.QueryRowContext(ctx,
- `SELECT id FROM account WHERE refresh_token = ?`, refreshToken,
- ).Scan(&id)
+ `SELECT id, username, refresh_token, access_token, refresh_token_created, access_token_created, created_at
+ FROM account WHERE username = ?`, username,
+ ).Scan(&acct.ID, &acct.Username, &acct.RefreshToken, &acct.AccessToken,
+ &refreshCreated, &accessCreated, &acct.CreatedAt)
if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, errors.New("invalid refresh token")
- }
- return nil, fmt.Errorf("lookup refresh token: %w", err)
+ return nil, fmt.Errorf("get account: %w", err)
}
- return d.RotateTokens(ctx, id)
+
+ acct.AccessTokenExpiry = accessCreated
+ return &acct, nil
}
// HealthCheck runs a trivial query to verify DB liveness.