bpo-35813: Tests and docs for shared_memory (#11816)

* Added tests for shared_memory submodule.

* Added tests for ShareableList.

* Fix bug in allocationn size during creation of empty ShareableList illuminated by existing test run on Linux.

* Initial set of docs for shared_memory module.

* Added docs for ShareableList, added doctree entry for shared_memory submodule, name refactoring for greater clarity.

* Added examples to SharedMemoryManager docs, for ease of documentation switched away from exclusively registered functions to some explicit methods on SharedMemoryManager.

* Wording tweaks to docs.

* Fix test failures on Windows.

* Added tests around SharedMemoryManager.

* Documentation tweaks.

* Fix inappropriate test on Windows.

* Further documentation tweaks.

* Fix bare exception.

* Removed __copyright__.

* Fixed typo in doc, removed comment.

* Updated SharedMemoryManager preliminary tests to reflect change of not supporting all registered functions on SyncManager.

* Added Sphinx doctest run controls.

* CloseHandle should be in a finally block in case MapViewOfFile fails.

* Missed opportunity to use with statement.

* Switch to self.addCleanup to spare long try/finally blocks and save one indentation, change to use decorator to skip test instead.

* Simplify the posixshmem extension module.

Provide shm_open() and shm_unlink() functions.  Move other
functionality into the shared_memory.py module.

* Added to doc around size parameter of SharedMemory.

* Changed PosixSharedMemory.size to use os.fstat.

* Change SharedMemory.buf to a read-only property as well as NamedSharedMemory.size.

* Marked as provisional per PEP411 in docstring.

* Changed SharedMemoryTracker to be private.

* Removed registered Proxy Objects from SharedMemoryManager.

* Removed shareable_wrap().

* Removed shareable_wrap() and dangling references to it.

* For consistency added __reduce__ to key classes.

* Fix for potential race condition on Windows for O_CREX.

* Remove unused imports.

* Update access to kernel32 on Windows per feedback from eryksun.

* Moved kernel32 calls to _winapi.

* Removed ShareableList.copy as redundant.

* Changes to _winapi use from eryksun feedback.

* Adopt simpler SharedMemory API, collapsing PosixSharedMemory and WindowsNamedSharedMemory into one.

* Fix missing docstring on class, add test for ignoring size when attaching.

* Moved SharedMemoryManager to managers module, tweak to fragile test.

* Tweak to exception in OpenFileMapping suggested by eryksun.

* Mark a few dangling bits as private as suggested by Giampaolo.
This commit is contained in:
Davin Potts 2019-02-23 22:08:16 -06:00 committed by GitHub
parent d610116a2e
commit e895de3e7f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 1510 additions and 1028 deletions

View file

@ -1,5 +1,5 @@
#
# Module providing the `SyncManager` class for dealing
# Module providing manager classes for dealing
# with shared objects
#
# multiprocessing/managers.py
@ -8,7 +8,8 @@
# Licensed to PSF under a Contributor Agreement.
#
__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token' ]
__all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token',
'SharedMemoryManager' ]
#
# Imports
@ -19,6 +20,7 @@ import threading
import array
import queue
import time
from os import getpid
from traceback import format_exc
@ -28,6 +30,11 @@ from . import pool
from . import process
from . import util
from . import get_context
try:
from . import shared_memory
HAS_SHMEM = True
except ImportError:
HAS_SHMEM = False
#
# Register some things for pickling
@ -1200,3 +1207,143 @@ SyncManager.register('Namespace', Namespace, NamespaceProxy)
# types returned by methods of PoolProxy
SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False)
SyncManager.register('AsyncResult', create_method=False)
#
# Definition of SharedMemoryManager and SharedMemoryServer
#
if HAS_SHMEM:
class _SharedMemoryTracker:
"Manages one or more shared memory segments."
def __init__(self, name, segment_names=[]):
self.shared_memory_context_name = name
self.segment_names = segment_names
def register_segment(self, segment_name):
"Adds the supplied shared memory block name to tracker."
util.debug(f"Register segment {segment_name!r} in pid {getpid()}")
self.segment_names.append(segment_name)
def destroy_segment(self, segment_name):
"""Calls unlink() on the shared memory block with the supplied name
and removes it from the list of blocks being tracked."""
util.debug(f"Destroy segment {segment_name!r} in pid {getpid()}")
self.segment_names.remove(segment_name)
segment = shared_memory.SharedMemory(segment_name)
segment.close()
segment.unlink()
def unlink(self):
"Calls destroy_segment() on all tracked shared memory blocks."
for segment_name in self.segment_names[:]:
self.destroy_segment(segment_name)
def __del__(self):
util.debug(f"Call {self.__class__.__name__}.__del__ in {getpid()}")
self.unlink()
def __getstate__(self):
return (self.shared_memory_context_name, self.segment_names)
def __setstate__(self, state):
self.__init__(*state)
class SharedMemoryServer(Server):
public = Server.public + \
['track_segment', 'release_segment', 'list_segments']
def __init__(self, *args, **kwargs):
Server.__init__(self, *args, **kwargs)
self.shared_memory_context = \
_SharedMemoryTracker(f"shmm_{self.address}_{getpid()}")
util.debug(f"SharedMemoryServer started by pid {getpid()}")
def create(self, c, typeid, *args, **kwargs):
"""Create a new distributed-shared object (not backed by a shared
memory block) and return its id to be used in a Proxy Object."""
# Unless set up as a shared proxy, don't make shared_memory_context
# a standard part of kwargs. This makes things easier for supplying
# simple functions.
if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"):
kwargs['shared_memory_context'] = self.shared_memory_context
return Server.create(self, c, typeid, *args, **kwargs)
def shutdown(self, c):
"Call unlink() on all tracked shared memory, terminate the Server."
self.shared_memory_context.unlink()
return Server.shutdown(self, c)
def track_segment(self, c, segment_name):
"Adds the supplied shared memory block name to Server's tracker."
self.shared_memory_context.register_segment(segment_name)
def release_segment(self, c, segment_name):
"""Calls unlink() on the shared memory block with the supplied name
and removes it from the tracker instance inside the Server."""
self.shared_memory_context.destroy_segment(segment_name)
def list_segments(self, c):
"""Returns a list of names of shared memory blocks that the Server
is currently tracking."""
return self.shared_memory_context.segment_names
class SharedMemoryManager(BaseManager):
"""Like SyncManager but uses SharedMemoryServer instead of Server.
It provides methods for creating and returning SharedMemory instances
and for creating a list-like object (ShareableList) backed by shared
memory. It also provides methods that create and return Proxy Objects
that support synchronization across processes (i.e. multi-process-safe
locks and semaphores).
"""
_Server = SharedMemoryServer
def __init__(self, *args, **kwargs):
BaseManager.__init__(self, *args, **kwargs)
util.debug(f"{self.__class__.__name__} created by pid {getpid()}")
def __del__(self):
util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}")
pass
def get_server(self):
'Better than monkeypatching for now; merge into Server ultimately'
if self._state.value != State.INITIAL:
if self._state.value == State.STARTED:
raise ProcessError("Already started SharedMemoryServer")
elif self._state.value == State.SHUTDOWN:
raise ProcessError("SharedMemoryManager has shut down")
else:
raise ProcessError(
"Unknown state {!r}".format(self._state.value))
return self._Server(self._registry, self._address,
self._authkey, self._serializer)
def SharedMemory(self, size):
"""Returns a new SharedMemory instance with the specified size in
bytes, to be tracked by the manager."""
with self._Client(self._address, authkey=self._authkey) as conn:
sms = shared_memory.SharedMemory(None, create=True, size=size)
try:
dispatch(conn, None, 'track_segment', (sms.name,))
except BaseException as e:
sms.unlink()
raise e
return sms
def ShareableList(self, sequence):
"""Returns a new ShareableList instance populated with the values
from the input sequence, to be tracked by the manager."""
with self._Client(self._address, authkey=self._authkey) as conn:
sl = shared_memory.ShareableList(sequence)
try:
dispatch(conn, None, 'track_segment', (sl.shm.name,))
except BaseException as e:
sl.shm.unlink()
raise e
return sl

View file

@ -1,228 +1,234 @@
"Provides shared memory for direct access across processes."
"""Provides shared memory for direct access across processes.
The API of this package is currently provisional. Refer to the
documentation for details.
"""
__all__ = [ 'SharedMemory', 'PosixSharedMemory', 'WindowsNamedSharedMemory',
'ShareableList', 'shareable_wrap',
'SharedMemoryServer', 'SharedMemoryManager', 'SharedMemoryTracker' ]
__all__ = [ 'SharedMemory', 'ShareableList' ]
from functools import reduce
from functools import partial
import mmap
from .managers import DictProxy, SyncManager, Server
from . import util
import os
import random
import errno
import struct
import sys
try:
from _posixshmem import _PosixSharedMemory, Error, ExistentialError, O_CREX
except ImportError as ie:
if os.name != "nt":
# On Windows, posixshmem is not required to be available.
raise ie
else:
_PosixSharedMemory = object
class ExistentialError(BaseException): pass
class Error(BaseException): pass
O_CREX = -1
import secrets
if os.name == "nt":
import _winapi
_USE_POSIX = False
else:
import _posixshmem
_USE_POSIX = True
class WindowsNamedSharedMemory:
_O_CREX = os.O_CREAT | os.O_EXCL
def __init__(self, name, flags=None, mode=None, size=None, read_only=False):
if name is None:
name = f'wnsm_{os.getpid()}_{random.randrange(100000)}'
# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
_SHM_SAFE_NAME_LENGTH = 14
self._mmap = mmap.mmap(-1, size, tagname=name)
self.buf = memoryview(self._mmap)
self.name = name
self.size = size
def __repr__(self):
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
def close(self):
self.buf.release()
self._mmap.close()
def unlink(self):
"""Windows ensures that destruction of the last reference to this
named shared memory block will result in the release of this memory."""
pass
# Shared memory block name prefix
if _USE_POSIX:
_SHM_NAME_PREFIX = 'psm_'
else:
_SHM_NAME_PREFIX = 'wnsm_'
class PosixSharedMemory(_PosixSharedMemory):
def __init__(self, name, flags=None, mode=None, size=None, read_only=False):
if name and (flags is None):
_PosixSharedMemory.__init__(self, name)
else:
if name is None:
name = f'psm_{os.getpid()}_{random.randrange(100000)}'
_PosixSharedMemory.__init__(self, name, flags=O_CREX, size=size)
self._mmap = mmap.mmap(self.fd, self.size)
self.buf = memoryview(self._mmap)
def __repr__(self):
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
def close(self):
self.buf.release()
self._mmap.close()
self.close_fd()
def _make_filename():
"Create a random filename for the shared memory object."
# number of random bytes to use for name
nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
assert len(name) <= _SHM_SAFE_NAME_LENGTH
return name
class SharedMemory:
"""Creates a new shared memory block or attaches to an existing
shared memory block.
def __new__(cls, *args, **kwargs):
if os.name == 'nt':
cls = WindowsNamedSharedMemory
else:
cls = PosixSharedMemory
return cls(*args, **kwargs)
Every shared memory block is assigned a unique name. This enables
one process to create a shared memory block with a particular name
so that a different process can attach to that same shared memory
block using that same name.
As a resource for sharing data across processes, shared memory blocks
may outlive the original process that created them. When one process
no longer needs access to a shared memory block that might still be
needed by other processes, the close() method should be called.
When a shared memory block is no longer needed by any process, the
unlink() method should be called to ensure proper cleanup."""
def shareable_wrap(
existing_obj=None,
shmem_name=None,
cls=None,
shape=(0,),
strides=None,
dtype=None,
format=None,
**kwargs
):
augmented_kwargs = dict(kwargs)
extras = dict(shape=shape, strides=strides, dtype=dtype, format=format)
for key, value in extras.items():
if value is not None:
augmented_kwargs[key] = value
# Defaults; enables close() and unlink() to run without errors.
_name = None
_fd = -1
_mmap = None
_buf = None
_flags = os.O_RDWR
_mode = 0o600
if existing_obj is not None:
existing_type = getattr(
existing_obj,
"_proxied_type",
type(existing_obj)
)
def __init__(self, name=None, create=False, size=0):
if not size >= 0:
raise ValueError("'size' must be a positive integer")
if create:
self._flags = _O_CREX | os.O_RDWR
if name is None and not self._flags & os.O_EXCL:
raise ValueError("'name' can only be None if create=True")
#agg = existing_obj.itemsize
#size = [ agg := i * agg for i in existing_obj.shape ][-1]
# TODO: replace use of reduce below with above 2 lines once available
size = reduce(
lambda x, y: x * y,
existing_obj.shape,
existing_obj.itemsize
)
if _USE_POSIX:
else:
assert shmem_name is not None
existing_type = cls
size = 1
# POSIX Shared Memory
shm = SharedMemory(shmem_name, size=size)
class CustomShareableProxy(existing_type):
def __init__(self, *args, buffer=None, **kwargs):
# If copy method called, prevent recursion from replacing _shm.
if not hasattr(self, "_shm"):
self._shm = shm
self._proxied_type = existing_type
if name is None:
while True:
name = _make_filename()
try:
self._fd = _posixshmem.shm_open(
name,
self._flags,
mode=self._mode
)
except FileExistsError:
continue
self._name = name
break
else:
# _proxied_type only used in pickling.
assert hasattr(self, "_proxied_type")
self._fd = _posixshmem.shm_open(
name,
self._flags,
mode=self._mode
)
self._name = name
try:
existing_type.__init__(self, *args, **kwargs)
except:
pass
if create and size:
os.ftruncate(self._fd, size)
stats = os.fstat(self._fd)
size = stats.st_size
self._mmap = mmap.mmap(self._fd, size)
except OSError:
self.unlink()
raise
def __repr__(self):
if not hasattr(self, "_shm"):
return existing_type.__repr__(self)
formatted_pairs = (
"%s=%r" % kv for kv in self._build_state(self).items()
)
return f"{self.__class__.__name__}({', '.join(formatted_pairs)})"
else:
#def __getstate__(self):
# if not hasattr(self, "_shm"):
# return existing_type.__getstate__(self)
# state = self._build_state(self)
# return state
# Windows Named Shared Memory
#def __setstate__(self, state):
# self.__init__(**state)
if create:
while True:
temp_name = _make_filename() if name is None else name
# Create and reserve shared memory block with this name
# until it can be attached to by mmap.
h_map = _winapi.CreateFileMapping(
_winapi.INVALID_HANDLE_VALUE,
_winapi.NULL,
_winapi.PAGE_READWRITE,
(size >> 32) & 0xFFFFFFFF,
size & 0xFFFFFFFF,
temp_name
)
try:
last_error_code = _winapi.GetLastError()
if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
if name is not None:
raise FileExistsError(
errno.EEXIST,
os.strerror(errno.EEXIST),
name,
_winapi.ERROR_ALREADY_EXISTS
)
else:
continue
self._mmap = mmap.mmap(-1, size, tagname=temp_name)
finally:
_winapi.CloseHandle(h_map)
self._name = temp_name
break
def __reduce__(self):
return (
shareable_wrap,
(
None,
self._shm.name,
self._proxied_type,
self.shape,
self.strides,
self.dtype.str if hasattr(self, "dtype") else None,
getattr(self, "format", None),
),
)
def copy(self):
dupe = existing_type.copy(self)
if not hasattr(dupe, "_shm"):
dupe = shareable_wrap(dupe)
return dupe
@staticmethod
def _build_state(existing_obj, generics_only=False):
state = {
"shape": existing_obj.shape,
"strides": existing_obj.strides,
}
try:
state["dtype"] = existing_obj.dtype
except AttributeError:
else:
self._name = name
# Dynamically determine the existing named shared memory
# block's size which is likely a multiple of mmap.PAGESIZE.
h_map = _winapi.OpenFileMapping(
_winapi.FILE_MAP_READ,
False,
name
)
try:
state["format"] = existing_obj.format
except AttributeError:
pass
if not generics_only:
try:
state["shmem_name"] = existing_obj._shm.name
state["cls"] = existing_type
except AttributeError:
pass
return state
p_buf = _winapi.MapViewOfFile(
h_map,
_winapi.FILE_MAP_READ,
0,
0,
0
)
finally:
_winapi.CloseHandle(h_map)
size = _winapi.VirtualQuerySize(p_buf)
self._mmap = mmap.mmap(-1, size, tagname=name)
proxy_type = type(
f"{existing_type.__name__}Shareable",
CustomShareableProxy.__bases__,
dict(CustomShareableProxy.__dict__),
)
self._size = size
self._buf = memoryview(self._mmap)
if existing_obj is not None:
def __del__(self):
try:
proxy_obj = proxy_type(
buffer=shm.buf,
**proxy_type._build_state(existing_obj)
)
except Exception:
proxy_obj = proxy_type(
buffer=shm.buf,
**proxy_type._build_state(existing_obj, True)
)
self.close()
except OSError:
pass
mveo = memoryview(existing_obj)
proxy_obj._shm.buf[:mveo.nbytes] = mveo.tobytes()
def __reduce__(self):
return (
self.__class__,
(
self.name,
False,
self.size,
),
)
else:
proxy_obj = proxy_type(buffer=shm.buf, **augmented_kwargs)
def __repr__(self):
return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
return proxy_obj
@property
def buf(self):
"A memoryview of contents of the shared memory block."
return self._buf
@property
def name(self):
"Unique name that identifies the shared memory block."
return self._name
@property
def size(self):
"Size in bytes."
return self._size
def close(self):
"""Closes access to the shared memory from this instance but does
not destroy the shared memory block."""
if self._buf is not None:
self._buf.release()
self._buf = None
if self._mmap is not None:
self._mmap.close()
self._mmap = None
if _USE_POSIX and self._fd >= 0:
os.close(self._fd)
self._fd = -1
def unlink(self):
"""Requests that the underlying shared memory block be destroyed.
In order to ensure proper cleanup of resources, unlink should be
called once (and only once) across all processes which have access
to the shared memory block."""
if _USE_POSIX and self.name:
_posixshmem.shm_unlink(self.name)
encoding = "utf8"
_encoding = "utf8"
class ShareableList:
"""Pattern for a mutable list-like object shareable via a shared
@ -234,8 +240,7 @@ class ShareableList:
packing format for any storable value must require no more than 8
characters to describe its format."""
# TODO: Adjust for discovered word size of machine.
types_mapping = {
_types_mapping = {
int: "q",
float: "d",
bool: "xxxxxxx?",
@ -243,17 +248,17 @@ class ShareableList:
bytes: "%ds",
None.__class__: "xxxxxx?x",
}
alignment = 8
back_transform_codes = {
_alignment = 8
_back_transforms_mapping = {
0: lambda value: value, # int, float, bool
1: lambda value: value.rstrip(b'\x00').decode(encoding), # str
1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str
2: lambda value: value.rstrip(b'\x00'), # bytes
3: lambda _value: None, # None
}
@staticmethod
def _extract_recreation_code(value):
"""Used in concert with back_transform_codes to convert values
"""Used in concert with _back_transforms_mapping to convert values
into the appropriate Python objects when retrieving them from
the list as well as when storing them."""
if not isinstance(value, (str, bytes, None.__class__)):
@ -265,36 +270,42 @@ class ShareableList:
else:
return 3 # NoneType
def __init__(self, iterable=None, name=None):
if iterable is not None:
def __init__(self, sequence=None, *, name=None):
if sequence is not None:
_formats = [
self.types_mapping[type(item)]
self._types_mapping[type(item)]
if not isinstance(item, (str, bytes))
else self.types_mapping[type(item)] % (
self.alignment * (len(item) // self.alignment + 1),
else self._types_mapping[type(item)] % (
self._alignment * (len(item) // self._alignment + 1),
)
for item in iterable
for item in sequence
]
self._list_len = len(_formats)
assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
self._allocated_bytes = tuple(
self.alignment if fmt[-1] != "s" else int(fmt[:-1])
self._alignment if fmt[-1] != "s" else int(fmt[:-1])
for fmt in _formats
)
_back_transform_codes = [
self._extract_recreation_code(item) for item in iterable
_recreation_codes = [
self._extract_recreation_code(item) for item in sequence
]
requested_size = struct.calcsize(
"q" + self._format_size_metainfo + "".join(_formats)
"q" + self._format_size_metainfo +
"".join(_formats) +
self._format_packing_metainfo +
self._format_back_transform_codes
)
else:
requested_size = 1 # Some platforms require > 0.
requested_size = 8 # Some platforms require > 0.
self.shm = SharedMemory(name, size=requested_size)
if name is not None and sequence is None:
self.shm = SharedMemory(name)
else:
self.shm = SharedMemory(name, create=True, size=requested_size)
if iterable is not None:
_enc = encoding
if sequence is not None:
_enc = _encoding
struct.pack_into(
"q" + self._format_size_metainfo,
self.shm.buf,
@ -306,7 +317,7 @@ class ShareableList:
"".join(_formats),
self.shm.buf,
self._offset_data_start,
*(v.encode(_enc) if isinstance(v, str) else v for v in iterable)
*(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
)
struct.pack_into(
self._format_packing_metainfo,
@ -318,7 +329,7 @@ class ShareableList:
self._format_back_transform_codes,
self.shm.buf,
self._offset_back_transform_codes,
*(_back_transform_codes)
*(_recreation_codes)
)
else:
@ -341,7 +352,7 @@ class ShareableList:
self._offset_packing_formats + position * 8
)[0]
fmt = v.rstrip(b'\x00')
fmt_as_str = fmt.decode(encoding)
fmt_as_str = fmt.decode(_encoding)
return fmt_as_str
@ -357,7 +368,7 @@ class ShareableList:
self.shm.buf,
self._offset_back_transform_codes + position
)[0]
transform_function = self.back_transform_codes[transform_code]
transform_function = self._back_transforms_mapping[transform_code]
return transform_function
@ -373,7 +384,7 @@ class ShareableList:
"8s",
self.shm.buf,
self._offset_packing_formats + position * 8,
fmt_as_str.encode(encoding)
fmt_as_str.encode(_encoding)
)
transform_code = self._extract_recreation_code(value)
@ -410,14 +421,14 @@ class ShareableList:
raise IndexError("assignment index out of range")
if not isinstance(value, (str, bytes)):
new_format = self.types_mapping[type(value)]
new_format = self._types_mapping[type(value)]
else:
if len(value) > self._allocated_bytes[position]:
raise ValueError("exceeds available storage for existing str")
if current_format[-1] == "s":
new_format = current_format
else:
new_format = self.types_mapping[str] % (
new_format = self._types_mapping[str] % (
self._allocated_bytes[position],
)
@ -426,16 +437,24 @@ class ShareableList:
new_format,
value
)
value = value.encode(encoding) if isinstance(value, str) else value
value = value.encode(_encoding) if isinstance(value, str) else value
struct.pack_into(new_format, self.shm.buf, offset, value)
def __reduce__(self):
return partial(self.__class__, name=self.shm.name), ()
def __len__(self):
return struct.unpack_from("q", self.shm.buf, 0)[0]
def __repr__(self):
return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
@property
def format(self):
"The struct packing format used by all currently stored values."
return "".join(self._get_packing_format(i) for i in range(self._list_len))
return "".join(
self._get_packing_format(i) for i in range(self._list_len)
)
@property
def _format_size_metainfo(self):
@ -464,12 +483,6 @@ class ShareableList:
def _offset_back_transform_codes(self):
return self._offset_packing_formats + self._list_len * 8
@classmethod
def copy(cls, self):
"L.copy() -> ShareableList -- a shallow copy of L."
return cls(self)
def count(self, value):
"L.count(value) -> integer -- return number of occurrences of value."
@ -484,90 +497,3 @@ class ShareableList:
return position
else:
raise ValueError(f"{value!r} not in this container")
class SharedMemoryTracker:
"Manages one or more shared memory segments."
def __init__(self, name, segment_names=[]):
self.shared_memory_context_name = name
self.segment_names = segment_names
def register_segment(self, segment):
util.debug(f"Registering segment {segment.name!r} in pid {os.getpid()}")
self.segment_names.append(segment.name)
def destroy_segment(self, segment_name):
util.debug(f"Destroying segment {segment_name!r} in pid {os.getpid()}")
self.segment_names.remove(segment_name)
segment = SharedMemory(segment_name, size=1)
segment.close()
segment.unlink()
def unlink(self):
for segment_name in self.segment_names[:]:
self.destroy_segment(segment_name)
def __del__(self):
util.debug(f"Called {self.__class__.__name__}.__del__ in {os.getpid()}")
self.unlink()
def __getstate__(self):
return (self.shared_memory_context_name, self.segment_names)
def __setstate__(self, state):
self.__init__(*state)
def wrap(self, obj_exposing_buffer_protocol):
wrapped_obj = shareable_wrap(obj_exposing_buffer_protocol)
self.register_segment(wrapped_obj._shm)
return wrapped_obj
class SharedMemoryServer(Server):
def __init__(self, *args, **kwargs):
Server.__init__(self, *args, **kwargs)
self.shared_memory_context = \
SharedMemoryTracker(f"shmm_{self.address}_{os.getpid()}")
util.debug(f"SharedMemoryServer started by pid {os.getpid()}")
def create(self, c, typeid, *args, **kwargs):
# Unless set up as a shared proxy, don't make shared_memory_context
# a standard part of kwargs. This makes things easier for supplying
# simple functions.
if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"):
kwargs['shared_memory_context'] = self.shared_memory_context
return Server.create(self, c, typeid, *args, **kwargs)
def shutdown(self, c):
self.shared_memory_context.unlink()
return Server.shutdown(self, c)
class SharedMemoryManager(SyncManager):
"""Like SyncManager but uses SharedMemoryServer instead of Server.
TODO: Consider relocate/merge into managers submodule."""
_Server = SharedMemoryServer
def __init__(self, *args, **kwargs):
SyncManager.__init__(self, *args, **kwargs)
util.debug(f"{self.__class__.__name__} created by pid {os.getpid()}")
def __del__(self):
util.debug(f"{self.__class__.__name__} told die by pid {os.getpid()}")
pass
def get_server(self):
'Better than monkeypatching for now; merge into Server ultimately'
if self._state.value != State.INITIAL:
if self._state.value == State.STARTED:
raise ProcessError("Already started server")
elif self._state.value == State.SHUTDOWN:
raise ProcessError("Manager has shut down")
else:
raise ProcessError(
"Unknown state {!r}".format(self._state.value))
return _Server(self._registry, self._address,
self._authkey, self._serializer)