Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions cmd/altinity-mcp/oauth_regression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ func TestOAuthMCPAuthInjectorPassesBearerContext(t *testing.T) {
return httptest.NewRecorder(), req
}

t.Run("any bearer reaches handler verbatim", func(t *testing.T) {
t.Run("unexpired bearer reaches handler verbatim", func(t *testing.T) {
t.Parallel()
// MCP no longer validates the JWT here — even a wrong-audience or
// expired token reaches the inner handler. The sidecar / token
// processor is the cryptographic gate.
// MCP does only an exp-only check here (no signature/aud/scope) — a
// wrong-audience but unexpired token still reaches the inner handler.
// The sidecar / token processor is the cryptographic gate.
tok := provider.issueIDToken(t, map[string]interface{}{
"sub": "u-good",
"iss": provider.server.URL,
Expand All @@ -231,6 +231,49 @@ func TestOAuthMCPAuthInjectorPassesBearerContext(t *testing.T) {
require.Equal(t, http.StatusOK, rr.Code)
})

t.Run("expired bearer rejected with 401 on tools/list", func(t *testing.T) {
t.Parallel()
// Regression for the claude.ai "stuck with one tool" state
// (anthropics/claude-code#46328): an expired token must yield a
// transport-level 401 even on tools/list, so the client re-authorizes.
tok := provider.issueIDToken(t, map[string]interface{}{
"sub": "u-stale",
"iss": provider.server.URL,
"aud": "clickhouse-api",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
})
body := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"tools/list"}`)
req := httptest.NewRequest(http.MethodPost, "https://mcp.example.com/", body)
req.Header.Set("Authorization", "Bearer "+tok)
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
called := false
app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rr, req)
require.False(t, called, "expired token must not reach inner handler")
require.Equal(t, http.StatusUnauthorized, rr.Code)
challenge := rr.Header().Get("WWW-Authenticate")
require.Contains(t, challenge, `error="invalid_token"`)
require.Contains(t, challenge, `error_description="OAuth token expired"`)
})

t.Run("opaque bearer soft-passes", func(t *testing.T) {
t.Parallel()
// Non-JWT (opaque) tokens cannot be exp-checked locally; they must
// soft-pass and let the CH sidecar validate, preserving forward-mode.
rr, req := mkReq("opaque-access-token-not-a-jwt")
called := false
app.createMCPAuthInjector(app.config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rr, req)
require.True(t, called, "opaque token must reach inner handler")
require.Equal(t, http.StatusOK, rr.Code)
})

t.Run("missing bearer rejected with 401", func(t *testing.T) {
t.Parallel()
rr, req := mkReq("")
Expand Down
20 changes: 14 additions & 6 deletions cmd/altinity-mcp/oauth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,12 +519,20 @@ func (a *application) createMCPAuthInjector(cfg config.Config) func(http.Handler
a.writeOAuthError(w, r, altinitymcp.ErrMissingOAuthToken)
return
}
// The CH-side ch-jwt-verify sidecar performs signature/iss/aud/
// exp/scope validation against the upstream JWKS for every query.
// Per-request validation here would duplicate work for opaque
// tokens and add a JWKS hop on the hot path. Claims
// stay unset on the context — ClaimsFromContext returns nil
// either way, so no downstream nil-deref risk.
// Reject expired tokens at the transport layer so EVERY MCP
// method — including tools/list — returns a 401 with
// WWW-Authenticate: Bearer error="invalid_token". We issue no
// refresh tokens (#115), and claude.ai otherwise gets stuck
// showing stale tools instead of re-authorizing when its token
// lapses (anthropics/claude-code#46328). This is an exp-only
// check (signature NOT verified): the CH-side ch-jwt-verify
// sidecar still does full signature/iss/aud/scope validation per
// query, so we avoid a JWKS hop on the hot path. Opaque/non-JWT
// tokens soft-pass (OAuthTokenExpired returns false).
if a.mcpServer.OAuthTokenExpired(oauthToken) {
a.writeOAuthError(w, r, altinitymcp.ErrOAuthTokenExpired)
return
}
ctx = context.WithValue(ctx, altinitymcp.OAuthTokenKey, oauthToken)
}

Expand Down
64 changes: 64 additions & 0 deletions pkg/server/server_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"strings"
"time"

"github.com/altinity/altinity-mcp/pkg/clickhouse"
"github.com/altinity/altinity-mcp/pkg/config"
Expand Down Expand Up @@ -386,3 +387,66 @@ func emailFromUnverifiedJWT(token string) (string, bool) {
}
return "", false
}

// oauthExpiryClockSkewSecs tolerates small clock differences between this
// server and the IdP that signed the token so a token expiring within seconds
// does not bounce clients with slightly fast clocks.
const oauthExpiryClockSkewSecs = 60

// unverifiedExp decodes the JWT payload WITHOUT verifying the signature and
// returns the `exp` claim. isJWT is false for anything that is not a
// three-segment JWT with a decodable payload (opaque tokens, malformed input) —
// callers treat that as "cannot tell, soft-pass". Crypto validation (signature,
// iss, aud, scope) stays delegated to the CH-side ch-jwt-verify sidecar per
// query; this only catches the expiry case so the MCP transport can return a
// 401 on tools/list instead of letting an expired session linger as stale tools.
func unverifiedExp(token string) (exp int64, isJWT bool) {
parts := strings.Split(strings.TrimSpace(token), ".")
if len(parts) != 3 {
return 0, false
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
// Some IdPs emit padded segments; try the std encoding as a fallback.
if payload, err = base64.URLEncoding.DecodeString(parts[1]); err != nil {
log.Debug().Err(err).Msg("oauth: failed to base64-decode JWT payload for exp")
return 0, false
}
}
var raw struct {
Exp json.Number `json:"exp"`
}
dec := json.NewDecoder(strings.NewReader(string(payload)))
dec.UseNumber()
if err := dec.Decode(&raw); err != nil {
log.Debug().Err(err).Msg("oauth: failed to JSON-parse JWT payload for exp")
return 0, false
}
if raw.Exp == "" {
// Valid JWT but no exp claim — nothing to expire on.
return 0, true
}
v, err := raw.Exp.Int64()
if err != nil {
// exp present but non-numeric/float; fall back to float parse.
f, ferr := raw.Exp.Float64()
if ferr != nil {
log.Debug().Err(err).Msg("oauth: JWT exp claim is not numeric")
return 0, true
}
v = int64(f)
}
return v, true
}

// OAuthTokenExpired reports whether token is a JWT whose `exp` claim is in the
// past (with clock skew). Opaque/non-JWT tokens and JWTs without an exp claim
// return false (soft-pass) so forward-mode and opaque-token deployments keep
// working unchanged. Signature is NOT verified — see unverifiedExp.
func (s *ClickHouseJWEServer) OAuthTokenExpired(token string) bool {
exp, isJWT := unverifiedExp(token)
if !isJWT || exp <= 0 {
return false
}
return time.Now().Unix() > exp+oauthExpiryClockSkewSecs
}
78 changes: 78 additions & 0 deletions pkg/server/server_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package server
import (
"encoding/base64"
"encoding/json"
"fmt"
"testing"
"time"

"github.com/altinity/altinity-mcp/pkg/config"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -218,3 +220,79 @@ func TestEmailFromUnverifiedJWT(t *testing.T) {
require.Equal(t, "alice@example.com", got)
})
}

// jwtWithClaims builds an unsigned-but-well-formed three-segment JWT string
// (the signature segment is a dummy) for exp-claim tests. unverifiedExp never
// checks the signature, so this is sufficient.
func jwtWithClaims(t *testing.T, claims map[string]interface{}) string {
t.Helper()
body, err := json.Marshal(claims)
require.NoError(t, err)
return "header." + base64.RawURLEncoding.EncodeToString(body) + ".sig"
}

func TestUnverifiedExpAndOAuthTokenExpired(t *testing.T) {
t.Parallel()
srv := &ClickHouseJWEServer{Config: config.Config{
Server: config.ServerConfig{OAuth: config.OAuthConfig{Enabled: true}},
}}

t.Run("expired_jwt", func(t *testing.T) {
t.Parallel()
exp := time.Now().Add(-time.Hour).Unix()
tok := jwtWithClaims(t, map[string]interface{}{"exp": exp})
gotExp, isJWT := unverifiedExp(tok)
require.True(t, isJWT)
require.Equal(t, exp, gotExp)
require.True(t, srv.OAuthTokenExpired(tok))
})

t.Run("unexpired_jwt", func(t *testing.T) {
t.Parallel()
tok := jwtWithClaims(t, map[string]interface{}{"exp": time.Now().Add(time.Hour).Unix()})
_, isJWT := unverifiedExp(tok)
require.True(t, isJWT)
require.False(t, srv.OAuthTokenExpired(tok))
})

t.Run("within_clock_skew_not_expired", func(t *testing.T) {
t.Parallel()
// Expired 30s ago — inside the 60s skew window, so NOT treated as expired.
tok := jwtWithClaims(t, map[string]interface{}{"exp": time.Now().Add(-30 * time.Second).Unix()})
require.False(t, srv.OAuthTokenExpired(tok))
})

t.Run("opaque_token_softpasses", func(t *testing.T) {
t.Parallel()
_, isJWT := unverifiedExp("opaque-access-token")
require.False(t, isJWT)
require.False(t, srv.OAuthTokenExpired("opaque-access-token"))
})

t.Run("jwt_without_exp_softpasses", func(t *testing.T) {
t.Parallel()
tok := jwtWithClaims(t, map[string]interface{}{"sub": "u-1"})
gotExp, isJWT := unverifiedExp(tok)
require.True(t, isJWT)
require.Zero(t, gotExp)
require.False(t, srv.OAuthTokenExpired(tok))
})

t.Run("exp_as_float_parsed", func(t *testing.T) {
t.Parallel()
// json numbers can decode as float; ensure fractional exp still works.
raw := time.Now().Add(-time.Hour).Unix()
tok := "header." + base64.RawURLEncoding.EncodeToString(
[]byte(fmt.Sprintf(`{"exp":%d.5}`, raw))) + ".sig"
gotExp, isJWT := unverifiedExp(tok)
require.True(t, isJWT)
require.Equal(t, raw, gotExp)
require.True(t, srv.OAuthTokenExpired(tok))
})

t.Run("malformed_payload_softpasses", func(t *testing.T) {
t.Parallel()
_, isJWT := unverifiedExp("head.!!notbase64!!.sig")
require.False(t, isJWT)
})
}