mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 11:49:12 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			89 lines
		
	
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
import unittest
 | 
						|
 | 
						|
from test.test_asyncio import functional as func_tests
 | 
						|
 | 
						|
 | 
						|
def tearDownModule():
 | 
						|
    asyncio.set_event_loop_policy(None)
 | 
						|
 | 
						|
 | 
						|
class ReceiveStuffProto(asyncio.BufferedProtocol):
 | 
						|
    def __init__(self, cb, con_lost_fut):
 | 
						|
        self.cb = cb
 | 
						|
        self.con_lost_fut = con_lost_fut
 | 
						|
 | 
						|
    def get_buffer(self, sizehint):
 | 
						|
        self.buffer = bytearray(100)
 | 
						|
        return self.buffer
 | 
						|
 | 
						|
    def buffer_updated(self, nbytes):
 | 
						|
        self.cb(self.buffer[:nbytes])
 | 
						|
 | 
						|
    def connection_lost(self, exc):
 | 
						|
        if exc is None:
 | 
						|
            self.con_lost_fut.set_result(None)
 | 
						|
        else:
 | 
						|
            self.con_lost_fut.set_exception(exc)
 | 
						|
 | 
						|
 | 
						|
class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
 | 
						|
 | 
						|
    def new_loop(self):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def test_buffered_proto_create_connection(self):
 | 
						|
 | 
						|
        NOISE = b'12345678+' * 1024
 | 
						|
 | 
						|
        async def client(addr):
 | 
						|
            data = b''
 | 
						|
 | 
						|
            def on_buf(buf):
 | 
						|
                nonlocal data
 | 
						|
                data += buf
 | 
						|
                if data == NOISE:
 | 
						|
                    tr.write(b'1')
 | 
						|
 | 
						|
            conn_lost_fut = self.loop.create_future()
 | 
						|
 | 
						|
            tr, pr = await self.loop.create_connection(
 | 
						|
                lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr)
 | 
						|
 | 
						|
            await conn_lost_fut
 | 
						|
 | 
						|
        async def on_server_client(reader, writer):
 | 
						|
            writer.write(NOISE)
 | 
						|
            await reader.readexactly(1)
 | 
						|
            writer.close()
 | 
						|
            await writer.wait_closed()
 | 
						|
 | 
						|
        srv = self.loop.run_until_complete(
 | 
						|
            asyncio.start_server(
 | 
						|
                on_server_client, '127.0.0.1', 0))
 | 
						|
 | 
						|
        addr = srv.sockets[0].getsockname()
 | 
						|
        self.loop.run_until_complete(
 | 
						|
            asyncio.wait_for(client(addr), 5))
 | 
						|
 | 
						|
        srv.close()
 | 
						|
        self.loop.run_until_complete(srv.wait_closed())
 | 
						|
 | 
						|
 | 
						|
class BufferedProtocolSelectorTests(BaseTestBufferedProtocol,
 | 
						|
                                    unittest.TestCase):
 | 
						|
 | 
						|
    def new_loop(self):
 | 
						|
        return asyncio.SelectorEventLoop()
 | 
						|
 | 
						|
 | 
						|
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
 | 
						|
class BufferedProtocolProactorTests(BaseTestBufferedProtocol,
 | 
						|
                                    unittest.TestCase):
 | 
						|
 | 
						|
    def new_loop(self):
 | 
						|
        return asyncio.ProactorEventLoop()
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    unittest.main()
 |