Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions foundry.lock
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"lib/ds-test": {
"rev": "e282159d5170298eb2455a6c05280ab5a73a4ef0"
},
"lib/forge-std": {
"rev": "978ac6fadb62f5f0b723c996f64be52eddba6801"
},
"lib/solmate": {
"rev": "c892309933b25c03d32b1b0d674df7ae292ba925"
}
}
3 changes: 2 additions & 1 deletion src/test/validium/FastWithdrawVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ contract FastWithdrawVaultTest is ValidiumTestBase {
address(messenger),
address(template),
address(factory),
address(rollup)
address(rollup),
address(1) // wethGateway - placeholder for tests
)
)
);
Expand Down
52 changes: 35 additions & 17 deletions src/test/validium/L1ERC20GatewayValidium.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {AddressAliasHelper} from "../../libraries/common/AddressAliasHelper.sol"
import {IL1ERC20GatewayValidium} from "../../validium/IL1ERC20GatewayValidium.sol";
import {IL2ERC20GatewayValidium} from "../../validium/IL2ERC20GatewayValidium.sol";
import {L1ERC20GatewayValidium} from "../../validium/L1ERC20GatewayValidium.sol";
import {L1WETHGatewayValidium} from "../../validium/L1WETHGatewayValidium.sol";
import {ScrollChainValidium} from "../../validium/ScrollChainValidium.sol";

import {TransferReentrantToken} from "../mocks/tokens/TransferReentrantToken.sol";
Expand Down Expand Up @@ -46,6 +47,7 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase {
ScrollStandardERC20 private template;
ScrollStandardERC20Factory private factory;
L2StandardERC20Gateway private counterpartGateway;
L1WETHGatewayValidium private wethGateway = L1WETHGatewayValidium(address(1)); // placeholder for tests

MockERC20 private l1Token;
MockERC20 private l2Token;
Expand Down Expand Up @@ -120,23 +122,43 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase {
_deposit(address(this), amount, recipient, gasLimit);
}

function testDepositERC20WithSender(
address sender,
function testDepositERC20WrongKey(
uint256 amount,
bytes memory recipient,
uint256 gasLimit
) public {
_deposit(sender, amount, recipient, gasLimit);
(uint256 keyId, ) = rollup.getLatestEncryptionKey();
hevm.expectRevert(ScrollChainValidium.ErrorUnknownEncryptionKey.selector);
gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId + 1);
}

function testDepositERC20WrongKey(
function testDepositERC20WithRealSenderUnauthorized(
address attacker,
address victim,
uint256 amount,
bytes memory recipient,
uint256 gasLimit
) public {
hevm.assume(attacker != address(0));
hevm.assume(attacker != address(wethGateway));

amount = bound(amount, 1, l1Token.balanceOf(address(this)));
gasLimit = bound(gasLimit, defaultGasLimit / 2, defaultGasLimit);
(uint256 keyId, ) = rollup.getLatestEncryptionKey();
hevm.expectRevert(ScrollChainValidium.ErrorUnknownEncryptionKey.selector);
gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId + 1);

// Transfer tokens to attacker
l1Token.transfer(attacker, amount);

// Attacker approves gateway
hevm.startPrank(attacker);
l1Token.approve(address(gateway), amount);

// Attacker tries to call depositERC20 with victim as _realSender
// This should revert with ErrorCallerNotWethGateway
hevm.expectRevert(L1ERC20GatewayValidium.ErrorCallerNotWethGateway.selector);
gateway.depositERC20(address(l1Token), victim, recipient, amount, gasLimit, keyId);

hevm.stopPrank();
}

function testDepositReentrantToken(uint256 amount) public {
Expand Down Expand Up @@ -407,12 +429,11 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase {
if (amount == 0) {
(uint256 keyId, ) = rollup.getLatestEncryptionKey();
hevm.expectRevert(L1ERC20GatewayValidium.ErrorAmountIsZero.selector);
if (from == address(this)) {
gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId);
} else {
gateway.depositERC20(address(l1Token), from, recipient, amount, gasLimit, keyId);
}
gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId);
} else {
// Note: from parameter is only used in event expectations, not in actual calls
// The depositERC20 function always uses msg.sender for actual deposits

// emit QueueTransaction from L1MessageQueueV2
{
hevm.expectEmit(true, true, false, true);
Expand All @@ -434,11 +455,7 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase {
uint256 feeVaultBalance = address(feeVault).balance;
assertEq(l1Messenger.messageSendTimestamp(keccak256(xDomainCalldata)), 0);
(uint256 keyId, ) = rollup.getLatestEncryptionKey();
if (from == address(this)) {
gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId);
} else {
gateway.depositERC20(address(l1Token), from, recipient, amount, gasLimit, keyId);
}
gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId);
assertEq(amount + gatewayBalance, l1Token.balanceOf(address(gateway)));
assertEq(feeVaultBalance, address(feeVault).balance);
assertGt(l1Messenger.messageSendTimestamp(keccak256(xDomainCalldata)), 0);
Expand All @@ -456,7 +473,8 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase {
address(messenger),
address(template),
address(factory),
address(rollup)
address(rollup),
address(wethGateway)
)
)
);
Expand Down
34 changes: 16 additions & 18 deletions src/test/validium/L1WETHGatewayValidium.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,24 @@ contract L1WETHGatewayValidiumTest is ValidiumTestBase {
counterpartGateway = new L2StandardERC20Gateway(address(1), address(1), address(1), address(factory));

// Deploy L1 contracts
gateway = _deployGateway(address(l1Messenger));
gateway = L1ERC20GatewayValidium(_deployProxy(address(0)));
wethGateway = new L1WETHGatewayValidium(address(weth), address(gateway));

// Upgrade gateway implementation with actual wethGateway address
admin.upgrade(
ITransparentUpgradeableProxy(address(gateway)),
address(
new L1ERC20GatewayValidium(
address(counterpartGateway),
address(l1Messenger),
address(template),
address(factory),
address(rollup),
address(wethGateway)
)
)
);

// Initialize L1 contracts
gateway.initialize();

Expand Down Expand Up @@ -146,21 +161,4 @@ contract L1WETHGatewayValidiumTest is ValidiumTestBase {
assertGt(l1Messenger.messageSendTimestamp(keccak256(xDomainCalldata)), 0);
}
}

function _deployGateway(address messenger) internal returns (L1ERC20GatewayValidium _gateway) {
_gateway = L1ERC20GatewayValidium(_deployProxy(address(0)));

admin.upgrade(
ITransparentUpgradeableProxy(address(_gateway)),
address(
new L1ERC20GatewayValidium(
address(counterpartGateway),
address(messenger),
address(template),
address(factory),
address(rollup)
)
)
);
}
}
15 changes: 14 additions & 1 deletion src/validium/L1ERC20GatewayValidium.sol
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium {
/// @dev Error thrown when amount is zero.
error ErrorAmountIsZero();

/// @dev Error thrown when caller is not the WETH gateway.
error ErrorCallerNotWethGateway();

/*************
* Constants *
*************/
Expand All @@ -47,6 +50,9 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium {
/// @notice The address of ScrollChainValidium contract in L2.
address public immutable scrollChainValidium;

/// @notice The address of L1WETHGatewayValidium contract.
address public immutable wethGateway;

/*************
* Variables *
*************/
Expand All @@ -67,18 +73,22 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium {
/// @param _messenger The address of `L1ScrollMessenger` contract in L1.
/// @param _l2TokenImplementation The address of `ScrollStandardERC20` implementation in L2.
/// @param _l2TokenFactory The address of `ScrollStandardERC20Factory` contract in L2.
/// @param _scrollChainValidium The address of `ScrollChainValidium` contract in L2.
/// @param _wethGateway The address of `L1WETHGatewayValidium` contract.
constructor(
address _counterpart,
address _messenger,
address _l2TokenImplementation,
address _l2TokenFactory,
address _scrollChainValidium
address _scrollChainValidium,
address _wethGateway
) ScrollGatewayBase(_counterpart, address(0), _messenger) {
_disableInitializers();

l2TokenImplementation = _l2TokenImplementation;
l2TokenFactory = _l2TokenFactory;
scrollChainValidium = _scrollChainValidium;
wethGateway = _wethGateway;
}

/// @notice Initialize the storage of L1ERC20GatewayValidium.
Expand Down Expand Up @@ -123,6 +133,9 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium {
uint256 _gasLimit,
uint256 _keyId
) external payable override {
// Only the WETH gateway can call this function to preserve the real sender
if (_msgSender() != wethGateway) revert ErrorCallerNotWethGateway();

_deposit(_token, _realSender, _to, _amount, new bytes(0), _gasLimit, _keyId);
}

Expand Down
Loading