Added write support for GDALRaster

- Instantiation of GDALRaster instances from dict or json data.
- Retrieve and write pixel values in GDALBand objects.
- Support for the GDALFlushCache in gdal C prototypes
- Added private flush method to GDALRaster to make sure all
  data is written to files when file-based rasters are changed.
- Replaced ``ptr`` with ``_ptr`` for internal ptr variable

Refs #23804. Thanks Claude Paroz and Tim Graham for the reviews.
This commit is contained in:
Daniel Wiesmann 2015-03-13 18:49:02 +00:00 committed by Claude Paroz
parent 8758a63ddb
commit f269c1d6f6
11 changed files with 600 additions and 57 deletions

View file

@ -31,6 +31,7 @@ get_driver_description = const_string_output(lgdal.GDALGetDescription, [c_void_p
create_ds = voidptr_output(lgdal.GDALCreate, [c_void_p, c_char_p, c_int, c_int, c_int, c_int])
open_ds = voidptr_output(lgdal.GDALOpen, [c_char_p, c_int])
close_ds = void_output(lgdal.GDALClose, [c_void_p])
flush_ds = int_output(lgdal.GDALFlushCache, [c_void_p])
copy_ds = voidptr_output(lgdal.GDALCreateCopy, [c_void_p, c_char_p, c_void_p, c_int,
POINTER(c_char_p), c_void_p, c_void_p])
add_band_ds = void_output(lgdal.GDALAddBand, [c_void_p, c_int])

View file

@ -2,9 +2,11 @@ from ctypes import byref, c_int
from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.shortcuts import numpy
from django.utils import six
from django.utils.encoding import force_text
from .const import GDAL_PIXEL_TYPES
from .const import GDAL_PIXEL_TYPES, GDAL_TO_CTYPES
class GDALBand(GDALBase):
@ -13,51 +15,49 @@ class GDALBand(GDALBase):
"""
def __init__(self, source, index):
self.source = source
self.ptr = capi.get_ds_raster_band(source.ptr, index)
self._ptr = capi.get_ds_raster_band(source._ptr, index)
@property
def description(self):
"""
Returns the description string of the band.
"""
return force_text(capi.get_band_description(self.ptr))
return force_text(capi.get_band_description(self._ptr))
@property
def width(self):
"""
Width (X axis) in pixels of the band.
"""
return capi.get_band_xsize(self.ptr)
return capi.get_band_xsize(self._ptr)
@property
def height(self):
"""
Height (Y axis) in pixels of the band.
"""
return capi.get_band_ysize(self.ptr)
return capi.get_band_ysize(self._ptr)
def datatype(self, as_string=False):
@property
def pixel_count(self):
"""
Returns the GDAL Pixel Datatype for this band.
Returns the total number of pixels in this band.
"""
dtype = capi.get_band_datatype(self.ptr)
if as_string:
dtype = GDAL_PIXEL_TYPES[dtype]
return dtype
return self.width * self.height
@property
def min(self):
"""
Returns the minimum pixel value for this band.
"""
return capi.get_band_minimum(self.ptr, byref(c_int()))
return capi.get_band_minimum(self._ptr, byref(c_int()))
@property
def max(self):
"""
Returns the maximum pixel value for this band.
"""
return capi.get_band_maximum(self.ptr, byref(c_int()))
return capi.get_band_maximum(self._ptr, byref(c_int()))
@property
def nodata_value(self):
@ -65,5 +65,80 @@ class GDALBand(GDALBase):
Returns the nodata value for this band, or None if it isn't set.
"""
nodata_exists = c_int()
value = capi.get_band_nodata_value(self.ptr, nodata_exists)
value = capi.get_band_nodata_value(self._ptr, nodata_exists)
return value if nodata_exists else None
@nodata_value.setter
def nodata_value(self, value):
"""
Sets the nodata value for this band.
"""
if not isinstance(value, (int, float)):
raise ValueError('Nodata value must be numeric.')
capi.set_band_nodata_value(self._ptr, value)
self.source._flush()
def datatype(self, as_string=False):
"""
Returns the GDAL Pixel Datatype for this band.
"""
dtype = capi.get_band_datatype(self._ptr)
if as_string:
dtype = GDAL_PIXEL_TYPES[dtype]
return dtype
def data(self, data=None, offset=None, size=None, as_memoryview=False):
"""
Reads or writes pixel values for this band. Blocks of data can
be accessed by specifying the width, height and offset of the
desired block. The same specification can be used to update
parts of a raster by providing an array of values.
Allowed input data types are bytes, memoryview, list, tuple, and array.
"""
if not offset:
offset = (0, 0)
if not size:
size = (self.width - offset[0], self.height - offset[1])
if any(x <= 0 for x in size):
raise ValueError('Offset too big for this raster.')
if size[0] > self.width or size[1] > self.height:
raise ValueError('Size is larger than raster.')
# Create ctypes type array generator
ctypes_array = GDAL_TO_CTYPES[self.datatype()] * (size[0] * size[1])
if data is None:
# Set read mode
access_flag = 0
# Prepare empty ctypes array
data_array = ctypes_array()
else:
# Set write mode
access_flag = 1
# Instantiate ctypes array holding the input data
if isinstance(data, (bytes, six.memoryview, numpy.ndarray)):
data_array = ctypes_array.from_buffer_copy(data)
else:
data_array = ctypes_array(*data)
# Access band
capi.band_io(self._ptr, access_flag, offset[0], offset[1],
size[0], size[1], byref(data_array), size[0],
size[1], self.datatype(), 0, 0)
# Return data as numpy array if possible, otherwise as list
if data is None:
if as_memoryview:
return memoryview(data_array)
elif numpy:
return numpy.frombuffer(
data_array, dtype=numpy.dtype(data_array)).reshape(size)
else:
return list(data_array)
else:
self.source._flush()

View file

@ -1,6 +1,9 @@
"""
GDAL - Constant definitions
"""
from ctypes import (
c_byte, c_double, c_float, c_int16, c_int32, c_uint16, c_uint32,
)
# See http://www.gdal.org/gdal_8h.html#a22e22ce0a55036a96f652765793fb7a4
GDAL_PIXEL_TYPES = {
@ -17,3 +20,12 @@ GDAL_PIXEL_TYPES = {
10: 'GDT_CFloat32', # Complex Float32
11: 'GDT_CFloat64', # Complex Float64
}
# Lookup values to convert GDAL pixel type indices into ctypes objects.
# The GDAL band-io works with ctypes arrays to hold data to be written
# or to hold the space for data to be read into. The lookup below helps
# selecting the right ctypes object for a given gdal pixel type.
GDAL_TO_CTYPES = [
None, c_byte, c_uint16, c_int16, c_uint32, c_int32,
c_float, c_double, None, None, None, None
]

View file

@ -1,3 +1,4 @@
import json
import os
from ctypes import addressof, byref, c_double
@ -7,6 +8,7 @@ from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.gdal.raster.band import GDALBand
from django.contrib.gis.gdal.srs import SpatialReference, SRSException
from django.contrib.gis.geometry.regex import json_regex
from django.utils import six
from django.utils.encoding import (
force_bytes, force_text, python_2_unicode_compatible,
@ -33,10 +35,22 @@ class TransformPoint(list):
def x(self):
return self[0]
@x.setter
def x(self, value):
gtf = self._raster.geotransform
gtf[self.indices[self._prop][0]] = value
self._raster.geotransform = gtf
@property
def y(self):
return self[1]
@y.setter
def y(self, value):
gtf = self._raster.geotransform
gtf[self.indices[self._prop][1]] = value
self._raster.geotransform = gtf
@python_2_unicode_compatible
class GDALRaster(GDALBase):
@ -47,17 +61,64 @@ class GDALRaster(GDALBase):
self._write = 1 if write else 0
Driver.ensure_registered()
# Preprocess json inputs. This converts json strings to dictionaries,
# which are parsed below the same way as direct dictionary inputs.
if isinstance(ds_input, six.string_types) and json_regex.match(ds_input):
ds_input = json.loads(ds_input)
# If input is a valid file path, try setting file as source.
if isinstance(ds_input, six.string_types):
if os.path.exists(ds_input):
try:
# GDALOpen will auto-detect the data source type.
self.ptr = capi.open_ds(force_bytes(ds_input), self._write)
except GDALException as err:
raise GDALException('Could not open the datasource at "{}" ({}).'.format(
ds_input, err))
else:
if not os.path.exists(ds_input):
raise GDALException('Unable to read raster source input "{}"'.format(ds_input))
try:
# GDALOpen will auto-detect the data source type.
self._ptr = capi.open_ds(force_bytes(ds_input), self._write)
except GDALException as err:
raise GDALException('Could not open the datasource at "{}" ({}).'.format(ds_input, err))
elif isinstance(ds_input, dict):
# A new raster needs to be created in write mode
self._write = 1
# Create driver (in memory by default)
driver = Driver(ds_input.get('driver', 'MEM'))
# For out of memory drivers, check filename argument
if driver.name != 'MEM' and 'name' not in ds_input:
raise GDALException('Specify name for creation of raster with driver "{}".'.format(driver.name))
# Check if width and height where specified
if 'width' not in ds_input or 'height' not in ds_input:
raise GDALException('Specify width and height attributes for JSON or dict input.')
# Create GDAL Raster
self._ptr = capi.create_ds(
driver._ptr,
force_bytes(ds_input.get('name', '')),
ds_input['width'],
ds_input['height'],
ds_input.get('nr_of_bands', len(ds_input.get('bands', []))),
ds_input.get('datatype', 6),
None
)
# Set band data if provided
for i, band_input in enumerate(ds_input.get('bands', [])):
self.bands[i].data(band_input['data'])
if 'nodata_value' in band_input:
self.bands[i].nodata_value = band_input['nodata_value']
# Set SRID, default to 0 (this assures SRS is always instanciated)
self.srs = ds_input.get('srid', 0)
# Set additional properties if provided
if 'origin' in ds_input:
self.origin.x, self.origin.y = ds_input['origin']
if 'scale' in ds_input:
self.scale.x, self.scale.y = ds_input['scale']
if 'skew' in ds_input:
self.skew.x, self.skew.y = ds_input['skew']
else:
raise GDALException('Invalid data source input type: "{}".'.format(type(ds_input)))
@ -72,15 +133,34 @@ class GDALRaster(GDALBase):
"""
Short-hand representation because WKB may be very large.
"""
return '<Raster object at %s>' % hex(addressof(self.ptr))
return '<Raster object at %s>' % hex(addressof(self._ptr))
def _flush(self):
"""
Flush all data from memory into the source file if it exists.
The data that needs flushing are geotransforms, coordinate systems,
nodata_values and pixel values. This function will be called
automatically wherever it is needed.
"""
# Raise an Exception if the value is being changed in read mode.
if not self._write:
raise GDALException('Raster needs to be opened in write mode to change values.')
capi.flush_ds(self._ptr)
@property
def name(self):
return force_text(capi.get_ds_description(self.ptr))
"""
Returns the name of this raster. Corresponds to filename
for file-based rasters.
"""
return force_text(capi.get_ds_description(self._ptr))
@cached_property
def driver(self):
ds_driver = capi.get_ds_driver(self.ptr)
"""
Returns the GDAL Driver used for this raster.
"""
ds_driver = capi.get_ds_driver(self._ptr)
return Driver(ds_driver)
@property
@ -88,48 +168,85 @@ class GDALRaster(GDALBase):
"""
Width (X axis) in pixels.
"""
return capi.get_ds_xsize(self.ptr)
return capi.get_ds_xsize(self._ptr)
@property
def height(self):
"""
Height (Y axis) in pixels.
"""
return capi.get_ds_ysize(self.ptr)
return capi.get_ds_ysize(self._ptr)
@property
def srs(self):
"""
Returns the Spatial Reference used in this GDALRaster.
Returns the SpatialReference used in this GDALRaster.
"""
try:
wkt = capi.get_ds_projection_ref(self.ptr)
wkt = capi.get_ds_projection_ref(self._ptr)
if not wkt:
return None
return SpatialReference(wkt, srs_type='wkt')
except SRSException:
return None
@cached_property
@srs.setter
def srs(self, value):
"""
Sets the spatial reference used in this GDALRaster. The input can be
a SpatialReference or any parameter accepted by the SpatialReference
constructor.
"""
if isinstance(value, SpatialReference):
srs = value
elif isinstance(value, six.integer_types + six.string_types):
srs = SpatialReference(value)
else:
raise ValueError('Could not create a SpatialReference from input.')
capi.set_ds_projection_ref(self._ptr, srs.wkt.encode())
self._flush()
@property
def geotransform(self):
"""
Returns the geotransform of the data source.
Returns the default geotransform if it does not exist or has not been
set previously. The default is (0.0, 1.0, 0.0, 0.0, 0.0, -1.0).
set previously. The default is [0.0, 1.0, 0.0, 0.0, 0.0, -1.0].
"""
# Create empty ctypes double array for data
gtf = (c_double * 6)()
capi.get_ds_geotransform(self.ptr, byref(gtf))
return tuple(gtf)
capi.get_ds_geotransform(self._ptr, byref(gtf))
return list(gtf)
@geotransform.setter
def geotransform(self, values):
"Sets the geotransform for the data source."
if sum([isinstance(x, (int, float)) for x in values]) != 6:
raise ValueError('Geotransform must consist of 6 numeric values.')
# Create ctypes double array with input and write data
values = (c_double * 6)(*values)
capi.set_ds_geotransform(self._ptr, byref(values))
self._flush()
@property
def origin(self):
"""
Coordinates of the raster origin.
"""
return TransformPoint(self, 'origin')
@property
def scale(self):
"""
Pixel scale in units of the raster projection.
"""
return TransformPoint(self, 'scale')
@property
def skew(self):
"""
Skew of pixels (rotation parameters).
"""
return TransformPoint(self, 'skew')
@property
@ -150,7 +267,10 @@ class GDALRaster(GDALBase):
@cached_property
def bands(self):
"""
Returns the bands of this raster as a list of GDALBand instances.
"""
bands = []
for idx in range(1, capi.get_ds_raster_count(self.ptr) + 1):
for idx in range(1, capi.get_ds_raster_count(self._ptr) + 1):
bands.append(GDALBand(self, idx))
return bands