"""Tests for the sampling profiler (profile.sample).""" import contextlib import io import marshal import os import socket import subprocess import sys import tempfile import unittest from unittest import mock from profile.pstats_collector import PstatsCollector from profile.stack_collector import ( CollapsedStackCollector, ) from test.support.os_helper import unlink from test.support import force_not_colorized_test_class, SHORT_TIMEOUT from test.support.socket_helper import find_unused_port from test.support import requires_subprocess PROCESS_VM_READV_SUPPORTED = False try: from _remote_debugging import PROCESS_VM_READV_SUPPORTED import _remote_debugging except ImportError: raise unittest.SkipTest( "Test only runs when _remote_debugging is available" ) else: import profile.sample from profile.sample import SampleProfiler class MockFrameInfo: """Mock FrameInfo for testing since the real one isn't accessible.""" def __init__(self, filename, lineno, funcname): self.filename = filename self.lineno = lineno self.funcname = funcname def __repr__(self): return f"MockFrameInfo(filename='{self.filename}', lineno={self.lineno}, funcname='{self.funcname}')" skip_if_not_supported = unittest.skipIf( ( sys.platform != "darwin" and sys.platform != "linux" and sys.platform != "win32" ), "Test only runs on Linux, Windows and MacOS", ) @contextlib.contextmanager def test_subprocess(script): # Find an unused port for socket communication port = find_unused_port() # Inject socket connection code at the beginning of the script socket_code = f''' import socket _test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) _test_sock.connect(('localhost', {port})) _test_sock.sendall(b"ready") ''' # Combine socket code with user script full_script = socket_code + script # Create server socket to wait for process to be ready server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_socket.bind(("localhost", port)) server_socket.settimeout(SHORT_TIMEOUT) server_socket.listen(1) proc = subprocess.Popen( [sys.executable, "-c", full_script], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) client_socket = None try: # Wait for process to connect and send ready signal client_socket, _ = server_socket.accept() server_socket.close() response = client_socket.recv(1024) if response != b"ready": raise RuntimeError(f"Unexpected response from subprocess: {response}") yield proc finally: if client_socket is not None: client_socket.close() if proc.poll() is None: proc.kill() proc.wait() def close_and_unlink(file): file.close() unlink(file.name) class TestSampleProfilerComponents(unittest.TestCase): """Unit tests for individual profiler components.""" def test_mock_frame_info_with_empty_and_unicode_values(self): """Test MockFrameInfo handles empty strings, unicode characters, and very long names correctly.""" # Test with empty strings frame = MockFrameInfo("", 0, "") self.assertEqual(frame.filename, "") self.assertEqual(frame.lineno, 0) self.assertEqual(frame.funcname, "") self.assertIn("filename=''", repr(frame)) # Test with unicode characters frame = MockFrameInfo("文件.py", 42, "函数名") self.assertEqual(frame.filename, "文件.py") self.assertEqual(frame.funcname, "函数名") # Test with very long names long_filename = "x" * 1000 + ".py" long_funcname = "func_" + "x" * 1000 frame = MockFrameInfo(long_filename, 999999, long_funcname) self.assertEqual(frame.filename, long_filename) self.assertEqual(frame.lineno, 999999) self.assertEqual(frame.funcname, long_funcname) def test_pstats_collector_with_extreme_intervals_and_empty_data(self): """Test PstatsCollector handles zero/large intervals, empty frames, None thread IDs, and duplicate frames.""" # Test with zero interval collector = PstatsCollector(sample_interval_usec=0) self.assertEqual(collector.sample_interval_usec, 0) # Test with very large interval collector = PstatsCollector(sample_interval_usec=1000000000) self.assertEqual(collector.sample_interval_usec, 1000000000) # Test collecting empty frames list collector = PstatsCollector(sample_interval_usec=1000) collector.collect([]) self.assertEqual(len(collector.result), 0) # Test collecting frames with None thread id test_frames = [(None, [MockFrameInfo("file.py", 10, "func")])] collector.collect(test_frames) # Should still process the frames self.assertEqual(len(collector.result), 1) # Test collecting duplicate frames in same sample test_frames = [ ( 1, [ MockFrameInfo("file.py", 10, "func1"), MockFrameInfo("file.py", 10, "func1"), # Duplicate ], ) ] collector = PstatsCollector(sample_interval_usec=1000) collector.collect(test_frames) # Should count both occurrences self.assertEqual( collector.result[("file.py", 10, "func1")]["cumulative_calls"], 2 ) def test_pstats_collector_single_frame_stacks(self): """Test PstatsCollector with single-frame call stacks to trigger len(frames) <= 1 branch.""" collector = PstatsCollector(sample_interval_usec=1000) # Test with exactly one frame (should trigger the <= 1 condition) single_frame = [(1, [MockFrameInfo("single.py", 10, "single_func")])] collector.collect(single_frame) # Should record the single frame with inline call self.assertEqual(len(collector.result), 1) single_key = ("single.py", 10, "single_func") self.assertIn(single_key, collector.result) self.assertEqual(collector.result[single_key]["direct_calls"], 1) self.assertEqual(collector.result[single_key]["cumulative_calls"], 1) # Test with empty frames (should also trigger <= 1 condition) empty_frames = [(1, [])] collector.collect(empty_frames) # Should not add any new entries self.assertEqual( len(collector.result), 1 ) # Still just the single frame # Test mixed single and multi-frame stacks mixed_frames = [ ( 1, [MockFrameInfo("single2.py", 20, "single_func2")], ), # Single frame ( 2, [ # Multi-frame stack MockFrameInfo("multi.py", 30, "multi_func1"), MockFrameInfo("multi.py", 40, "multi_func2"), ], ), ] collector.collect(mixed_frames) # Should have recorded all functions self.assertEqual( len(collector.result), 4 ) # single + single2 + multi1 + multi2 # Verify single frame handling single2_key = ("single2.py", 20, "single_func2") self.assertIn(single2_key, collector.result) self.assertEqual(collector.result[single2_key]["direct_calls"], 1) self.assertEqual(collector.result[single2_key]["cumulative_calls"], 1) # Verify multi-frame handling still works multi1_key = ("multi.py", 30, "multi_func1") multi2_key = ("multi.py", 40, "multi_func2") self.assertIn(multi1_key, collector.result) self.assertIn(multi2_key, collector.result) self.assertEqual(collector.result[multi1_key]["direct_calls"], 1) self.assertEqual( collector.result[multi2_key]["cumulative_calls"], 1 ) # Called from multi1 def test_collapsed_stack_collector_with_empty_and_deep_stacks(self): """Test CollapsedStackCollector handles empty frames, single-frame stacks, and very deep call stacks.""" collector = CollapsedStackCollector() # Test with empty frames collector.collect([]) self.assertEqual(len(collector.call_trees), 0) # Test with single frame stack test_frames = [(1, [("file.py", 10, "func")])] collector.collect(test_frames) self.assertEqual(len(collector.call_trees), 1) self.assertEqual(collector.call_trees[0], [("file.py", 10, "func")]) # Test with very deep stack deep_stack = [(f"file{i}.py", i, f"func{i}") for i in range(100)] test_frames = [(1, deep_stack)] collector = CollapsedStackCollector() collector.collect(test_frames) self.assertEqual(len(collector.call_trees[0]), 100) # Check it's properly reversed self.assertEqual( collector.call_trees[0][0], ("file99.py", 99, "func99") ) self.assertEqual(collector.call_trees[0][-1], ("file0.py", 0, "func0")) def test_pstats_collector_basic(self): """Test basic PstatsCollector functionality.""" collector = PstatsCollector(sample_interval_usec=1000) # Test empty state self.assertEqual(len(collector.result), 0) self.assertEqual(len(collector.stats), 0) # Test collecting sample data test_frames = [ ( 1, [ MockFrameInfo("file.py", 10, "func1"), MockFrameInfo("file.py", 20, "func2"), ], ) ] collector.collect(test_frames) # Should have recorded calls for both functions self.assertEqual(len(collector.result), 2) self.assertIn(("file.py", 10, "func1"), collector.result) self.assertIn(("file.py", 20, "func2"), collector.result) # Top-level function should have direct call self.assertEqual( collector.result[("file.py", 10, "func1")]["direct_calls"], 1 ) self.assertEqual( collector.result[("file.py", 10, "func1")]["cumulative_calls"], 1 ) # Calling function should have cumulative call but no direct calls self.assertEqual( collector.result[("file.py", 20, "func2")]["cumulative_calls"], 1 ) self.assertEqual( collector.result[("file.py", 20, "func2")]["direct_calls"], 0 ) def test_pstats_collector_create_stats(self): """Test PstatsCollector stats creation.""" collector = PstatsCollector( sample_interval_usec=1000000 ) # 1 second intervals test_frames = [ ( 1, [ MockFrameInfo("file.py", 10, "func1"), MockFrameInfo("file.py", 20, "func2"), ], ) ] collector.collect(test_frames) collector.collect(test_frames) # Collect twice collector.create_stats() # Check stats format: (direct_calls, cumulative_calls, tt, ct, callers) func1_stats = collector.stats[("file.py", 10, "func1")] self.assertEqual(func1_stats[0], 2) # direct_calls (top of stack) self.assertEqual(func1_stats[1], 2) # cumulative_calls self.assertEqual( func1_stats[2], 2.0 ) # tt (total time - 2 samples * 1 sec) self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time) func2_stats = collector.stats[("file.py", 20, "func2")] self.assertEqual( func2_stats[0], 0 ) # direct_calls (never top of stack) self.assertEqual( func2_stats[1], 2 ) # cumulative_calls (appears in stack) self.assertEqual(func2_stats[2], 0.0) # tt (no direct calls) self.assertEqual(func2_stats[3], 2.0) # ct (cumulative time) def test_collapsed_stack_collector_basic(self): collector = CollapsedStackCollector() # Test empty state self.assertEqual(len(collector.call_trees), 0) self.assertEqual(len(collector.function_samples), 0) # Test collecting sample data test_frames = [ (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) ] collector.collect(test_frames) # Should store call tree (reversed) self.assertEqual(len(collector.call_trees), 1) expected_tree = [("file.py", 20, "func2"), ("file.py", 10, "func1")] self.assertEqual(collector.call_trees[0], expected_tree) # Should count function samples self.assertEqual( collector.function_samples[("file.py", 10, "func1")], 1 ) self.assertEqual( collector.function_samples[("file.py", 20, "func2")], 1 ) def test_collapsed_stack_collector_export(self): collapsed_out = tempfile.NamedTemporaryFile(delete=False) self.addCleanup(close_and_unlink, collapsed_out) collector = CollapsedStackCollector() test_frames1 = [ (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) ] test_frames2 = [ (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) ] # Same stack test_frames3 = [(1, [("other.py", 5, "other_func")])] collector.collect(test_frames1) collector.collect(test_frames2) collector.collect(test_frames3) collector.export(collapsed_out.name) # Check file contents with open(collapsed_out.name, "r") as f: content = f.read() lines = content.strip().split("\n") self.assertEqual(len(lines), 2) # Two unique stacks # Check collapsed format: file:func:line;file:func:line count stack1_expected = "file.py:func2:20;file.py:func1:10 2" stack2_expected = "other.py:other_func:5 1" self.assertIn(stack1_expected, lines) self.assertIn(stack2_expected, lines) def test_pstats_collector_export(self): collector = PstatsCollector( sample_interval_usec=1000000 ) # 1 second intervals test_frames1 = [ ( 1, [ MockFrameInfo("file.py", 10, "func1"), MockFrameInfo("file.py", 20, "func2"), ], ) ] test_frames2 = [ ( 1, [ MockFrameInfo("file.py", 10, "func1"), MockFrameInfo("file.py", 20, "func2"), ], ) ] # Same stack test_frames3 = [(1, [MockFrameInfo("other.py", 5, "other_func")])] collector.collect(test_frames1) collector.collect(test_frames2) collector.collect(test_frames3) pstats_out = tempfile.NamedTemporaryFile( suffix=".pstats", delete=False ) self.addCleanup(close_and_unlink, pstats_out) collector.export(pstats_out.name) # Check file can be loaded with marshal with open(pstats_out.name, "rb") as f: stats_data = marshal.load(f) # Should be a dictionary with the sampled marker self.assertIsInstance(stats_data, dict) self.assertIn(("__sampled__",), stats_data) self.assertTrue(stats_data[("__sampled__",)]) # Should have function data function_entries = [ k for k in stats_data.keys() if k != ("__sampled__",) ] self.assertGreater(len(function_entries), 0) # Check specific function stats format: (cc, nc, tt, ct, callers) func1_key = ("file.py", 10, "func1") func2_key = ("file.py", 20, "func2") other_key = ("other.py", 5, "other_func") self.assertIn(func1_key, stats_data) self.assertIn(func2_key, stats_data) self.assertIn(other_key, stats_data) # Check func1 stats (should have 2 samples) func1_stats = stats_data[func1_key] self.assertEqual(func1_stats[0], 2) # total_calls self.assertEqual(func1_stats[1], 2) # nc (non-recursive calls) self.assertEqual(func1_stats[2], 2.0) # tt (total time) self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time) class TestSampleProfiler(unittest.TestCase): """Test the SampleProfiler class.""" def test_sample_profiler_initialization(self): """Test SampleProfiler initialization with various parameters.""" from profile.sample import SampleProfiler # Mock RemoteUnwinder to avoid permission issues with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock.MagicMock() # Test basic initialization profiler = SampleProfiler( pid=12345, sample_interval_usec=1000, all_threads=False ) self.assertEqual(profiler.pid, 12345) self.assertEqual(profiler.sample_interval_usec, 1000) self.assertEqual(profiler.all_threads, False) # Test with all_threads=True profiler = SampleProfiler( pid=54321, sample_interval_usec=5000, all_threads=True ) self.assertEqual(profiler.pid, 54321) self.assertEqual(profiler.sample_interval_usec, 5000) self.assertEqual(profiler.all_threads, True) def test_sample_profiler_sample_method_timing(self): """Test that the sample method respects duration and handles timing correctly.""" from profile.sample import SampleProfiler # Mock the unwinder to avoid needing a real process mock_unwinder = mock.MagicMock() mock_unwinder.get_stack_trace.return_value = [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ] with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder profiler = SampleProfiler( pid=12345, sample_interval_usec=100000, all_threads=False ) # 100ms interval # Mock collector mock_collector = mock.MagicMock() # Mock time to control the sampling loop start_time = 1000.0 times = [ start_time + i * 0.1 for i in range(12) ] # 0, 0.1, 0.2, ..., 1.1 seconds with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): profiler.sample(mock_collector, duration_sec=1) result = output.getvalue() # Should have captured approximately 10 samples (1 second / 0.1 second interval) self.assertIn("Captured", result) self.assertIn("samples", result) # Verify collector was called multiple times self.assertGreaterEqual(mock_collector.collect.call_count, 5) self.assertLessEqual(mock_collector.collect.call_count, 11) def test_sample_profiler_error_handling(self): """Test that the sample method handles errors gracefully.""" from profile.sample import SampleProfiler # Mock unwinder that raises errors mock_unwinder = mock.MagicMock() error_sequence = [ RuntimeError("Process died"), [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ], UnicodeDecodeError("utf-8", b"", 0, 1, "invalid"), [ ( 1, [ mock.MagicMock( filename="test.py", lineno=20, funcname="test_func2", ) ], ) ], OSError("Permission denied"), ] mock_unwinder.get_stack_trace.side_effect = error_sequence with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder profiler = SampleProfiler( pid=12345, sample_interval_usec=10000, all_threads=False ) mock_collector = mock.MagicMock() # Control timing to run exactly 5 samples times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06] with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): profiler.sample(mock_collector, duration_sec=0.05) result = output.getvalue() # Should report error rate self.assertIn("Error rate:", result) self.assertIn("%", result) # Collector should have been called only for successful samples (should be > 0) self.assertGreater(mock_collector.collect.call_count, 0) self.assertLessEqual(mock_collector.collect.call_count, 3) def test_sample_profiler_missed_samples_warning(self): """Test that the profiler warns about missed samples when sampling is too slow.""" from profile.sample import SampleProfiler mock_unwinder = mock.MagicMock() mock_unwinder.get_stack_trace.return_value = [ ( 1, [ mock.MagicMock( filename="test.py", lineno=10, funcname="test_func" ) ], ) ] with mock.patch( "_remote_debugging.RemoteUnwinder" ) as mock_unwinder_class: mock_unwinder_class.return_value = mock_unwinder # Use very short interval that we'll miss profiler = SampleProfiler( pid=12345, sample_interval_usec=1000, all_threads=False ) # 1ms interval mock_collector = mock.MagicMock() # Simulate slow sampling where we miss many samples times = [ 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, ] # Extra time points to avoid StopIteration with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): profiler.sample(mock_collector, duration_sec=0.5) result = output.getvalue() # Should warn about missed samples self.assertIn("Warning: missed", result) self.assertIn("samples from the expected total", result) @force_not_colorized_test_class class TestPrintSampledStats(unittest.TestCase): """Test the print_sampled_stats function.""" def setUp(self): """Set up test data.""" # Mock stats data self.mock_stats = mock.MagicMock() self.mock_stats.stats = { ("file1.py", 10, "func1"): ( 100, 100, 0.5, 0.5, {}, ), # cc, nc, tt, ct, callers ("file2.py", 20, "func2"): (50, 50, 0.25, 0.3, {}), ("file3.py", 30, "func3"): (200, 200, 1.5, 2.0, {}), ("file4.py", 40, "func4"): ( 10, 10, 0.001, 0.001, {}, ), # millisecond range ("file5.py", 50, "func5"): ( 5, 5, 0.000001, 0.000002, {}, ), # microsecond range } def test_print_sampled_stats_basic(self): """Test basic print_sampled_stats functionality.""" from profile.sample import print_sampled_stats # Capture output with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(self.mock_stats, sample_interval_usec=100) result = output.getvalue() # Check header is present self.assertIn("Profile Stats:", result) self.assertIn("nsamples", result) self.assertIn("tottime", result) self.assertIn("cumtime", result) # Check functions are present self.assertIn("func1", result) self.assertIn("func2", result) self.assertIn("func3", result) def test_print_sampled_stats_sorting(self): """Test different sorting options.""" from profile.sample import print_sampled_stats # Test sort by calls with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=0, sample_interval_usec=100 ) result = output.getvalue() lines = result.strip().split("\n") # Find the data lines (skip header) data_lines = [l for l in lines if "file" in l and ".py" in l] # func3 should be first (200 calls) self.assertIn("func3", data_lines[0]) # Test sort by time with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=1, sample_interval_usec=100 ) result = output.getvalue() lines = result.strip().split("\n") data_lines = [l for l in lines if "file" in l and ".py" in l] # func3 should be first (1.5s time) self.assertIn("func3", data_lines[0]) def test_print_sampled_stats_limit(self): """Test limiting output rows.""" from profile.sample import print_sampled_stats with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, limit=2, sample_interval_usec=100 ) result = output.getvalue() # Count function entries in the main stats section (not in summary) lines = result.split("\n") # Find where the main stats section ends (before summary) main_section_lines = [] for line in lines: if "Summary of Interesting Functions:" in line: break main_section_lines.append(line) # Count function entries only in main section func_count = sum( 1 for line in main_section_lines if "func" in line and ".py" in line ) self.assertEqual(func_count, 2) def test_print_sampled_stats_time_units(self): """Test proper time unit selection.""" from profile.sample import print_sampled_stats with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(self.mock_stats, sample_interval_usec=100) result = output.getvalue() # Should use seconds for the header since max time is > 1s self.assertIn("tottime (s)", result) self.assertIn("cumtime (s)", result) # Test with only microsecond-range times micro_stats = mock.MagicMock() micro_stats.stats = { ("file1.py", 10, "func1"): (100, 100, 0.000005, 0.000010, {}), } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(micro_stats, sample_interval_usec=100) result = output.getvalue() # Should use microseconds self.assertIn("tottime (μs)", result) self.assertIn("cumtime (μs)", result) def test_print_sampled_stats_summary(self): """Test summary section generation.""" from profile.sample import print_sampled_stats with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, show_summary=True, sample_interval_usec=100, ) result = output.getvalue() # Check summary sections are present self.assertIn("Summary of Interesting Functions:", result) self.assertIn( "Functions with Highest Direct/Cumulative Ratio (Hot Spots):", result, ) self.assertIn( "Functions with Highest Call Frequency (Indirect Calls):", result ) self.assertIn( "Functions with Highest Call Magnification (Cumulative/Direct):", result, ) def test_print_sampled_stats_no_summary(self): """Test disabling summary output.""" from profile.sample import print_sampled_stats with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, show_summary=False, sample_interval_usec=100, ) result = output.getvalue() # Summary should not be present self.assertNotIn("Summary of Interesting Functions:", result) def test_print_sampled_stats_empty_stats(self): """Test with empty stats.""" from profile.sample import print_sampled_stats empty_stats = mock.MagicMock() empty_stats.stats = {} with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(empty_stats, sample_interval_usec=100) result = output.getvalue() # Should still print header self.assertIn("Profile Stats:", result) def test_print_sampled_stats_sample_percentage_sorting(self): """Test sample percentage sorting options.""" from profile.sample import print_sampled_stats # Add a function with high sample percentage (more direct calls than func3's 200) self.mock_stats.stats[("expensive.py", 60, "expensive_func")] = ( 300, # direct calls (higher than func3's 200) 300, # cumulative calls 1.0, # total time 1.0, # cumulative time {}, ) # Test sort by sample percentage with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=3, sample_interval_usec=100 ) # sample percentage result = output.getvalue() lines = result.strip().split("\n") data_lines = [l for l in lines if ".py" in l and "func" in l] # expensive_func should be first (highest sample percentage) self.assertIn("expensive_func", data_lines[0]) def test_print_sampled_stats_with_recursive_calls(self): """Test print_sampled_stats with recursive calls where nc != cc.""" from profile.sample import print_sampled_stats # Create stats with recursive calls (nc != cc) recursive_stats = mock.MagicMock() recursive_stats.stats = { # (direct_calls, cumulative_calls, tt, ct, callers) - recursive function ("recursive.py", 10, "factorial"): ( 5, # direct_calls 10, # cumulative_calls (appears more times in stack due to recursion) 0.5, 0.6, {}, ), ("normal.py", 20, "normal_func"): ( 3, # direct_calls 3, # cumulative_calls (same as direct for non-recursive) 0.2, 0.2, {}, ), } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(recursive_stats, sample_interval_usec=100) result = output.getvalue() # Should display recursive calls as "5/10" format self.assertIn("5/10", result) # nc/cc format for recursive calls self.assertIn("3", result) # just nc for non-recursive calls self.assertIn("factorial", result) self.assertIn("normal_func", result) def test_print_sampled_stats_with_zero_call_counts(self): """Test print_sampled_stats with zero call counts to trigger division protection.""" from profile.sample import print_sampled_stats # Create stats with zero call counts zero_stats = mock.MagicMock() zero_stats.stats = { ("file.py", 10, "zero_calls"): (0, 0, 0.0, 0.0, {}), # Zero calls ("file.py", 20, "normal_func"): ( 5, 5, 0.1, 0.1, {}, ), # Normal function } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats(zero_stats, sample_interval_usec=100) result = output.getvalue() # Should handle zero call counts gracefully self.assertIn("zero_calls", result) self.assertIn("zero_calls", result) self.assertIn("normal_func", result) def test_print_sampled_stats_sort_by_name(self): """Test sort by function name option.""" from profile.sample import print_sampled_stats with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( self.mock_stats, sort=-1, sample_interval_usec=100 ) # sort by name result = output.getvalue() lines = result.strip().split("\n") # Find the data lines (skip header and summary) # Data lines start with whitespace and numbers, and contain filename:lineno(function) data_lines = [] for line in lines: # Skip header lines and summary sections if ( line.startswith(" ") and "(" in line and ")" in line and not line.startswith( " 1." ) # Skip summary lines that start with times and not line.startswith( " 0." ) # Skip summary lines that start with times and not "per call" in line # Skip summary lines and not "calls" in line # Skip summary lines and not "total time" in line # Skip summary lines and not "cumulative time" in line ): # Skip summary lines data_lines.append(line) # Extract just the function names for comparison func_names = [] import re for line in data_lines: # Function name is between the last ( and ), accounting for ANSI color codes match = re.search(r"\(([^)]+)\)$", line) if match: func_name = match.group(1) # Remove ANSI color codes func_name = re.sub(r"\x1b\[[0-9;]*m", "", func_name) func_names.append(func_name) # Verify we extracted function names and they are sorted self.assertGreater( len(func_names), 0, "Should have extracted some function names" ) self.assertEqual( func_names, sorted(func_names), f"Function names {func_names} should be sorted alphabetically", ) def test_print_sampled_stats_with_zero_time_functions(self): """Test summary sections with functions that have zero time.""" from profile.sample import print_sampled_stats # Create stats with zero-time functions zero_time_stats = mock.MagicMock() zero_time_stats.stats = { ("file1.py", 10, "zero_time_func"): ( 5, 5, 0.0, 0.0, {}, ), # Zero time ("file2.py", 20, "normal_func"): ( 3, 3, 0.1, 0.1, {}, ), # Normal time } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( zero_time_stats, show_summary=True, sample_interval_usec=100, ) result = output.getvalue() # Should handle zero-time functions gracefully in summary self.assertIn("Summary of Interesting Functions:", result) self.assertIn("zero_time_func", result) self.assertIn("normal_func", result) def test_print_sampled_stats_with_malformed_qualified_names(self): """Test summary generation with function names that don't contain colons.""" from profile.sample import print_sampled_stats # Create stats with function names that would create malformed qualified names malformed_stats = mock.MagicMock() malformed_stats.stats = { # Function name without clear module separation ("no_colon_func", 10, "func"): (3, 3, 0.1, 0.1, {}), ("", 20, "empty_filename_func"): (2, 2, 0.05, 0.05, {}), ("normal.py", 30, "normal_func"): (5, 5, 0.2, 0.2, {}), } with io.StringIO() as output: with mock.patch("sys.stdout", output): print_sampled_stats( malformed_stats, show_summary=True, sample_interval_usec=100, ) result = output.getvalue() # Should handle malformed names gracefully in summary aggregation self.assertIn("Summary of Interesting Functions:", result) # All function names should appear somewhere in the output self.assertIn("func", result) self.assertIn("empty_filename_func", result) self.assertIn("normal_func", result) def test_print_sampled_stats_with_recursive_call_stats_creation(self): """Test create_stats with recursive call data to trigger total_rec_calls branch.""" collector = PstatsCollector(sample_interval_usec=1000000) # 1 second # Simulate recursive function data where total_rec_calls would be set # We need to manually manipulate the collector result to test this branch collector.result = { ("recursive.py", 10, "factorial"): { "total_rec_calls": 3, # Non-zero recursive calls "direct_calls": 5, "cumulative_calls": 10, }, ("normal.py", 20, "normal_func"): { "total_rec_calls": 0, # Zero recursive calls "direct_calls": 2, "cumulative_calls": 5, }, } collector.create_stats() # Check that recursive calls are handled differently from non-recursive factorial_stats = collector.stats[("recursive.py", 10, "factorial")] normal_stats = collector.stats[("normal.py", 20, "normal_func")] # factorial should use cumulative_calls (10) as nc self.assertEqual( factorial_stats[1], 10 ) # nc should be cumulative_calls self.assertEqual(factorial_stats[0], 5) # cc should be direct_calls # normal_func should use cumulative_calls as nc self.assertEqual(normal_stats[1], 5) # nc should be cumulative_calls self.assertEqual(normal_stats[0], 2) # cc should be direct_calls @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) class TestRecursiveFunctionProfiling(unittest.TestCase): """Test profiling of recursive functions and complex call patterns.""" def test_recursive_function_call_counting(self): """Test that recursive function calls are counted correctly.""" collector = PstatsCollector(sample_interval_usec=1000) # Simulate a recursive call pattern: fibonacci(5) calling itself recursive_frames = [ ( 1, [ # First sample: deep in recursion MockFrameInfo("fib.py", 10, "fibonacci"), MockFrameInfo("fib.py", 10, "fibonacci"), # recursive call MockFrameInfo( "fib.py", 10, "fibonacci" ), # deeper recursion MockFrameInfo("fib.py", 10, "fibonacci"), # even deeper MockFrameInfo("main.py", 5, "main"), # main caller ], ), ( 1, [ # Second sample: different recursion depth MockFrameInfo("fib.py", 10, "fibonacci"), MockFrameInfo("fib.py", 10, "fibonacci"), # recursive call MockFrameInfo("main.py", 5, "main"), # main caller ], ), ( 1, [ # Third sample: back to deeper recursion MockFrameInfo("fib.py", 10, "fibonacci"), MockFrameInfo("fib.py", 10, "fibonacci"), MockFrameInfo("fib.py", 10, "fibonacci"), MockFrameInfo("main.py", 5, "main"), ], ), ] for frames in recursive_frames: collector.collect([frames]) collector.create_stats() # Check that recursive calls are counted properly fib_key = ("fib.py", 10, "fibonacci") main_key = ("main.py", 5, "main") self.assertIn(fib_key, collector.stats) self.assertIn(main_key, collector.stats) # Fibonacci should have many calls due to recursion fib_stats = collector.stats[fib_key] direct_calls, cumulative_calls, tt, ct, callers = fib_stats # Should have recorded multiple calls (9 total appearances in samples) self.assertEqual(cumulative_calls, 9) self.assertGreater(tt, 0) # Should have some total time self.assertGreater(ct, 0) # Should have some cumulative time # Main should have fewer calls main_stats = collector.stats[main_key] main_direct_calls, main_cumulative_calls = main_stats[0], main_stats[1] self.assertEqual(main_direct_calls, 0) # Never directly executing self.assertEqual(main_cumulative_calls, 3) # Appears in all 3 samples def test_nested_function_hierarchy(self): """Test profiling of deeply nested function calls.""" collector = PstatsCollector(sample_interval_usec=1000) # Simulate a deep call hierarchy deep_call_frames = [ ( 1, [ MockFrameInfo("level1.py", 10, "level1_func"), MockFrameInfo("level2.py", 20, "level2_func"), MockFrameInfo("level3.py", 30, "level3_func"), MockFrameInfo("level4.py", 40, "level4_func"), MockFrameInfo("level5.py", 50, "level5_func"), MockFrameInfo("main.py", 5, "main"), ], ), ( 1, [ # Same hierarchy sampled again MockFrameInfo("level1.py", 10, "level1_func"), MockFrameInfo("level2.py", 20, "level2_func"), MockFrameInfo("level3.py", 30, "level3_func"), MockFrameInfo("level4.py", 40, "level4_func"), MockFrameInfo("level5.py", 50, "level5_func"), MockFrameInfo("main.py", 5, "main"), ], ), ] for frames in deep_call_frames: collector.collect([frames]) collector.create_stats() # All levels should be recorded for level in range(1, 6): key = (f"level{level}.py", level * 10, f"level{level}_func") self.assertIn(key, collector.stats) stats = collector.stats[key] direct_calls, cumulative_calls, tt, ct, callers = stats # Each level should appear in stack twice (2 samples) self.assertEqual(cumulative_calls, 2) # Only level1 (deepest) should have direct calls if level == 1: self.assertEqual(direct_calls, 2) else: self.assertEqual(direct_calls, 0) # Deeper levels should have lower cumulative time than higher levels # (since they don't include time from functions they call) if level == 1: # Deepest level with most time self.assertGreater(ct, 0) def test_alternating_call_patterns(self): """Test profiling with alternating call patterns.""" collector = PstatsCollector(sample_interval_usec=1000) # Simulate alternating execution paths pattern_frames = [ # Pattern A: path through func_a ( 1, [ MockFrameInfo("module.py", 10, "func_a"), MockFrameInfo("module.py", 30, "shared_func"), MockFrameInfo("main.py", 5, "main"), ], ), # Pattern B: path through func_b ( 1, [ MockFrameInfo("module.py", 20, "func_b"), MockFrameInfo("module.py", 30, "shared_func"), MockFrameInfo("main.py", 5, "main"), ], ), # Pattern A again ( 1, [ MockFrameInfo("module.py", 10, "func_a"), MockFrameInfo("module.py", 30, "shared_func"), MockFrameInfo("main.py", 5, "main"), ], ), # Pattern B again ( 1, [ MockFrameInfo("module.py", 20, "func_b"), MockFrameInfo("module.py", 30, "shared_func"), MockFrameInfo("main.py", 5, "main"), ], ), ] for frames in pattern_frames: collector.collect([frames]) collector.create_stats() # Check that both paths are recorded equally func_a_key = ("module.py", 10, "func_a") func_b_key = ("module.py", 20, "func_b") shared_key = ("module.py", 30, "shared_func") main_key = ("main.py", 5, "main") # func_a and func_b should each be directly executing twice self.assertEqual(collector.stats[func_a_key][0], 2) # direct_calls self.assertEqual(collector.stats[func_a_key][1], 2) # cumulative_calls self.assertEqual(collector.stats[func_b_key][0], 2) # direct_calls self.assertEqual(collector.stats[func_b_key][1], 2) # cumulative_calls # shared_func should appear in all samples (4 times) but never directly executing self.assertEqual(collector.stats[shared_key][0], 0) # direct_calls self.assertEqual(collector.stats[shared_key][1], 4) # cumulative_calls # main should appear in all samples but never directly executing self.assertEqual(collector.stats[main_key][0], 0) # direct_calls self.assertEqual(collector.stats[main_key][1], 4) # cumulative_calls def test_collapsed_stack_with_recursion(self): """Test collapsed stack collector with recursive patterns.""" collector = CollapsedStackCollector() # Recursive call pattern recursive_frames = [ ( 1, [ ("factorial.py", 10, "factorial"), ("factorial.py", 10, "factorial"), # recursive ("factorial.py", 10, "factorial"), # deeper ("main.py", 5, "main"), ], ), ( 1, [ ("factorial.py", 10, "factorial"), ("factorial.py", 10, "factorial"), # different depth ("main.py", 5, "main"), ], ), ] for frames in recursive_frames: collector.collect([frames]) # Should capture both call trees self.assertEqual(len(collector.call_trees), 2) # First tree should be longer (deeper recursion) tree1 = collector.call_trees[0] tree2 = collector.call_trees[1] # Trees should be different lengths due to different recursion depths self.assertNotEqual(len(tree1), len(tree2)) # Both should contain factorial calls self.assertTrue(any("factorial" in str(frame) for frame in tree1)) self.assertTrue(any("factorial" in str(frame) for frame in tree2)) # Function samples should count all occurrences factorial_key = ("factorial.py", 10, "factorial") main_key = ("main.py", 5, "main") # factorial appears 5 times total (3 + 2) self.assertEqual(collector.function_samples[factorial_key], 5) # main appears 2 times total self.assertEqual(collector.function_samples[main_key], 2) @requires_subprocess() @skip_if_not_supported class TestSampleProfilerIntegration(unittest.TestCase): @classmethod def setUpClass(cls): cls.test_script = ''' import time import os def slow_fibonacci(n): """Recursive fibonacci - should show up prominently in profiler.""" if n <= 1: return n return slow_fibonacci(n-1) + slow_fibonacci(n-2) def cpu_intensive_work(): """CPU intensive work that should show in profiler.""" result = 0 for i in range(10000): result += i * i if i % 100 == 0: result = result % 1000000 return result def medium_computation(): """Medium complexity function.""" result = 0 for i in range(100): result += i * i return result def fast_loop(): """Fast simple loop.""" total = 0 for i in range(50): total += i return total def nested_calls(): """Test nested function calls.""" def level1(): def level2(): return medium_computation() return level2() return level1() def main_loop(): """Main test loop with different execution paths.""" iteration = 0 while True: iteration += 1 # Different execution paths - focus on CPU intensive work if iteration % 3 == 0: # Very CPU intensive result = cpu_intensive_work() elif iteration % 5 == 0: # Expensive recursive operation result = slow_fibonacci(12) else: # Medium operation result = nested_calls() # No sleep - keep CPU busy if __name__ == "__main__": main_loop() ''' def test_sampling_basic_functionality(self): with ( test_subprocess(self.test_script) as proc, io.StringIO() as captured_output, mock.patch("sys.stdout", captured_output), ): try: profile.sample.sample( proc.pid, duration_sec=2, sample_interval_usec=1000, # 1ms show_summary=False, ) except PermissionError: self.skipTest("Insufficient permissions for remote profiling") output = captured_output.getvalue() # Basic checks on output self.assertIn("Captured", output) self.assertIn("samples", output) self.assertIn("Profile Stats", output) # Should see some of our test functions self.assertIn("slow_fibonacci", output) def test_sampling_with_pstats_export(self): pstats_out = tempfile.NamedTemporaryFile( suffix=".pstats", delete=False ) self.addCleanup(close_and_unlink, pstats_out) with test_subprocess(self.test_script) as proc: # Suppress profiler output when testing file export with ( io.StringIO() as captured_output, mock.patch("sys.stdout", captured_output), ): try: profile.sample.sample( proc.pid, duration_sec=1, filename=pstats_out.name, sample_interval_usec=10000, ) except PermissionError: self.skipTest( "Insufficient permissions for remote profiling" ) # Verify file was created and contains valid data self.assertTrue(os.path.exists(pstats_out.name)) self.assertGreater(os.path.getsize(pstats_out.name), 0) # Try to load the stats file with open(pstats_out.name, "rb") as f: stats_data = marshal.load(f) # Should be a dictionary with the sampled marker self.assertIsInstance(stats_data, dict) self.assertIn(("__sampled__",), stats_data) self.assertTrue(stats_data[("__sampled__",)]) # Should have some function data function_entries = [ k for k in stats_data.keys() if k != ("__sampled__",) ] self.assertGreater(len(function_entries), 0) def test_sampling_with_collapsed_export(self): collapsed_file = tempfile.NamedTemporaryFile( suffix=".txt", delete=False ) self.addCleanup(close_and_unlink, collapsed_file) with ( test_subprocess(self.test_script) as proc, ): # Suppress profiler output when testing file export with ( io.StringIO() as captured_output, mock.patch("sys.stdout", captured_output), ): try: profile.sample.sample( proc.pid, duration_sec=1, filename=collapsed_file.name, output_format="collapsed", sample_interval_usec=10000, ) except PermissionError: self.skipTest( "Insufficient permissions for remote profiling" ) # Verify file was created and contains valid data self.assertTrue(os.path.exists(collapsed_file.name)) self.assertGreater(os.path.getsize(collapsed_file.name), 0) # Check file format with open(collapsed_file.name, "r") as f: content = f.read() lines = content.strip().split("\n") self.assertGreater(len(lines), 0) # Each line should have format: stack_trace count for line in lines: parts = line.rsplit(" ", 1) self.assertEqual(len(parts), 2) stack_trace, count_str = parts self.assertGreater(len(stack_trace), 0) self.assertTrue(count_str.isdigit()) self.assertGreater(int(count_str), 0) # Stack trace should contain semicolon-separated entries if ";" in stack_trace: stack_parts = stack_trace.split(";") for part in stack_parts: # Each part should be file:function:line self.assertIn(":", part) def test_sampling_all_threads(self): with ( test_subprocess(self.test_script) as proc, # Suppress profiler output io.StringIO() as captured_output, mock.patch("sys.stdout", captured_output), ): try: profile.sample.sample( proc.pid, duration_sec=1, all_threads=True, sample_interval_usec=10000, show_summary=False, ) except PermissionError: self.skipTest("Insufficient permissions for remote profiling") # Just verify that sampling completed without error # We're not testing output format here @skip_if_not_supported @unittest.skipIf( sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, "Test only runs on Linux with process_vm_readv support", ) class TestSampleProfilerErrorHandling(unittest.TestCase): def test_invalid_pid(self): with self.assertRaises((OSError, RuntimeError)): profile.sample.sample(-1, duration_sec=1) def test_process_dies_during_sampling(self): with test_subprocess("import time; time.sleep(0.5); exit()") as proc: with ( io.StringIO() as captured_output, mock.patch("sys.stdout", captured_output), ): try: profile.sample.sample( proc.pid, duration_sec=2, # Longer than process lifetime sample_interval_usec=50000, ) except PermissionError: self.skipTest( "Insufficient permissions for remote profiling" ) output = captured_output.getvalue() self.assertIn("Error rate", output) def test_invalid_output_format(self): with self.assertRaises(ValueError): profile.sample.sample( os.getpid(), duration_sec=1, output_format="invalid_format", ) def test_invalid_output_format_with_mocked_profiler(self): """Test invalid output format with proper mocking to avoid permission issues.""" with mock.patch( "profile.sample.SampleProfiler" ) as mock_profiler_class: mock_profiler = mock.MagicMock() mock_profiler_class.return_value = mock_profiler with self.assertRaises(ValueError) as cm: profile.sample.sample( 12345, duration_sec=1, output_format="unknown_format", ) # Should raise ValueError with the invalid format name self.assertIn( "Invalid output format: unknown_format", str(cm.exception) ) def test_is_process_running(self): with test_subprocess("import time; time.sleep(1000)") as proc: try: profiler = SampleProfiler(pid=proc.pid, sample_interval_usec=1000, all_threads=False) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) self.assertTrue(profiler._is_process_running()) self.assertIsNotNone(profiler.unwinder.get_stack_trace()) proc.kill() proc.wait() # ValueError on MacOS (yeah I know), ProcessLookupError on Linux and Windows self.assertRaises((ValueError, ProcessLookupError), profiler.unwinder.get_stack_trace) # Exit the context manager to ensure the process is terminated self.assertFalse(profiler._is_process_running()) self.assertRaises((ValueError, ProcessLookupError), profiler.unwinder.get_stack_trace) @unittest.skipUnless(sys.platform == "linux", "Only valid on Linux") def test_esrch_signal_handling(self): with test_subprocess("import time; time.sleep(1000)") as proc: try: unwinder = _remote_debugging.RemoteUnwinder(proc.pid) except PermissionError: self.skipTest( "Insufficient permissions to read the stack trace" ) initial_trace = unwinder.get_stack_trace() self.assertIsNotNone(initial_trace) proc.kill() # Wait for the process to die and try to get another trace proc.wait() with self.assertRaises(ProcessLookupError): unwinder.get_stack_trace() class TestSampleProfilerCLI(unittest.TestCase): def test_cli_collapsed_format_validation(self): """Test that CLI properly validates incompatible options with collapsed format.""" test_cases = [ # Test sort options are invalid with collapsed ( ["profile.sample", "--collapsed", "--sort-nsamples", "12345"], "sort", ), ( ["profile.sample", "--collapsed", "--sort-tottime", "12345"], "sort", ), ( [ "profile.sample", "--collapsed", "--sort-cumtime", "12345", ], "sort", ), ( [ "profile.sample", "--collapsed", "--sort-sample-pct", "12345", ], "sort", ), ( [ "profile.sample", "--collapsed", "--sort-cumul-pct", "12345", ], "sort", ), ( ["profile.sample", "--collapsed", "--sort-name", "12345"], "sort", ), # Test limit option is invalid with collapsed (["profile.sample", "--collapsed", "-l", "20", "12345"], "limit"), ( ["profile.sample", "--collapsed", "--limit", "20", "12345"], "limit", ), # Test no-summary option is invalid with collapsed ( ["profile.sample", "--collapsed", "--no-summary", "12345"], "summary", ), ] for test_args, expected_error_keyword in test_cases: with ( mock.patch("sys.argv", test_args), mock.patch("sys.stderr", io.StringIO()) as mock_stderr, self.assertRaises(SystemExit) as cm, ): profile.sample.main() self.assertEqual(cm.exception.code, 2) # argparse error code error_msg = mock_stderr.getvalue() self.assertIn("error:", error_msg) self.assertIn("--pstats format", error_msg) def test_cli_default_collapsed_filename(self): """Test that collapsed format gets a default filename when not specified.""" test_args = ["profile.sample", "--collapsed", "12345"] with ( mock.patch("sys.argv", test_args), mock.patch("profile.sample.sample") as mock_sample, ): profile.sample.main() # Check that filename was set to default collapsed format mock_sample.assert_called_once() call_args = mock_sample.call_args[1] self.assertEqual(call_args["output_format"], "collapsed") self.assertEqual(call_args["filename"], "collapsed.12345.txt") def test_cli_custom_output_filenames(self): """Test custom output filenames for both formats.""" test_cases = [ ( ["profile.sample", "--pstats", "-o", "custom.pstats", "12345"], "custom.pstats", "pstats", ), ( ["profile.sample", "--collapsed", "-o", "custom.txt", "12345"], "custom.txt", "collapsed", ), ] for test_args, expected_filename, expected_format in test_cases: with ( mock.patch("sys.argv", test_args), mock.patch("profile.sample.sample") as mock_sample, ): profile.sample.main() mock_sample.assert_called_once() call_args = mock_sample.call_args[1] self.assertEqual(call_args["filename"], expected_filename) self.assertEqual(call_args["output_format"], expected_format) def test_cli_missing_required_arguments(self): """Test that CLI requires PID argument.""" with ( mock.patch("sys.argv", ["profile.sample"]), mock.patch("sys.stderr", io.StringIO()), ): with self.assertRaises(SystemExit): profile.sample.main() def test_cli_mutually_exclusive_format_options(self): """Test that pstats and collapsed options are mutually exclusive.""" with ( mock.patch( "sys.argv", ["profile.sample", "--pstats", "--collapsed", "12345"], ), mock.patch("sys.stderr", io.StringIO()), ): with self.assertRaises(SystemExit): profile.sample.main() def test_argument_parsing_basic(self): test_args = ["profile.sample", "12345"] with ( mock.patch("sys.argv", test_args), mock.patch("profile.sample.sample") as mock_sample, ): profile.sample.main() mock_sample.assert_called_once_with( 12345, sample_interval_usec=100, duration_sec=10, filename=None, all_threads=False, limit=15, sort=2, show_summary=True, output_format="pstats", realtime_stats=False, ) def test_sort_options(self): sort_options = [ ("--sort-nsamples", 0), ("--sort-tottime", 1), ("--sort-cumtime", 2), ("--sort-sample-pct", 3), ("--sort-cumul-pct", 4), ("--sort-name", -1), ] for option, expected_sort_value in sort_options: test_args = ["profile.sample", option, "12345"] with ( mock.patch("sys.argv", test_args), mock.patch("profile.sample.sample") as mock_sample, ): profile.sample.main() mock_sample.assert_called_once() call_args = mock_sample.call_args[1] self.assertEqual( call_args["sort"], expected_sort_value, ) mock_sample.reset_mock() if __name__ == "__main__": unittest.main()