diff --git a/rich/console.py b/rich/console.py index 8c462d54..7800c2a5 100644 --- a/rich/console.py +++ b/rich/console.py @@ -548,31 +548,31 @@ class ConsoleThreadLocals(threading.local): def __getstate__(self): """Support for pickle serialization. - + Returns the serializable state of the thread-local object. Note: This loses the thread-local nature, but allows serialization for caching and other use cases. - + Returns: Dict[str, Any]: The serializable state containing theme_stack, buffer, and buffer_index. """ return { - 'theme_stack': self.theme_stack, - 'buffer': self.buffer.copy(), # Create a copy to be safe - 'buffer_index': self.buffer_index + "theme_stack": self.theme_stack, + "buffer": self.buffer.copy(), # Create a copy to be safe + "buffer_index": self.buffer_index, } - + def __setstate__(self, state): """Support for pickle deserialization. - + Args: state (Dict[str, Any]): The state dictionary from __getstate__ """ # Restore the state - self.theme_stack = state['theme_stack'] - self.buffer = state['buffer'] - self.buffer_index = state['buffer_index'] + self.theme_stack = state["theme_stack"] + self.buffer = state["buffer"] + self.buffer_index = state["buffer_index"] class RenderHook(ABC): @@ -2638,36 +2638,37 @@ class Console: ) with open(path, "w", encoding="utf-8") as write_file: write_file.write(svg) - + def __getstate__(self): """Support for pickle serialization. - + Returns the serializable state of the Console object. Note: Thread locks are recreated during deserialization. - + Returns: Dict[str, Any]: The serializable state of the Console. """ # Get all instance attributes except locks state = self.__dict__.copy() - + # Remove the unpickleable locks - state.pop('_lock', None) - state.pop('_record_buffer_lock', None) - + state.pop("_lock", None) + state.pop("_record_buffer_lock", None) + return state - + def __setstate__(self, state): """Support for pickle deserialization. - + Args: state (Dict[str, Any]): The state dictionary from __getstate__ """ # Restore the state self.__dict__.update(state) - + # Recreate the locks import threading + self._lock = threading.RLock() self._record_buffer_lock = threading.RLock() diff --git a/tests/test_pickle_fix.py b/tests/test_pickle_fix.py index be3365e9..e13fbea9 100644 --- a/tests/test_pickle_fix.py +++ b/tests/test_pickle_fix.py @@ -8,7 +8,7 @@ import sys import os # Add the current directory to the path so we can import the modified rich -sys.path.insert(0, '/tmp/rich') +sys.path.insert(0, "/tmp/rich") from rich.console import Console from rich.segment import Segment @@ -17,31 +17,31 @@ from rich.segment import Segment def test_basic_pickle(): """Test basic pickle functionality of ConsoleThreadLocals.""" print("๐Ÿงช Testing basic ConsoleThreadLocals pickle functionality...") - + console = Console() ctl = console._thread_locals - + # Add some data to make it more realistic ctl.buffer.append(Segment("test")) ctl.buffer_index = 1 - + try: # Test serialization pickled_data = pickle.dumps(ctl) print(" โœ… Serialization successful") - + # Test deserialization restored_ctl = pickle.loads(pickled_data) print(" โœ… Deserialization successful") - + # Verify state preservation assert type(restored_ctl.theme_stack) == type(ctl.theme_stack) assert restored_ctl.buffer == ctl.buffer assert restored_ctl.buffer_index == ctl.buffer_index print(" โœ… State preservation verified") - + return True - + except Exception as e: print(f" โŒ Test failed: {e}") return False @@ -50,29 +50,29 @@ def test_basic_pickle(): def test_langflow_compatibility(): """Test compatibility with Langflow's caching mechanism.""" print("๐Ÿ”ง Testing Langflow cache compatibility...") - + console = Console() - + # Simulate Langflow's cache data structure result_dict = { "result": console, "type": type(console), } - + try: # This is what Langflow's cache service tries to do pickled = pickle.dumps(result_dict) print(" โœ… Complex object serialization successful") - + restored = pickle.loads(pickled) print(" โœ… Complex object deserialization successful") - + # Verify the console is properly restored assert type(restored["result"]) == type(console) print(" โœ… Object type preservation verified") - + return True - + except Exception as e: print(f" โŒ Test failed: {e}") return False @@ -81,28 +81,28 @@ def test_langflow_compatibility(): def test_thread_local_behavior(): """Test that thread-local behavior works after unpickling.""" print("๐Ÿ”„ Testing thread-local behavior preservation...") - + import threading import time - + console = Console() ctl = console._thread_locals - + # Serialize and deserialize try: pickled = pickle.dumps(ctl) restored_ctl = pickle.loads(pickled) - + # Test that we can still use the restored object restored_ctl.buffer.append(Segment("thread test")) restored_ctl.buffer_index = 5 - + print(f" โœ… Restored object is functional") print(f" ๐Ÿ“Š Buffer length: {len(restored_ctl.buffer)}") print(f" ๐Ÿ“Š Buffer index: {restored_ctl.buffer_index}") - + return True - + except Exception as e: print(f" โŒ Test failed: {e}") return False @@ -111,24 +111,24 @@ def test_thread_local_behavior(): def main(): """Run all tests.""" print("๐Ÿš€ Starting Rich ConsoleThreadLocals pickle fix tests...\n") - + tests = [ test_basic_pickle, test_langflow_compatibility, test_thread_local_behavior, ] - + passed = 0 total = len(tests) - + for test in tests: if test(): passed += 1 print() # Add spacing between tests - + print("=" * 50) print(f"๐Ÿ“Š Test Results: {passed}/{total} tests passed") - + if passed == total: print("๐ŸŽ‰ All tests passed! The pickle fix is working correctly.") return 0 @@ -138,4 +138,4 @@ def main(): if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/tests/test_pickle_support.py b/tests/test_pickle_support.py index 660b08f5..c05f183d 100644 --- a/tests/test_pickle_support.py +++ b/tests/test_pickle_support.py @@ -10,17 +10,17 @@ def test_console_thread_locals_pickle(): """Test that ConsoleThreadLocals can be pickled and unpickled.""" console = Console() ctl = console._thread_locals - + # Add some data to make it more realistic ctl.buffer.append(Segment("test")) ctl.buffer_index = 1 - + # Test serialization pickled_data = pickle.dumps(ctl) - + # Test deserialization restored_ctl = pickle.loads(pickled_data) - + # Verify state preservation assert type(restored_ctl.theme_stack) == type(ctl.theme_stack) assert restored_ctl.buffer == ctl.buffer @@ -30,54 +30,50 @@ def test_console_thread_locals_pickle(): def test_console_pickle(): """Test that Console objects can be pickled and unpickled.""" console = Console(width=120, height=40) - + # Test serialization pickled_data = pickle.dumps(console) - + # Test deserialization restored_console = pickle.loads(pickled_data) - + # Verify basic properties are preserved assert restored_console.width == console.width assert restored_console.height == console.height assert restored_console._color_system == console._color_system - + # Verify locks are recreated - assert hasattr(restored_console, '_lock') - assert hasattr(restored_console, '_record_buffer_lock') - + assert hasattr(restored_console, "_lock") + assert hasattr(restored_console, "_record_buffer_lock") + # Verify the console is functional with restored_console.capture() as capture: restored_console.print("Test message") - + assert "Test message" in capture.get() def test_console_with_complex_state_pickle(): """Test console pickle with more complex state.""" - theme = Theme({ - "info": "cyan", - "warning": "yellow", - "error": "red bold" - }) - + theme = Theme({"info": "cyan", "warning": "yellow", "error": "red bold"}) + console = Console(theme=theme, record=True) - + # Add some content console.print("Info message", style="info") console.print("Warning message", style="warning") console.record = False # Stop recording - + # Test serialization pickled_data = pickle.dumps(console) - + # Test deserialization restored_console = pickle.loads(pickled_data) - + # Verify theme is preserved assert restored_console.get_style("info").color.name == "cyan" assert restored_console.get_style("warning").color.name == "yellow" - + # Verify console functionality assert restored_console.record is False @@ -85,28 +81,28 @@ def test_console_with_complex_state_pickle(): def test_cache_simulation(): """Test cache-like usage scenario (similar to Langflow).""" console = Console() - + # Simulate caching scenario like Langflow cache_data = { "result": console, "type": type(console), - "metadata": {"created": "2025-09-25", "version": "1.0"} + "metadata": {"created": "2025-09-25", "version": "1.0"}, } - + # This should not raise any pickle errors pickled = pickle.dumps(cache_data) restored = pickle.loads(pickled) - + # Verify restoration assert type(restored["result"]) == Console assert restored["type"] == Console assert restored["metadata"]["created"] == "2025-09-25" - + # Verify the restored console works restored_console = restored["result"] with restored_console.capture() as capture: restored_console.print("Cache test successful") - + assert "Cache test successful" in capture.get() @@ -116,28 +112,28 @@ def test_nested_console_pickle(): container = { "console": Console(width=100), "name": "test_container", - "data": [1, 2, 3] + "data": [1, 2, 3], } - + # Should be able to pickle dict containing Console pickled = pickle.dumps(container) restored = pickle.loads(pickled) - + assert restored["name"] == "test_container" assert restored["data"] == [1, 2, 3] assert restored["console"].width == 100 - + # Verify console functionality with restored["console"].capture() as capture: restored["console"].print("Nested test") - + assert "Nested test" in capture.get() if __name__ == "__main__": # Run tests manually if called directly import sys - + tests = [ test_console_thread_locals_pickle, test_console_pickle, @@ -145,7 +141,7 @@ if __name__ == "__main__": test_cache_simulation, test_nested_console_pickle, ] - + passed = 0 for test in tests: try: @@ -154,6 +150,6 @@ if __name__ == "__main__": passed += 1 except Exception as e: print(f"โŒ {test.__name__} failed: {e}") - + print(f"\n๐Ÿ“Š Results: {passed}/{len(tests)} tests passed") - sys.exit(0 if passed == len(tests) else 1) \ No newline at end of file + sys.exit(0 if passed == len(tests) else 1)