diff --git a/src/client/errors.ts b/src/client/errors.ts index 476b7a2e6..f5ba1cb9a 100644 --- a/src/client/errors.ts +++ b/src/client/errors.ts @@ -50,9 +50,7 @@ class ErrorClientVerificationFailed extends ErrorClientService { exitCode = sysexits.USAGE; } -class ErrorAuthentication extends ErrorPolykey {} - -class ErrorAuthenticationInvalidToken extends ErrorAuthentication { +class ErrorClientAuthenticationInvalidToken extends ErrorClient { static description = 'Token is invalid'; exitCode = sysexits.PROTOCOL; } @@ -69,6 +67,5 @@ export { ErrorClientServiceNotRunning, ErrorClientServiceDestroyed, ErrorClientVerificationFailed, - ErrorAuthentication, - ErrorAuthenticationInvalidToken, + ErrorClientAuthenticationInvalidToken, }; diff --git a/src/client/handlers/AuthSignToken.ts b/src/client/handlers/AuthSignToken.ts index 04910b722..4a38eb484 100644 --- a/src/client/handlers/AuthSignToken.ts +++ b/src/client/handlers/AuthSignToken.ts @@ -29,7 +29,7 @@ class AuthSignToken extends UnaryHandler< const inputToken = { payload: input.payload, signatures: input.signatures }; const incomingToken = Token.fromEncoded(inputToken); if (!('publicKey' in incomingToken.payload)) { - throw new clientErrors.ErrorAuthenticationInvalidToken( + throw new clientErrors.ErrorClientAuthenticationInvalidToken( 'Input token does not contain public key', ); } @@ -38,7 +38,7 @@ class AuthSignToken extends UnaryHandler< 'base64url', ) as PublicKey; if (!incomingToken.verifyWithPublicKey(incomingPublicKey)) { - throw new clientErrors.ErrorAuthenticationInvalidToken( + throw new clientErrors.ErrorClientAuthenticationInvalidToken( 'Incoming token does not match its signature', ); } diff --git a/tests/client/handlers/auth.test.ts b/tests/client/handlers/auth.test.ts new file mode 100644 index 000000000..37115248d --- /dev/null +++ b/tests/client/handlers/auth.test.ts @@ -0,0 +1,134 @@ +import type { + IdentityRequestData, + IdentityResponseData, +} from '#src/client/types.js'; +import type { TLSConfig } from '#network/types.js'; +import fs from 'node:fs'; +import path from 'node:path'; +import os from 'node:os'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import { RPCClient } from '@matrixai/rpc'; +import { WebSocketClient } from '@matrixai/ws'; +import * as testsUtils from '../../utils/index.js'; +import { AuthSignToken } from '#client/handlers/index.js'; +import { authSignToken } from '#client/callers/index.js'; +import KeyRing from '#keys/KeyRing.js'; +import Token from '#tokens/Token.js'; +import ClientService from '#client/ClientService.js'; +import * as keysUtils from '#keys/utils/index.js'; +import * as networkUtils from '#network/utils.js'; +import * as clientErrors from '#client/errors.js'; + +describe('authSignToken', () => { + const logger = new Logger('authSignToken test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + const password = 'password'; + const localhost = '127.0.0.1'; + let dataDir: string; + let keyRing: KeyRing; + let tlsConfig: TLSConfig; + let clientService: ClientService; + let webSocketClient: WebSocketClient; + let rpcClient: RPCClient<{ + authSignToken: typeof authSignToken; + }>; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + keyRing = await KeyRing.createKeyRing({ + password, + keysPath, + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + logger, + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + clientService = new ClientService({ + tlsConfig, + logger: logger.getChild(ClientService.name), + }); + await clientService.start({ + manifest: { + authSignToken: new AuthSignToken({ + keyRing, + }), + }, + host: localhost, + }); + webSocketClient = await WebSocketClient.createWebSocketClient({ + config: { + verifyPeer: false, + }, + host: localhost, + logger: logger.getChild(WebSocketClient.name), + port: clientService.port, + }); + rpcClient = new RPCClient({ + manifest: { + authSignToken, + }, + streamFactory: () => webSocketClient.connection.newStream(), + toError: networkUtils.toError, + logger: logger.getChild(RPCClient.name), + }); + }); + + afterEach(async () => { + await keyRing.stop(); + await clientService.stop({ force: true }); + await webSocketClient.destroy({ force: true }); + await keyRing.stop(); + await fs.promises.rm(dataDir, { + force: true, + recursive: true, + }); + }); + + test('should sign a valid token', async () => { + // Create token with separate key pair + const keyPair = keysUtils.generateKeyPair(); + const token = Token.fromPayload({ + publicKey: keyPair.publicKey.toString('base64url'), + returnURL: 'test', + }); + token.signWithPrivateKey(keyPair); + + // Get the node to sign the token as well + const encodedToken = token.toEncoded(); + const identityToken = await rpcClient.methods.authSignToken(encodedToken); + + // Check the signature of both the incoming token and the original sent token + const decodedToken = Token.fromEncoded(identityToken); + const decodedPublicKey = keysUtils.publicKeyFromNodeId(keyRing.getNodeId()); + expect(decodedToken.verifyWithPublicKey(decodedPublicKey)).toBeTrue(); + const requestToken = Token.fromEncoded( + decodedToken.payload.requestToken, + ); + expect(requestToken.verifyWithPublicKey(keyPair.publicKey)).toBeTrue(); + }); + + test('should fail if public key does not match signature', async () => { + // Create token with a key pair and sign it with another + const keyPair1 = keysUtils.generateKeyPair(); + const keyPair2 = keysUtils.generateKeyPair(); + const token = Token.fromPayload({ + publicKey: keyPair1.publicKey.toString('base64url'), + returnURL: 'test', + }); + token.signWithPrivateKey(keyPair2); + + // The token should fail validation + const encodedToken = token.toEncoded(); + await testsUtils.expectRemoteError( + rpcClient.methods.authSignToken(encodedToken), + clientErrors.ErrorClientAuthenticationInvalidToken, + ); + }); +}); diff --git a/tests/client/handlers/vaults.test.ts b/tests/client/handlers/vaults.test.ts index b16e37846..91efba32f 100644 --- a/tests/client/handlers/vaults.test.ts +++ b/tests/client/handlers/vaults.test.ts @@ -94,10 +94,7 @@ describe('vaultsClone', () => { let dataDir: string; let db: DB; let keyRing: KeyRing; - let webSocketClient: WebSocketClient; - let clientService: ClientService; let vaultManager: VaultManager; - let taskManager: TaskManager; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( path.join(os.tmpdir(), 'polykey-test-'), @@ -130,11 +127,7 @@ describe('vaultsClone', () => { }); }); afterEach(async () => { - await clientService?.stop({ force: true }); - await webSocketClient.destroy({ force: true }); await vaultManager.stop(); - await taskManager.stopProcessing(); - await taskManager.stopTasks(); await db.stop(); await keyRing.stop(); await fs.promises.rm(dataDir, { @@ -693,8 +686,6 @@ describe('vaultsPull', () => { let dataDir: string; let db: DB; let keyRing: KeyRing; - let webSocketClient: WebSocketClient; - let clientService: ClientService; let vaultManager: VaultManager; let taskManager: TaskManager; let acl: ACL; @@ -758,8 +749,6 @@ describe('vaultsPull', () => { }); }); afterEach(async () => { - await clientService?.stop({ force: true }); - await webSocketClient.destroy({ force: true }); await vaultManager.stop(); await notificationsManager.stop(); await gestaltGraph.stop(); @@ -884,8 +873,6 @@ describe('vaultsScan', () => { let dataDir: string; let db: DB; let keyRing: KeyRing; - let webSocketClient: WebSocketClient; - let clientService: ClientService; let vaultManager: VaultManager; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -918,8 +905,6 @@ describe('vaultsScan', () => { }); }); afterEach(async () => { - await clientService?.stop({ force: true }); - await webSocketClient.destroy({ force: true }); await vaultManager.stop(); await db.stop(); await keyRing.stop();