@@ -65,6 +65,22 @@ impl Default for Options {
6565 }
6666}
6767
68+ impl Options {
69+ /// Set the on_connected callback
70+ pub fn with_on_connected < F , Fut > ( mut self , f : F ) -> Self
71+ where
72+ F : Fn ( Endpoint , Connection ) -> Fut + Send + Sync + ' static ,
73+ Fut : std:: future:: Future < Output = io:: Result < ( ) > > + Send + ' static ,
74+ {
75+ self . on_connected = Some ( Arc :: new ( move |ep, conn| {
76+ let ep = ep. clone ( ) ;
77+ let conn = conn. clone ( ) ;
78+ Box :: pin ( f ( ep, conn) )
79+ } ) ) ;
80+ self
81+ }
82+ }
83+
6884/// A reference to a connection that is owned by a connection pool.
6985#[ derive( Debug ) ]
7086pub struct ConnectionRef {
@@ -524,9 +540,9 @@ mod tests {
524540
525541 use iroh:: {
526542 discovery:: static_provider:: StaticProvider ,
527- endpoint:: Connection ,
543+ endpoint:: { Connection , ConnectionType } ,
528544 protocol:: { AcceptError , ProtocolHandler , Router } ,
529- NodeAddr , NodeId , SecretKey , Watcher ,
545+ Endpoint , NodeAddr , NodeId , SecretKey , Watcher ,
530546 } ;
531547 use n0_future:: { io, stream, BufferedStreamExt , StreamExt } ;
532548 use n0_snafu:: ResultExt ;
@@ -770,37 +786,41 @@ mod tests {
770786 Ok ( ( ) )
771787 }
772788
773- /// Uses an on_connected callback that delays for a long time.
774- ///
775- /// This checks that the pool timeout includes on_connected delay.
789+ /// Uses an on_connected callback to ensure that the connection is direct.
776790 #[ tokio:: test]
777791 // #[traced_test]
778- async fn on_connected_timeout ( ) -> TestResult < ( ) > {
792+ async fn on_connected_direct ( ) -> TestResult < ( ) > {
779793 let n = 1 ;
780794 let ( ids, routers, discovery) = echo_servers ( n) . await ?;
781795 let endpoint = iroh:: Endpoint :: builder ( )
782796 . discovery ( discovery)
783797 . bind ( )
784798 . await ?;
785- let on_connected: OnConnected = Arc :: new ( |_, _| {
786- Box :: pin ( async {
787- tokio:: time:: sleep ( Duration :: from_secs ( 20 ) ) . await ;
788- Ok ( ( ) )
789- } )
790- } ) ;
799+ let on_connected = |ep : Endpoint , conn : Connection | async move {
800+ let Ok ( id) = conn. remote_node_id ( ) else {
801+ return Err ( io:: Error :: other ( "unable to get node id" ) ) ;
802+ } ;
803+ let Some ( watcher) = ep. conn_type ( id) else {
804+ return Err ( io:: Error :: other ( "unable to get conn_type watcher" ) ) ;
805+ } ;
806+ let mut stream = watcher. stream ( ) ;
807+ while let Some ( status) = stream. next ( ) . await {
808+ if let ConnectionType :: Direct { .. } = status {
809+ return Ok ( ( ) ) ;
810+ }
811+ }
812+ Err ( io:: Error :: other ( "connection closed before becoming direct" ) )
813+ } ;
791814 let pool = ConnectionPool :: new (
792815 endpoint,
793816 ECHO_ALPN ,
794- Options {
795- on_connected : Some ( on_connected) ,
796- ..test_options ( )
797- } ,
817+ test_options ( ) . with_on_connected ( on_connected) ,
798818 ) ;
799819 let client = EchoClient { pool } ;
800820 let msg = b"Hello, pool!" . to_vec ( ) ;
801821 for id in & ids {
802822 let res = client. echo ( * id, msg. clone ( ) ) . await ;
803- assert ! ( matches! ( res, Err ( PoolConnectError :: Timeout ) ) ) ;
823+ assert ! ( res. is_ok ( ) ) ;
804824 }
805825 shutdown_routers ( routers) . await ;
806826 Ok ( ( ) )
0 commit comments