Added write support for GDALRaster

- Instantiation of GDALRaster instances from dict or json data.
- Retrieve and write pixel values in GDALBand objects.
- Support for the GDALFlushCache in gdal C prototypes
- Added private flush method to GDALRaster to make sure all
  data is written to files when file-based rasters are changed.
- Replaced ``ptr`` with ``_ptr`` for internal ptr variable

Refs #23804. Thanks Claude Paroz and Tim Graham for the reviews.
This commit is contained in:
Daniel Wiesmann 2015-03-13 18:49:02 +00:00 committed by Claude Paroz
parent 8758a63ddb
commit f269c1d6f6
11 changed files with 600 additions and 57 deletions

View file

@ -1,3 +1,4 @@
import json
import os
from ctypes import addressof, byref, c_double
@ -7,6 +8,7 @@ from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.gdal.raster.band import GDALBand
from django.contrib.gis.gdal.srs import SpatialReference, SRSException
from django.contrib.gis.geometry.regex import json_regex
from django.utils import six
from django.utils.encoding import (
force_bytes, force_text, python_2_unicode_compatible,
@ -33,10 +35,22 @@ class TransformPoint(list):
def x(self):
return self[0]
@x.setter
def x(self, value):
gtf = self._raster.geotransform
gtf[self.indices[self._prop][0]] = value
self._raster.geotransform = gtf
@property
def y(self):
return self[1]
@y.setter
def y(self, value):
gtf = self._raster.geotransform
gtf[self.indices[self._prop][1]] = value
self._raster.geotransform = gtf
@python_2_unicode_compatible
class GDALRaster(GDALBase):
@ -47,17 +61,64 @@ class GDALRaster(GDALBase):
self._write = 1 if write else 0
Driver.ensure_registered()
# Preprocess json inputs. This converts json strings to dictionaries,
# which are parsed below the same way as direct dictionary inputs.
if isinstance(ds_input, six.string_types) and json_regex.match(ds_input):
ds_input = json.loads(ds_input)
# If input is a valid file path, try setting file as source.
if isinstance(ds_input, six.string_types):
if os.path.exists(ds_input):
try:
# GDALOpen will auto-detect the data source type.
self.ptr = capi.open_ds(force_bytes(ds_input), self._write)
except GDALException as err:
raise GDALException('Could not open the datasource at "{}" ({}).'.format(
ds_input, err))
else:
if not os.path.exists(ds_input):
raise GDALException('Unable to read raster source input "{}"'.format(ds_input))
try:
# GDALOpen will auto-detect the data source type.
self._ptr = capi.open_ds(force_bytes(ds_input), self._write)
except GDALException as err:
raise GDALException('Could not open the datasource at "{}" ({}).'.format(ds_input, err))
elif isinstance(ds_input, dict):
# A new raster needs to be created in write mode
self._write = 1
# Create driver (in memory by default)
driver = Driver(ds_input.get('driver', 'MEM'))
# For out of memory drivers, check filename argument
if driver.name != 'MEM' and 'name' not in ds_input:
raise GDALException('Specify name for creation of raster with driver "{}".'.format(driver.name))
# Check if width and height where specified
if 'width' not in ds_input or 'height' not in ds_input:
raise GDALException('Specify width and height attributes for JSON or dict input.')
# Create GDAL Raster
self._ptr = capi.create_ds(
driver._ptr,
force_bytes(ds_input.get('name', '')),
ds_input['width'],
ds_input['height'],
ds_input.get('nr_of_bands', len(ds_input.get('bands', []))),
ds_input.get('datatype', 6),
None
)
# Set band data if provided
for i, band_input in enumerate(ds_input.get('bands', [])):
self.bands[i].data(band_input['data'])
if 'nodata_value' in band_input:
self.bands[i].nodata_value = band_input['nodata_value']
# Set SRID, default to 0 (this assures SRS is always instanciated)
self.srs = ds_input.get('srid', 0)
# Set additional properties if provided
if 'origin' in ds_input:
self.origin.x, self.origin.y = ds_input['origin']
if 'scale' in ds_input:
self.scale.x, self.scale.y = ds_input['scale']
if 'skew' in ds_input:
self.skew.x, self.skew.y = ds_input['skew']
else:
raise GDALException('Invalid data source input type: "{}".'.format(type(ds_input)))
@ -72,15 +133,34 @@ class GDALRaster(GDALBase):
"""
Short-hand representation because WKB may be very large.
"""
return '<Raster object at %s>' % hex(addressof(self.ptr))
return '<Raster object at %s>' % hex(addressof(self._ptr))
def _flush(self):
"""
Flush all data from memory into the source file if it exists.
The data that needs flushing are geotransforms, coordinate systems,
nodata_values and pixel values. This function will be called
automatically wherever it is needed.
"""
# Raise an Exception if the value is being changed in read mode.
if not self._write:
raise GDALException('Raster needs to be opened in write mode to change values.')
capi.flush_ds(self._ptr)
@property
def name(self):
return force_text(capi.get_ds_description(self.ptr))
"""
Returns the name of this raster. Corresponds to filename
for file-based rasters.
"""
return force_text(capi.get_ds_description(self._ptr))
@cached_property
def driver(self):
ds_driver = capi.get_ds_driver(self.ptr)
"""
Returns the GDAL Driver used for this raster.
"""
ds_driver = capi.get_ds_driver(self._ptr)
return Driver(ds_driver)
@property
@ -88,48 +168,85 @@ class GDALRaster(GDALBase):
"""
Width (X axis) in pixels.
"""
return capi.get_ds_xsize(self.ptr)
return capi.get_ds_xsize(self._ptr)
@property
def height(self):
"""
Height (Y axis) in pixels.
"""
return capi.get_ds_ysize(self.ptr)
return capi.get_ds_ysize(self._ptr)
@property
def srs(self):
"""
Returns the Spatial Reference used in this GDALRaster.
Returns the SpatialReference used in this GDALRaster.
"""
try:
wkt = capi.get_ds_projection_ref(self.ptr)
wkt = capi.get_ds_projection_ref(self._ptr)
if not wkt:
return None
return SpatialReference(wkt, srs_type='wkt')
except SRSException:
return None
@cached_property
@srs.setter
def srs(self, value):
"""
Sets the spatial reference used in this GDALRaster. The input can be
a SpatialReference or any parameter accepted by the SpatialReference
constructor.
"""
if isinstance(value, SpatialReference):
srs = value
elif isinstance(value, six.integer_types + six.string_types):
srs = SpatialReference(value)
else:
raise ValueError('Could not create a SpatialReference from input.')
capi.set_ds_projection_ref(self._ptr, srs.wkt.encode())
self._flush()
@property
def geotransform(self):
"""
Returns the geotransform of the data source.
Returns the default geotransform if it does not exist or has not been
set previously. The default is (0.0, 1.0, 0.0, 0.0, 0.0, -1.0).
set previously. The default is [0.0, 1.0, 0.0, 0.0, 0.0, -1.0].
"""
# Create empty ctypes double array for data
gtf = (c_double * 6)()
capi.get_ds_geotransform(self.ptr, byref(gtf))
return tuple(gtf)
capi.get_ds_geotransform(self._ptr, byref(gtf))
return list(gtf)
@geotransform.setter
def geotransform(self, values):
"Sets the geotransform for the data source."
if sum([isinstance(x, (int, float)) for x in values]) != 6:
raise ValueError('Geotransform must consist of 6 numeric values.')
# Create ctypes double array with input and write data
values = (c_double * 6)(*values)
capi.set_ds_geotransform(self._ptr, byref(values))
self._flush()
@property
def origin(self):
"""
Coordinates of the raster origin.
"""
return TransformPoint(self, 'origin')
@property
def scale(self):
"""
Pixel scale in units of the raster projection.
"""
return TransformPoint(self, 'scale')
@property
def skew(self):
"""
Skew of pixels (rotation parameters).
"""
return TransformPoint(self, 'skew')
@property
@ -150,7 +267,10 @@ class GDALRaster(GDALBase):
@cached_property
def bands(self):
"""
Returns the bands of this raster as a list of GDALBand instances.
"""
bands = []
for idx in range(1, capi.get_ds_raster_count(self.ptr) + 1):
for idx in range(1, capi.get_ds_raster_count(self._ptr) + 1):
bands.append(GDALBand(self, idx))
return bands