mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 02:15:10 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			325 lines
		
	
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			325 lines
		
	
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Implementation of marshal.loads() in pure Python
 | |
| 
 | |
| import ast
 | |
| 
 | |
| from typing import Any, Tuple
 | |
| 
 | |
| 
 | |
| class Type:
 | |
|     # Adapted from marshal.c
 | |
|     NULL                = ord('0')
 | |
|     NONE                = ord('N')
 | |
|     FALSE               = ord('F')
 | |
|     TRUE                = ord('T')
 | |
|     STOPITER            = ord('S')
 | |
|     ELLIPSIS            = ord('.')
 | |
|     INT                 = ord('i')
 | |
|     INT64               = ord('I')
 | |
|     FLOAT               = ord('f')
 | |
|     BINARY_FLOAT        = ord('g')
 | |
|     COMPLEX             = ord('x')
 | |
|     BINARY_COMPLEX      = ord('y')
 | |
|     LONG                = ord('l')
 | |
|     STRING              = ord('s')
 | |
|     INTERNED            = ord('t')
 | |
|     REF                 = ord('r')
 | |
|     TUPLE               = ord('(')
 | |
|     LIST                = ord('[')
 | |
|     DICT                = ord('{')
 | |
|     CODE                = ord('c')
 | |
|     UNICODE             = ord('u')
 | |
|     UNKNOWN             = ord('?')
 | |
|     SET                 = ord('<')
 | |
|     FROZENSET           = ord('>')
 | |
|     ASCII               = ord('a')
 | |
|     ASCII_INTERNED      = ord('A')
 | |
|     SMALL_TUPLE         = ord(')')
 | |
|     SHORT_ASCII         = ord('z')
 | |
|     SHORT_ASCII_INTERNED = ord('Z')
 | |
| 
 | |
| 
 | |
| FLAG_REF = 0x80  # with a type, add obj to index
 | |
| 
 | |
| NULL = object()  # marker
 | |
| 
 | |
| # Cell kinds
 | |
| CO_FAST_LOCAL = 0x20
 | |
| CO_FAST_CELL = 0x40
 | |
| CO_FAST_FREE = 0x80
 | |
| 
 | |
| 
 | |
| class Code:
 | |
|     def __init__(self, **kwds: Any):
 | |
|         self.__dict__.update(kwds)
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         return f"Code(**{self.__dict__})"
 | |
| 
 | |
|     co_localsplusnames: Tuple[str]
 | |
|     co_localspluskinds: Tuple[int]
 | |
| 
 | |
|     def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]:
 | |
|         varnames: list[str] = []
 | |
|         for name, kind in zip(self.co_localsplusnames,
 | |
|                               self.co_localspluskinds):
 | |
|             if kind & select_kind:
 | |
|                 varnames.append(name)
 | |
|         return tuple(varnames)
 | |
| 
 | |
|     @property
 | |
|     def co_varnames(self) -> Tuple[str, ...]:
 | |
|         return self.get_localsplus_names(CO_FAST_LOCAL)
 | |
| 
 | |
|     @property
 | |
|     def co_cellvars(self) -> Tuple[str, ...]:
 | |
|         return self.get_localsplus_names(CO_FAST_CELL)
 | |
| 
 | |
|     @property
 | |
|     def co_freevars(self) -> Tuple[str, ...]:
 | |
|         return self.get_localsplus_names(CO_FAST_FREE)
 | |
| 
 | |
|     @property
 | |
|     def co_nlocals(self) -> int:
 | |
|         return len(self.co_varnames)
 | |
| 
 | |
| 
 | |
| class Reader:
 | |
|     # A fairly literal translation of the marshal reader.
 | |
| 
 | |
|     def __init__(self, data: bytes):
 | |
|         self.data: bytes = data
 | |
|         self.end: int = len(self.data)
 | |
|         self.pos: int = 0
 | |
|         self.refs: list[Any] = []
 | |
|         self.level: int = 0
 | |
| 
 | |
|     def r_string(self, n: int) -> bytes:
 | |
|         assert 0 <= n <= self.end - self.pos
 | |
|         buf = self.data[self.pos : self.pos + n]
 | |
|         self.pos += n
 | |
|         return buf
 | |
| 
 | |
|     def r_byte(self) -> int:
 | |
|         buf = self.r_string(1)
 | |
|         return buf[0]
 | |
| 
 | |
|     def r_short(self) -> int:
 | |
|         buf = self.r_string(2)
 | |
|         x = buf[0]
 | |
|         x |= buf[1] << 8
 | |
|         x |= -(x & (1<<15))  # Sign-extend
 | |
|         return x
 | |
| 
 | |
|     def r_long(self) -> int:
 | |
|         buf = self.r_string(4)
 | |
|         x = buf[0]
 | |
|         x |= buf[1] << 8
 | |
|         x |= buf[2] << 16
 | |
|         x |= buf[3] << 24
 | |
|         x |= -(x & (1<<31))  # Sign-extend
 | |
|         return x
 | |
| 
 | |
|     def r_long64(self) -> int:
 | |
|         buf = self.r_string(8)
 | |
|         x = buf[0]
 | |
|         x |= buf[1] << 8
 | |
|         x |= buf[2] << 16
 | |
|         x |= buf[3] << 24
 | |
|         x |= buf[4] << 32
 | |
|         x |= buf[5] << 40
 | |
|         x |= buf[6] << 48
 | |
|         x |= buf[7] << 56
 | |
|         x |= -(x & (1<<63))  # Sign-extend
 | |
|         return x
 | |
| 
 | |
|     def r_PyLong(self) -> int:
 | |
|         n = self.r_long()
 | |
|         size = abs(n)
 | |
|         x = 0
 | |
|         # Pray this is right
 | |
|         for i in range(size):
 | |
|             x |= self.r_short() << i*15
 | |
|         if n < 0:
 | |
|             x = -x
 | |
|         return x
 | |
| 
 | |
|     def r_float_bin(self) -> float:
 | |
|         buf = self.r_string(8)
 | |
|         import struct  # Lazy import to avoid breaking UNIX build
 | |
|         return struct.unpack("d", buf)[0]
 | |
| 
 | |
|     def r_float_str(self) -> float:
 | |
|         n = self.r_byte()
 | |
|         buf = self.r_string(n)
 | |
|         return ast.literal_eval(buf.decode("ascii"))
 | |
| 
 | |
|     def r_ref_reserve(self, flag: int) -> int:
 | |
|         if flag:
 | |
|             idx = len(self.refs)
 | |
|             self.refs.append(None)
 | |
|             return idx
 | |
|         else:
 | |
|             return 0
 | |
| 
 | |
|     def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any:
 | |
|         if flag:
 | |
|             self.refs[idx] = obj
 | |
|         return obj
 | |
| 
 | |
|     def r_ref(self, obj: Any, flag: int) -> Any:
 | |
|         assert flag & FLAG_REF
 | |
|         self.refs.append(obj)
 | |
|         return obj
 | |
| 
 | |
|     def r_object(self) -> Any:
 | |
|         old_level = self.level
 | |
|         try:
 | |
|             return self._r_object()
 | |
|         finally:
 | |
|             self.level = old_level
 | |
| 
 | |
|     def _r_object(self) -> Any:
 | |
|         code = self.r_byte()
 | |
|         flag = code & FLAG_REF
 | |
|         type = code & ~FLAG_REF
 | |
|         # print("  "*self.level + f"{code} {flag} {type} {chr(type)!r}")
 | |
|         self.level += 1
 | |
| 
 | |
|         def R_REF(obj: Any) -> Any:
 | |
|             if flag:
 | |
|                 obj = self.r_ref(obj, flag)
 | |
|             return obj
 | |
| 
 | |
|         if type == Type.NULL:
 | |
|             return NULL
 | |
|         elif type == Type.NONE:
 | |
|             return None
 | |
|         elif type == Type.ELLIPSIS:
 | |
|             return Ellipsis
 | |
|         elif type == Type.FALSE:
 | |
|             return False
 | |
|         elif type == Type.TRUE:
 | |
|             return True
 | |
|         elif type == Type.INT:
 | |
|             return R_REF(self.r_long())
 | |
|         elif type == Type.INT64:
 | |
|             return R_REF(self.r_long64())
 | |
|         elif type == Type.LONG:
 | |
|             return R_REF(self.r_PyLong())
 | |
|         elif type == Type.FLOAT:
 | |
|             return R_REF(self.r_float_str())
 | |
|         elif type == Type.BINARY_FLOAT:
 | |
|             return R_REF(self.r_float_bin())
 | |
|         elif type == Type.COMPLEX:
 | |
|             return R_REF(complex(self.r_float_str(),
 | |
|                                     self.r_float_str()))
 | |
|         elif type == Type.BINARY_COMPLEX:
 | |
|             return R_REF(complex(self.r_float_bin(),
 | |
|                                     self.r_float_bin()))
 | |
|         elif type == Type.STRING:
 | |
|             n = self.r_long()
 | |
|             return R_REF(self.r_string(n))
 | |
|         elif type == Type.ASCII_INTERNED or type == Type.ASCII:
 | |
|             n = self.r_long()
 | |
|             return R_REF(self.r_string(n).decode("ascii"))
 | |
|         elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII:
 | |
|             n = self.r_byte()
 | |
|             return R_REF(self.r_string(n).decode("ascii"))
 | |
|         elif type == Type.INTERNED or type == Type.UNICODE:
 | |
|             n = self.r_long()
 | |
|             return R_REF(self.r_string(n).decode("utf8", "surrogatepass"))
 | |
|         elif type == Type.SMALL_TUPLE:
 | |
|             n = self.r_byte()
 | |
|             idx = self.r_ref_reserve(flag)
 | |
|             retval: Any = tuple(self.r_object() for _ in range(n))
 | |
|             self.r_ref_insert(retval, idx, flag)
 | |
|             return retval
 | |
|         elif type == Type.TUPLE:
 | |
|             n = self.r_long()
 | |
|             idx = self.r_ref_reserve(flag)
 | |
|             retval = tuple(self.r_object() for _ in range(n))
 | |
|             self.r_ref_insert(retval, idx, flag)
 | |
|             return retval
 | |
|         elif type == Type.LIST:
 | |
|             n = self.r_long()
 | |
|             retval = R_REF([])
 | |
|             for _ in range(n):
 | |
|                 retval.append(self.r_object())
 | |
|             return retval
 | |
|         elif type == Type.DICT:
 | |
|             retval = R_REF({})
 | |
|             while True:
 | |
|                 key = self.r_object()
 | |
|                 if key == NULL:
 | |
|                     break
 | |
|                 val = self.r_object()
 | |
|                 retval[key] = val
 | |
|             return retval
 | |
|         elif type == Type.SET:
 | |
|             n = self.r_long()
 | |
|             retval = R_REF(set())
 | |
|             for _ in range(n):
 | |
|                 v = self.r_object()
 | |
|                 retval.add(v)
 | |
|             return retval
 | |
|         elif type == Type.FROZENSET:
 | |
|             n = self.r_long()
 | |
|             s: set[Any] = set()
 | |
|             idx = self.r_ref_reserve(flag)
 | |
|             for _ in range(n):
 | |
|                 v = self.r_object()
 | |
|                 s.add(v)
 | |
|             retval = frozenset(s)
 | |
|             self.r_ref_insert(retval, idx, flag)
 | |
|             return retval
 | |
|         elif type == Type.CODE:
 | |
|             retval = R_REF(Code())
 | |
|             retval.co_argcount = self.r_long()
 | |
|             retval.co_posonlyargcount = self.r_long()
 | |
|             retval.co_kwonlyargcount = self.r_long()
 | |
|             retval.co_stacksize = self.r_long()
 | |
|             retval.co_flags = self.r_long()
 | |
|             retval.co_code = self.r_object()
 | |
|             retval.co_consts = self.r_object()
 | |
|             retval.co_names = self.r_object()
 | |
|             retval.co_localsplusnames = self.r_object()
 | |
|             retval.co_localspluskinds = self.r_object()
 | |
|             retval.co_filename = self.r_object()
 | |
|             retval.co_name = self.r_object()
 | |
|             retval.co_qualname = self.r_object()
 | |
|             retval.co_firstlineno = self.r_long()
 | |
|             retval.co_linetable = self.r_object()
 | |
|             retval.co_exceptiontable = self.r_object()
 | |
|             return retval
 | |
|         elif type == Type.REF:
 | |
|             n = self.r_long()
 | |
|             retval = self.refs[n]
 | |
|             assert retval is not None
 | |
|             return retval
 | |
|         else:
 | |
|             breakpoint()
 | |
|             raise AssertionError(f"Unknown type {type} {chr(type)!r}")
 | |
| 
 | |
| 
 | |
| def loads(data: bytes) -> Any:
 | |
|     assert isinstance(data, bytes)
 | |
|     r = Reader(data)
 | |
|     return r.r_object()
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     # Test
 | |
|     import marshal, pprint
 | |
|     sample = {'foo': {(42, "bar", 3.14)}}
 | |
|     data = marshal.dumps(sample)
 | |
|     retval = loads(data)
 | |
|     assert retval == sample, retval
 | |
|     sample = main.__code__
 | |
|     data = marshal.dumps(sample)
 | |
|     retval = loads(data)
 | |
|     assert isinstance(retval, Code), retval
 | |
|     pprint.pprint(retval.__dict__)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 | 
