@@ -125,42 +125,47 @@ def test_get_conn_returns_socket(self):
125125 def test_sock_returns_socket (self ):
126126 self .assertIs (self .iface .sock , self .mock_sock )
127127
128- if not _IS_SYNC :
129128
130- def _make_async_iface (self ):
129+ if not _IS_SYNC :
130+
131+ class TestAsyncNetworkingInterface (AsyncUnitTest ):
132+ def _make_iface (self ):
131133 mock_transport = MagicMock ()
132134 mock_protocol = MagicMock ()
133135 mock_protocol .gettimeout = 10.0
134136 return AsyncNetworkingInterface ((mock_transport , mock_protocol ))
135137
136- def test_async_gettimeout_returns_protocol_timeout (self ):
137- iface = self ._make_async_iface ()
138+ def test_gettimeout_returns_protocol_timeout (self ):
139+ iface = self ._make_iface ()
138140 self .assertEqual (iface .gettimeout , 10.0 )
139141
140- def test_async_settimeout_delegates_to_protocol (self ):
141- iface = self ._make_async_iface ()
142+ def test_settimeout_delegates_to_protocol (self ):
143+ iface = self ._make_iface ()
142144 iface .settimeout (7.0 )
143145 iface .conn [1 ].settimeout .assert_called_once_with (7.0 )
144146
145- def test_async_is_closing_delegates_to_transport (self ):
146- iface = self ._make_async_iface ()
147+ def test_is_closing_delegates_to_transport (self ):
148+ iface = self ._make_iface ()
147149 iface .conn [0 ].is_closing .return_value = False
148150 self .assertFalse (iface .is_closing ())
149151
150- def test_async_get_conn_returns_protocol (self ):
151- iface = self ._make_async_iface ()
152+ def test_get_conn_returns_protocol (self ):
153+ iface = self ._make_iface ()
152154 self .assertIs (iface .get_conn , iface .conn [1 ])
153155
154- def test_async_sock_returns_transport_socket (self ):
155- iface = self ._make_async_iface ()
156+ def test_sock_returns_transport_socket (self ):
157+ iface = self ._make_iface ()
156158 sentinel = object ()
157159 iface .conn [0 ].get_extra_info .return_value = sentinel
158160 self .assertIs (iface .sock , sentinel )
159161 iface .conn [0 ].get_extra_info .assert_called_once_with ("socket" )
160162
161-
162- class TestPyMongoProtocolTimeout (AsyncUnitTest ):
163- if not _IS_SYNC :
163+ class TestPyMongoProtocol (AsyncUnitTest ):
164+ async def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
165+ proto = await _make_protocol ()
166+ proto ._max_message_size = max_size
167+ proto ._header = memoryview (bytearray (header_bytes ))
168+ return proto
164169
165170 async def test_initial_timeout_from_constructor (self ):
166171 proto = await _make_protocol (timeout = 3.0 )
@@ -175,16 +180,6 @@ async def test_default_timeout_is_none(self):
175180 proto = await _make_protocol ()
176181 self .assertIsNone (proto .gettimeout )
177182
178-
179- class TestPyMongoProtocolProcessHeader (AsyncUnitTest ):
180- if not _IS_SYNC :
181-
182- async def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
183- proto = await _make_protocol ()
184- proto ._max_message_size = max_size
185- proto ._header = memoryview (bytearray (header_bytes ))
186- return proto
187-
188183 async def test_normal_op_msg (self ):
189184 hdr = _make_header (32 , 1 , 99 , 2013 )
190185 proto = await self ._make_proto_with_header (hdr )
@@ -229,11 +224,7 @@ async def test_op_reply_op_code(self):
229224 self .assertEqual (op_code , 1 )
230225 self .assertFalse (expecting_compression )
231226
232-
233- class TestPyMongoProtocolProcessCompressionHeader (AsyncUnitTest ):
234- if not _IS_SYNC :
235-
236- async def test_returns_op_code_and_compressor_id (self ):
227+ async def test_compression_header_returns_op_code_and_compressor_id (self ):
237228 proto = await _make_protocol ()
238229 # op_code=2013, uncompressed_size=0, compressor_id=1 (snappy)
239230 data = struct .pack ("<iiB" , 2013 , 0 , 1 )
@@ -242,127 +233,13 @@ async def test_returns_op_code_and_compressor_id(self):
242233 self .assertEqual (op_code , 2013 )
243234 self .assertEqual (compressor_id , 1 )
244235
245- async def test_zlib_compressor_id (self ):
236+ async def test_compression_header_zlib_compressor_id (self ):
246237 proto = await _make_protocol ()
247238 data = struct .pack ("<iiB" , 2013 , 0 , 2 )
248239 proto ._compression_header = memoryview (bytearray (data ))
249240 _ , compressor_id = proto .process_compression_header ()
250241 self .assertEqual (compressor_id , 2 )
251242
252-
253- class TestPyMongoProtocolGetBuffer (AsyncUnitTest ):
254- if not _IS_SYNC :
255-
256- async def test_expecting_header_returns_full_header_slice (self ):
257- proto = await _make_protocol ()
258- proto ._expecting_header = True
259- proto ._header_index = 0
260- self .assertEqual (len (proto .get_buffer (0 )), 16 )
261-
262- async def test_expecting_header_partial_returns_remaining (self ):
263- proto = await _make_protocol ()
264- proto ._expecting_header = True
265- proto ._header_index = 8
266- self .assertEqual (len (proto .get_buffer (0 )), 8 )
267-
268- async def test_expecting_compression_returns_compression_slice (self ):
269- proto = await _make_protocol ()
270- proto ._expecting_header = False
271- proto ._expecting_compression = True
272- proto ._compression_index = 0
273- self .assertEqual (len (proto .get_buffer (0 )), 9 )
274-
275- async def test_expecting_compression_partial (self ):
276- proto = await _make_protocol ()
277- proto ._expecting_header = False
278- proto ._expecting_compression = True
279- proto ._compression_index = 5
280- self .assertEqual (len (proto .get_buffer (0 )), 4 )
281-
282- async def test_message_body_returns_remaining_slice (self ):
283- proto = await _make_protocol ()
284- proto ._expecting_header = False
285- proto ._expecting_compression = False
286- proto ._message = memoryview (bytearray (100 ))
287- proto ._message_index = 0
288- self .assertEqual (len (proto .get_buffer (0 )), 100 )
289-
290- async def test_connection_lost_allocates_drain_buffer (self ):
291- proto = await _make_protocol ()
292- proto ._connection_lost = True
293- proto ._message = None
294- self .assertEqual (len (proto .get_buffer (0 )), 2 ** 14 )
295-
296- async def test_connection_lost_reuses_existing_buffer (self ):
297- proto = await _make_protocol ()
298- proto ._connection_lost = True
299- proto ._message = memoryview (bytearray (50 ))
300- self .assertEqual (len (proto .get_buffer (0 )), 50 )
301-
302-
303- class TestPyMongoProtocolBufferUpdated (AsyncUnitTest ):
304- if not _IS_SYNC :
305-
306- async def test_zero_bytes_closes_connection (self ):
307- proto = await _make_protocol ()
308- proto .buffer_updated (0 )
309- self .assertTrue (proto ._connection_lost )
310-
311- async def test_connection_lost_returns_early (self ):
312- proto = await _make_protocol ()
313- proto ._connection_lost = True
314- proto ._header_index = 3
315- proto .buffer_updated (5 )
316- self .assertEqual (proto ._header_index , 3 )
317-
318- async def test_partial_header_increments_index (self ):
319- proto = await _make_protocol ()
320- proto ._expecting_header = True
321- proto ._header_index = 0
322- proto .buffer_updated (8 )
323- self .assertEqual (proto ._header_index , 8 )
324-
325- async def test_full_header_transitions_to_message (self ):
326- proto = await _make_protocol ()
327- proto ._expecting_header = True
328- hdr = _make_header (32 , 1 , 0 , 2013 )
329- proto ._header = memoryview (bytearray (hdr ))
330- proto ._header_index = 0
331- proto .buffer_updated (16 )
332- self .assertFalse (proto ._expecting_header )
333- self .assertEqual (proto ._message_size , 16 )
334-
335- async def test_invalid_header_closes_connection (self ):
336- proto = await _make_protocol ()
337- proto ._expecting_header = True
338- # length=16 (not > 16) triggers ProtocolError
339- hdr = _make_header (16 , 1 , 0 , 2013 )
340- proto ._header = memoryview (bytearray (hdr ))
341- proto ._header_index = 0
342- proto .buffer_updated (16 )
343- self .assertTrue (proto ._connection_lost )
344-
345- async def test_compression_header_processing (self ):
346- proto = await _make_protocol ()
347- proto ._expecting_header = False
348- proto ._expecting_compression = True
349- comp_hdr = struct .pack ("<iiB" , 2013 , 0 , 2 )
350- proto ._compression_header = memoryview (bytearray (comp_hdr ))
351- proto ._compression_index = 0
352- proto .buffer_updated (9 )
353- self .assertFalse (proto ._expecting_compression )
354- self .assertEqual (proto ._op_code , 2013 )
355- self .assertEqual (proto ._compressor_id , 2 )
356-
357- async def test_partial_compression_header_increments_index (self ):
358- proto = await _make_protocol ()
359- proto ._expecting_header = False
360- proto ._expecting_compression = True
361- proto ._compression_index = 0
362- proto .buffer_updated (4 )
363- self .assertEqual (proto ._compression_index , 4 )
364- self .assertTrue (proto ._expecting_compression )
365-
366243 async def test_message_complete_resolves_pending_future (self ):
367244 proto = await _make_protocol ()
368245 proto ._expecting_header = False
@@ -384,51 +261,11 @@ async def test_message_complete_resolves_pending_future(self):
384261 self .assertIsNone (compressor_id )
385262 self .assertEqual (response_to , 42 )
386263
387- async def test_message_complete_no_pending_creates_new_future (self ):
388- proto = await _make_protocol ()
389- proto ._expecting_header = False
390- proto ._expecting_compression = False
391- proto ._message_size = 5
392- proto ._message = memoryview (bytearray (5 ))
393- proto ._message_index = 0
394- proto ._op_code = 2013
395- proto ._compressor_id = None
396- proto ._response_to = 0
397-
398- self .assertFalse (proto ._pending_messages )
399- proto .buffer_updated (5 )
400- self .assertEqual (len (proto ._done_messages ), 1 )
401-
402- async def test_partial_message_increments_index (self ):
403- proto = await _make_protocol ()
404- proto ._expecting_header = False
405- proto ._expecting_compression = False
406- proto ._message_size = 20
407- proto ._message = memoryview (bytearray (20 ))
408- proto ._message_index = 0
409- proto .buffer_updated (7 )
410- self .assertEqual (proto ._message_index , 7 )
411-
412-
413- class TestPyMongoProtocolClose (AsyncUnitTest ):
414- if not _IS_SYNC :
415-
416- async def test_close_sets_connection_lost_flag (self ):
417- proto = await _make_protocol ()
418- proto .close ()
419- self .assertTrue (proto ._connection_lost )
420-
421264 async def test_close_aborts_transport (self ):
422265 proto = await _make_protocol ()
423266 proto .close ()
424267 self .assertTrue (proto .transport .abort .called )
425268
426- async def test_connection_lost_resolves_closed_future (self ):
427- proto = await _make_protocol ()
428- self .assertFalse (proto ._closed .done ())
429- proto .connection_lost (None )
430- self .assertTrue (proto ._closed .done ())
431-
432269 async def test_connection_lost_twice_does_not_raise (self ):
433270 proto = await _make_protocol ()
434271 proto .connection_lost (None )
@@ -444,10 +281,7 @@ async def test_close_with_exception_propagates_to_pending(self):
444281 await fut
445282 self .assertIn ("connection reset" , str (ctx .exception ))
446283
447-
448- class TestAsyncSocketReceive (AsyncUnitTest ):
449- if not _IS_SYNC :
450-
284+ class TestAsyncSocketReceive (AsyncUnitTest ):
451285 async def test_reads_full_data_in_one_call (self ):
452286 data = b"hello world!"
453287 length = len (data )
@@ -468,15 +302,16 @@ async def test_reads_data_in_multiple_chunks(self):
468302 chunk1 , chunk2 = data [:4 ], data [4 :]
469303 mock_sock = MagicMock ()
470304 loop = asyncio .get_running_loop ()
471- calls = [ 0 ]
305+ calls = 0
472306
473307 async def fake_recv_into (sock , buf ):
474- if calls [0 ] == 0 :
308+ nonlocal calls
309+ if calls == 0 :
475310 buf [: len (chunk1 )] = chunk1
476- calls [ 0 ] += 1
311+ calls += 1
477312 return len (chunk1 )
478313 buf [: len (chunk2 )] = chunk2
479- calls [ 0 ] += 1
314+ calls += 1
480315 return len (chunk2 )
481316
482317 with patch .object (loop , "sock_recv_into" , new = AsyncMock (side_effect = fake_recv_into )):
0 commit comments