Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion new-ui/src/shared/components/LocationCard/LocationCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { LocationCardMfaMobileView } from './views/LocationCardMfaMobileView/Loc
import { LocationCardMfaOidcView } from './views/LocationCardMfaOidcView/LocationCardMfaOidcView';
import { LocationCardMfaSettings } from './views/LocationCardMfaSettings/LocationCardMfaSettings';
import { LocationCardMfaTotpView } from './views/LocationCardMfaTotpView/LocationCardMfaTotpView';
import { LocationCardPostureCheckFailView } from './views/LocationCardPostureCheckFailView/LocationCardPostureCheckFailView';

interface Props {
location: LocationInfo;
Expand All @@ -39,7 +40,7 @@ const views: Record<LocationCardViewsValue, ReactNode> = {
[LocationCardViews.MfaSettings]: <LocationCardMfaSettings />,
[LocationCardViews.Connecting]: null,
[LocationCardViews.Connected]: <ConnectedView />,
[LocationCardViews.PostureCheckFail]: null,
[LocationCardViews.PostureCheckFail]: <LocationCardPostureCheckFailView />,
};

interface InnerProps {
Expand Down
20 changes: 20 additions & 0 deletions new-ui/src/shared/components/LocationCard/api/connectError.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import z from 'zod';

const connectErrorSchema = z.discriminatedUnion('kind', [
z.object({
kind: z.literal('postureCheckFailed'),
message: z.string(),
}),
z.object({
kind: z.literal('other'),
message: z.string(),
}),
]);

export type ConnectError = z.infer<typeof connectErrorSchema>;

export const parseConnectError = (err: unknown): ConnectError | null => {
const result = connectErrorSchema.safeParse(err);

return result.success ? result.data : null;
};
121 changes: 121 additions & 0 deletions new-ui/src/shared/components/LocationCard/api/startClientMfaSession.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import { fetch } from '@tauri-apps/plugin-http';
import { api } from '../../../rust-api/api';
import type {
EdgeRequestHeaders,
InstanceInfo,
LocationInfo,
} from '../../../rust-api/types';

export const CLIENT_MFA_ENDPOINT = 'api/v1/client-mfa';

/** Error raised when the MFA start request or its prerequisites fail. */
export class MfaStartError extends Error {
public readonly status?: number;

constructor(message: string, status?: number) {
super(message);
this.name = 'MfaStartError';
this.status = status;
}
}

/** MFA method identifiers expected by the desktop-client MFA API */
export const MfaStartMethod = {
Totp: 0,
Email: 1,
Oidc: 2,
MobileApprove: 4,
} as const;

export type MfaStartMethod = (typeof MfaStartMethod)[keyof typeof MfaStartMethod];

/** Successful MFA start response returned by the proxy. */
export type MfaStartResponse = {
token: string;
challenge?: string;
};

/** Error response shape returned by the proxy for MFA start failures. */
type MfaStartErrorResponse = {
error?: string;
};

/** Narrows MFA start errors that should open the posture failure view. */
export const shouldShowPostureError = (
err: unknown,
location: LocationInfo,
): err is MfaStartError =>
err instanceof MfaStartError && err.status === 403 && location.posture_check_required;

/** Input required to start a desktop-client MFA session. */
type StartClientMfaSessionParams = {
instance: InstanceInfo;
location: LocationInfo;
method: MfaStartMethod;
};

/** MFA start response plus request headers required by later MFA calls. */
type StartClientMfaSessionResult = {
response: MfaStartResponse;
headers: EdgeRequestHeaders;
};

/** Starts an MFA session, including posture data when the location requires it. */
export const startClientMfaSession = async ({
instance,
location,
method,
}: StartClientMfaSessionParams): Promise<StartClientMfaSessionResult> => {
let headers: EdgeRequestHeaders;
try {
headers = await api.getEdgeRequestHeaders();
} catch {
throw new MfaStartError('Failed to load request headers');
}

let posture_data: unknown;
try {
posture_data = location.posture_check_required
? await api.getPostureData()
: undefined;
} catch {
throw new MfaStartError('Failed to load posture data');
}

try {
const response = await fetch(`${instance.proxy_url}${CLIENT_MFA_ENDPOINT}/start`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...headers,
},
body: JSON.stringify({
method,
pubkey: instance.pubkey,
location_id: location.network_id,
posture_data,
}),
});

if (!response.ok) {
let message = 'Failed to start MFA';
try {
const data = (await response.json()) as MfaStartErrorResponse;
message = data.error ?? message;
} catch {
// Keep the response status even if the proxy sends a malformed error body.
}
throw new MfaStartError(message, response.status);
}

return {
response: (await response.json()) as MfaStartResponse,
headers,
};
} catch (err) {
if (err instanceof MfaStartError) {
throw err;
}
throw new MfaStartError('Failed to reach server');
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@ import { useMutation } from '@tanstack/react-query';
import clsx from 'clsx';
import { api } from '../../../../rust-api/api';
import { LocationMfaMode } from '../../../../rust-api/types';
import { parseConnectError } from '../../api/connectError';
import { useLocationCardContext } from '../../context/context';
import { LocationCardViews } from '../../context/types';

export const ConnectButton = () => {
const { location, setView, startMfa } = useLocationCardContext();
const { location, setPostureError, setView, startMfa } = useLocationCardContext();

const { mutate: connect } = useMutation({
mutationFn: api.connect,
onSuccess: () => {
setView(LocationCardViews.Connected);
},
onError: (err) => {
const connectError = parseConnectError(err);

if (
location.posture_check_required &&
connectError?.kind === 'postureCheckFailed'
) {
setPostureError(connectError.message);
setView(LocationCardViews.PostureCheckFail);
}
},
meta: {
invalidate: ['locations'],
},
Expand Down
5 changes: 5 additions & 0 deletions new-ui/src/shared/components/LocationCard/context/context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ interface LocationCardContextValue {
instance: InstanceInfo;
currentView: LocationCardViewsValue;
previousView: LocationCardViewsValue | null;
postureError: string | null;
setView: (view: LocationCardViewsValue) => void;
setPostureError: (error: string | null) => void;
startMfa: () => void;
}

Expand All @@ -34,6 +36,7 @@ export const LocationCardProvider = ({
children,
}: LocationCardProviderProps) => {
const [previousView, setPreviousView] = useState<LocationCardViewsValue | null>(null);
const [postureError, setPostureError] = useState<string | null>(null);
const [currentView, setCurrentView] = useState<LocationCardViewsValue>(
location.active ? LocationCardViews.Connected : LocationCardViews.Default,
);
Expand Down Expand Up @@ -68,7 +71,9 @@ export const LocationCardProvider = ({
value={{
currentView,
previousView,
postureError,
setView,
setPostureError,
location,
instance,
startMfa,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import type { LocationInfo } from '../../../rust-api/types';
import { shouldShowPostureError } from '../api/startClientMfaSession';
import { LocationCardViews, type LocationCardViewsValue } from '../context/types';

type HandleMfaStartErrorParams = {
err: unknown;
location: LocationInfo;
setPostureError: (error: string | null) => void;
setView: (view: LocationCardViewsValue) => void;
};

/** Handles MFA start posture failures and reports whether the error was consumed. */
export const handleMfaStartError = ({
err,
location,
setPostureError,
setView,
}: HandleMfaStartErrorParams): boolean => {
if (!shouldShowPostureError(err, location)) {
return false;
}

setPostureError(err.message);
setView(LocationCardViews.PostureCheckFail);
return true;
};
77 changes: 27 additions & 50 deletions new-ui/src/shared/components/LocationCard/hooks/useMfaConnect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,16 @@ import { fetch } from '@tauri-apps/plugin-http';
import { error } from '@tauri-apps/plugin-log';
import { useCallback, useEffect, useRef, useState } from 'react';
import { api } from '../../../rust-api/api';
import {
getInstancesQueryOptions,
getPlatformHeaderQueryOptions,
} from '../../../rust-api/query';
import { getInstancesQueryOptions } from '../../../rust-api/query';
import type { EdgeRequestHeaders } from '../../../rust-api/types';
import {
CLIENT_MFA_ENDPOINT,
type MfaStartMethod,
startClientMfaSession,
} from '../api/startClientMfaSession';
import { useLocationCardContext } from '../context/context';
import { LocationCardViews } from '../context/types';

const MFA_ENDPOINT = 'api/v1/client-mfa';

type MfaStartResponse = {
token: string;
challenge?: string;
};
import { handleMfaStartError } from './handleMfaStartError';

type MfaFinishResponse = {
preshared_key: string;
Expand All @@ -26,8 +22,10 @@ type MfaErrorResponse = {
error: string;
};

export const useMfaConnect = (method: 0 | 1) => {
const { location, setView } = useLocationCardContext();
type CodeMfaStartMethod = Extract<MfaStartMethod, 0 | 1>;

export const useMfaConnect = (method: CodeMfaStartMethod) => {
const { location, setPostureError, setView } = useLocationCardContext();

const [token, setToken] = useState<string | null>(null);
const [isStarting, setIsStarting] = useState(false);
Expand All @@ -37,7 +35,6 @@ export const useMfaConnect = (method: 0 | 1) => {
const [requestHeaders, setRequestHeaders] = useState<EdgeRequestHeaders | null>(null);

const { data: instances } = useQuery(getInstancesQueryOptions);
const { data: platformHeader } = useQuery(getPlatformHeaderQueryOptions);

const instance = instances?.find((i) => i.id === location.instance_id);

Expand All @@ -52,67 +49,47 @@ export const useMfaConnect = (method: 0 | 1) => {
},
});

// Fire the /start request exactly once when instance + platformHeader are ready.
// Fire the /start request exactly once when instance data is ready.
const startCalled = useRef(false);

// biome-ignore lint/correctness/useExhaustiveDependencies: intentional one-shot trigger via startCalled ref
useEffect(() => {
if (!instance || !platformHeader || startCalled.current) return;
if (!instance || startCalled.current) return;
startCalled.current = true;

setIsStarting(true);

(async () => {
let headers: EdgeRequestHeaders;
try {
headers = await api.getEdgeRequestHeaders();
setRequestHeaders(headers);
} catch {
setStartError('Failed to load request headers');
setIsStarting(false);
return;
}

try {
const res = await fetch(`${instance.proxy_url}${MFA_ENDPOINT}/start`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...headers,
},
body: JSON.stringify({
method,
pubkey: instance.pubkey,
location_id: location.network_id,
}),
const { response, headers } = await startClientMfaSession({
instance,
location,
method,
});

if (res.ok) {
const data = (await res.json()) as MfaStartResponse;
setToken(data.token);
} else {
const data = (await res.json()) as MfaErrorResponse;
setStartError(data.error ?? 'Failed to start MFA');
setRequestHeaders(headers);
setToken(response.token);
} catch (err) {
if (handleMfaStartError({ err, location, setPostureError, setView })) {
return;
}
} catch {
setStartError('Failed to reach server');
setStartError(err instanceof Error ? err.message : 'Failed to start MFA');
} finally {
setIsStarting(false);
}
})();
}, [instance, platformHeader]);
}, [instance]);

const verifyCode = useCallback(
async (code: string) => {
if (!token || !instance || !platformHeader || !requestHeaders) return;
if (!token || !instance || !requestHeaders) return;

setIsVerifying(true);
setVerifyError(null);

const body = JSON.stringify({ token, code });

try {
const res = await fetch(`${instance.proxy_url}${MFA_ENDPOINT}/finish`, {
const res = await fetch(`${instance.proxy_url}${CLIENT_MFA_ENDPOINT}/finish`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Expand Down Expand Up @@ -148,7 +125,7 @@ export const useMfaConnect = (method: 0 | 1) => {
setIsVerifying(false);
}
},
[token, instance, platformHeader, requestHeaders, location, connectMutate, setView],
[token, instance, requestHeaders, location, connectMutate, setView],
);

return { token, isStarting, startError, verifyCode, isVerifying, verifyError };
Expand Down
Loading
Loading