#8295 : Added shutil.unpack_archive and related APIs

This commit is contained in:
Tarek Ziadé 2010-04-28 17:51:36 +00:00
parent 71fb6c88a8
commit 6ac91723bd
4 changed files with 295 additions and 6 deletions

View file

@ -11,6 +11,7 @@ from os.path import abspath
import fnmatch
import collections
import errno
import tarfile
try:
from pwd import getpwnam
@ -25,7 +26,9 @@ except ImportError:
__all__ = ["copyfileobj", "copyfile", "copymode", "copystat", "copy", "copy2",
"copytree", "move", "rmtree", "Error", "SpecialFileError",
"ExecError", "make_archive", "get_archive_formats",
"register_archive_format", "unregister_archive_format"]
"register_archive_format", "unregister_archive_format",
"get_unpack_formats", "register_unpack_format",
"unregister_unpack_format", "unpack_archive"]
class Error(EnvironmentError):
pass
@ -37,6 +40,14 @@ class SpecialFileError(EnvironmentError):
class ExecError(EnvironmentError):
"""Raised when a command could not be executed"""
class ReadError(EnvironmentError):
"""Raised when an archive cannot be read"""
class RegistryError(Exception):
"""Raised when a registery operation with the archiving
and unpacking registeries fails"""
try:
WindowsError
except NameError:
@ -381,10 +392,7 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
if not dry_run:
os.makedirs(archive_dir)
# creating the tarball
import tarfile # late import so Python build itself doesn't break
if logger is not None:
logger.info('Creating tar archive')
@ -567,3 +575,165 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0,
os.chdir(save_cwd)
return filename
def get_unpack_formats():
"""Returns a list of supported formats for unpacking.
Each element of the returned sequence is a tuple
(name, extensions, description)
"""
formats = [(name, info[0], info[3]) for name, info in
_UNPACK_FORMATS.items()]
formats.sort()
return formats
def _check_unpack_options(extensions, function, extra_args):
"""Checks what gets registered as an unpacker."""
# first make sure no other unpacker is registered for this extension
existing_extensions = {}
for name, info in _UNPACK_FORMATS.items():
for ext in info[0]:
existing_extensions[ext] = name
for extension in extensions:
if extension in existing_extensions:
msg = '%s is already registered for "%s"'
raise RegistryError(msg % (extension,
existing_extensions[extension]))
if not isinstance(function, collections.Callable):
raise TypeError('The registered function must be a callable')
def register_unpack_format(name, extensions, function, extra_args=None,
description=''):
"""Registers an unpack format.
`name` is the name of the format. `extensions` is a list of extensions
corresponding to the format.
`function` is the callable that will be
used to unpack archives. The callable will receive archives to unpack.
If it's unable to handle an archive, it needs to raise a ReadError
exception.
If provided, `extra_args` is a sequence of
(name, value) tuples that will be passed as arguments to the callable.
description can be provided to describe the format, and will be returned
by the get_unpack_formats() function.
"""
if extra_args is None:
extra_args = []
_check_unpack_options(extensions, function, extra_args)
_UNPACK_FORMATS[name] = extensions, function, extra_args, description
def unregister_unpack_format(name):
"""Removes the pack format from the registery."""
del _UNPACK_FORMATS[name]
def _ensure_directory(path):
"""Ensure that the parent directory of `path` exists"""
dirname = os.path.dirname(path)
if not os.path.isdir(dirname):
os.makedirs(dirname)
def _unpack_zipfile(filename, extract_dir):
"""Unpack zip `filename` to `extract_dir`
"""
try:
import zipfile
except ImportError:
raise ReadError('zlib not supported, cannot unpack this archive.')
if not zipfile.is_zipfile(filename):
raise ReadError("%s is not a zip file" % filename)
zip = zipfile.ZipFile(filename)
try:
for info in zip.infolist():
name = info.filename
# don't extract absolute paths or ones with .. in them
if name.startswith('/') or '..' in name:
continue
target = os.path.join(extract_dir, *name.split('/'))
if not target:
continue
_ensure_directory(target)
if not name.endswith('/'):
# file
data = zip.read(info.filename)
f = open(target,'wb')
try:
f.write(data)
finally:
f.close()
del data
finally:
zip.close()
def _unpack_tarfile(filename, extract_dir):
"""Unpack tar/tar.gz/tar.bz2 `filename` to `extract_dir`
"""
try:
tarobj = tarfile.open(filename)
except tarfile.TarError:
raise ReadError(
"%s is not a compressed or uncompressed tar file" % filename)
try:
tarobj.extractall(extract_dir)
finally:
tarobj.close()
_UNPACK_FORMATS = {
'gztar': (['.tar.gz', '.tgz'], _unpack_tarfile, [], "gzip'ed tar-file"),
'bztar': (['.bz2'], _unpack_tarfile, [], "bzip2'ed tar-file"),
'tar': (['.tar'], _unpack_tarfile, [], "uncompressed tar file"),
'zip': (['.zip'], _unpack_zipfile, [], "ZIP file")
}
def _find_unpack_format(filename):
for name, info in _UNPACK_FORMATS.items():
for extension in info[0]:
if filename.endswith(extension):
return name
return None
def unpack_archive(filename, extract_dir=None, format=None):
"""Unpack an archive.
`filename` is the name of the archive.
`extract_dir` is the name of the target directory, where the archive
is unpacked. If not provided, the current working directory is used.
`format` is the archive format: one of "zip", "tar", or "gztar". Or any
other registered format. If not provided, unpack_archive will use the
filename extension and see if an unpacker was registered for that
extension.
In case none is found, a ValueError is raised.
"""
if extract_dir is None:
extract_dir = os.getcwd()
if format is not None:
try:
format_info = _UNPACK_FORMATS[format]
except KeyError:
raise ValueError("Unknown unpack format '{0}'".format(format))
func = format_info[0]
func(filename, extract_dir, **dict(format_info[1]))
else:
# we need to look at the registered unpackers supported extensions
format = _find_unpack_format(filename)
if format is None:
raise ReadError("Unknown archive format '{0}'".format(filename))
func = _UNPACK_FORMATS[format][1]
kwargs = dict(_UNPACK_FORMATS[format][2])
func(filename, extract_dir, **kwargs)