1
0
mirror of https://github.com/lxc/incus.git synced 2026-02-05 09:46:19 +01:00
Files
incus/client/incus_oidc.go
Stéphane Graber 39d8708751 client: Use server-advertised OIDC scopes
Closes #2249

Signed-off-by: Stéphane Graber <stgraber@stgraber.org>
2025-07-28 14:25:56 -04:00

329 lines
9.7 KiB
Go

package incus
import (
"context"
"crypto/rand"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/gorilla/websocket"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"github.com/lxc/incus/v6/shared/util"
)
// ErrOIDCExpired is returned when the token is expired and we can't retry the request ourselves.
var ErrOIDCExpired = errors.New("OIDC token expired, please re-try the request")
// setupOIDCClient initializes the OIDC (OpenID Connect) client with given tokens if it hasn't been set up already.
// It also assigns the protocol's http client to the oidcClient's httpClient.
func (r *ProtocolIncus) setupOIDCClient(token *oidc.Tokens[*oidc.IDTokenClaims]) {
if r.oidcClient != nil {
return
}
r.oidcClient = newOIDCClient(token)
r.oidcClient.httpClient = r.http
}
// GetOIDCTokens returns the current OIDC tokens (if any) from the OIDC client.
//
// This should only be used by internal Incus tools when it's not possible to get the tokens from a Config struct.
func (r *ProtocolIncus) GetOIDCTokens() *oidc.Tokens[*oidc.IDTokenClaims] {
if r.oidcClient == nil {
return nil
}
return r.oidcClient.tokens
}
// oidcTransport is a custom HTTP transport that injects the audience field into requests directed at the device authorization endpoint.
type oidcTransport struct {
deviceAuthorizationEndpoint string
audience string
}
// RoundTrip is a method of oidcTransport that modifies the request, adds the audience parameter if appropriate, and sends it along.
func (o *oidcTransport) RoundTrip(r *http.Request) (*http.Response, error) {
// Don't modify the request if it's not to the device authorization endpoint, or there are no
// URL parameters which need to be set.
if r.URL.String() != o.deviceAuthorizationEndpoint || len(o.audience) == 0 {
return http.DefaultTransport.RoundTrip(r)
}
err := r.ParseForm()
if err != nil {
return nil, err
}
if o.audience != "" {
r.Form.Add("audience", o.audience)
}
// Update the body with the new URL parameters.
body := r.Form.Encode()
r.Body = io.NopCloser(strings.NewReader(body))
r.ContentLength = int64(len(body))
return http.DefaultTransport.RoundTrip(r)
}
var errRefreshAccessToken = errors.New("Failed refreshing access token")
type oidcClient struct {
httpClient *http.Client
oidcTransport *oidcTransport
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
// oidcClient is a structure encapsulating an HTTP client, OIDC transport, and a token for OpenID Connect (OIDC) operations.
// newOIDCClient constructs a new oidcClient, ensuring the token field is non-nil to prevent panics during authentication.
func newOIDCClient(tokens *oidc.Tokens[*oidc.IDTokenClaims]) *oidcClient {
client := oidcClient{
tokens: tokens,
httpClient: &http.Client{},
oidcTransport: &oidcTransport{},
}
// Ensure client.tokens is never nil otherwise authenticate() will panic.
if client.tokens == nil {
client.tokens = &oidc.Tokens[*oidc.IDTokenClaims]{}
}
return &client
}
// getAccessToken returns the Access Token from the oidcClient's tokens, or an empty string if no tokens are present.
func (o *oidcClient) getAccessToken() string {
if o.tokens == nil || o.tokens.Token == nil {
return ""
}
return o.tokens.AccessToken
}
// do function executes an HTTP request using the oidcClient's http client, and manages authorization by refreshing or authenticating as needed.
// If the request fails with an HTTP Unauthorized status, it attempts to refresh the access token, or perform an OIDC authentication if refresh fails.
func (o *oidcClient) do(req *http.Request) (*http.Response, error) {
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, err
}
// Return immediately if the error is not HTTP status unauthorized.
if resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}
issuer := resp.Header.Get("X-Incus-OIDC-issuer")
clientID := resp.Header.Get("X-Incus-OIDC-clientid")
audience := resp.Header.Get("X-Incus-OIDC-audience")
scopes := resp.Header.Get("X-Incus-OIDC-scopes")
if scopes == "" {
scopes = "openid,offline_access"
}
if issuer == "" || clientID == "" {
return resp, nil
}
// Refresh the token.
err = o.refresh(issuer, clientID, scopes)
if err != nil {
err = o.authenticate(issuer, clientID, audience, scopes)
if err != nil {
return nil, err
}
}
// If not dealing with something we can retry, return a clear error.
if req.Method != "GET" && req.GetBody == nil {
return resp, ErrOIDCExpired
}
// Set the new access token in the header.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.tokens.AccessToken))
// Reset the request body.
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return nil, err
}
req.Body = body
}
resp, err = o.httpClient.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
// dial function executes a websocket request and handles OIDC authentication and refresh.
func (o *oidcClient) dial(dialer websocket.Dialer, uri string, req *http.Request) (*websocket.Conn, *http.Response, error) {
conn, resp, err := dialer.Dial(uri, req.Header)
if err != nil && resp == nil {
return nil, nil, err
}
// Return immediately if the error is not HTTP status unauthorized.
if conn != nil && resp.StatusCode != http.StatusUnauthorized {
return conn, resp, nil
}
issuer := resp.Header.Get("X-Incus-OIDC-issuer")
clientID := resp.Header.Get("X-Incus-OIDC-clientid")
audience := resp.Header.Get("X-Incus-OIDC-audience")
scopes := resp.Header.Get("X-Incus-OIDC-scopes")
if scopes == "" {
scopes = "openid,offline_access"
}
if issuer == "" || clientID == "" {
return nil, resp, err
}
err = o.refresh(issuer, clientID, scopes)
if err != nil {
err = o.authenticate(issuer, clientID, audience, scopes)
if err != nil {
return nil, resp, err
}
}
// Set the new access token in the header.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.tokens.AccessToken))
return dialer.Dial(uri, req.Header)
}
// getProvider initializes a new OpenID Connect Relying Party for a given issuer and clientID.
// The function also creates a secure CookieHandler with random encryption and hash keys, and applies a series of configurations on the Relying Party.
func (o *oidcClient) getProvider(issuer string, clientID string, scopes string) (rp.RelyingParty, error) {
hashKey := make([]byte, 16)
encryptKey := make([]byte, 16)
_, err := rand.Read(hashKey)
if err != nil {
return nil, err
}
_, err = rand.Read(encryptKey)
if err != nil {
return nil, err
}
cookieHandler := httphelper.NewCookieHandler(hashKey, encryptKey, httphelper.WithUnsecure())
options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithPKCE(cookieHandler),
rp.WithHTTPClient(o.httpClient),
}
provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, "", "", strings.Split(scopes, ","), options...)
if err != nil {
return nil, err
}
return provider, nil
}
// refresh attempts to refresh the OpenID Connect access token for the client using the refresh token.
// If no token is present or the refresh token is empty, it returns an error. If successful, it updates the access token and other relevant token fields.
func (o *oidcClient) refresh(issuer string, clientID string, scopes string) error {
if o.tokens.Token == nil || o.tokens.RefreshToken == "" {
return errRefreshAccessToken
}
provider, err := o.getProvider(issuer, clientID, scopes)
if err != nil {
return errRefreshAccessToken
}
oauthTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](context.TODO(), provider, o.tokens.RefreshToken, "", "")
if err != nil {
return errRefreshAccessToken
}
o.tokens.AccessToken = oauthTokens.AccessToken
o.tokens.TokenType = oauthTokens.TokenType
o.tokens.Expiry = oauthTokens.Expiry
if oauthTokens.RefreshToken != "" {
o.tokens.RefreshToken = oauthTokens.RefreshToken
}
return nil
}
// authenticate initiates the OpenID Connect device flow authentication process for the client.
// It presents a user code for the end user to input in the device that has web access and waits for them to complete the authentication,
// subsequently updating the client's tokens upon successful authentication.
func (o *oidcClient) authenticate(issuer string, clientID string, audience string, scopes string) error {
// Store the old transport and restore it in the end.
oldTransport := o.httpClient.Transport
o.oidcTransport.audience = audience
o.httpClient.Transport = o.oidcTransport
defer func() {
o.httpClient.Transport = oldTransport
}()
provider, err := o.getProvider(issuer, clientID, scopes)
if err != nil {
return err
}
o.oidcTransport.deviceAuthorizationEndpoint = provider.GetDeviceAuthorizationEndpoint()
resp, err := rp.DeviceAuthorization(context.TODO(), strings.Split(scopes, ","), provider, nil)
if err != nil {
return err
}
u, _ := url.Parse(resp.VerificationURIComplete)
fmt.Printf("URL: %s\n", u.String())
fmt.Printf("Code: %s\n\n", resp.UserCode)
_ = util.OpenBrowser(u.String())
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
defer stop()
token, err := rp.DeviceAccessToken(ctx, resp.DeviceCode, time.Duration(resp.Interval)*time.Second, provider)
if err != nil {
return err
}
if o.tokens.Token == nil {
o.tokens.Token = &oauth2.Token{}
}
o.tokens.Expiry = time.Now().Add(time.Duration(token.ExpiresIn))
o.tokens.IDToken = token.IDToken
o.tokens.AccessToken = token.AccessToken
o.tokens.TokenType = token.TokenType
if token.RefreshToken != "" {
o.tokens.RefreshToken = token.RefreshToken
}
return nil
}