How to save a NumPy array as a GeoTIFF file using GDAL

Dr. Huidae Cho
Institute for Environmental and Spatial Analysis...University of North Georgia

1   Introduction

Try these exercises first:

Use nlcd2001_clipped.tif and nlcd2016_clipped.tif.

2   Python code

from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt

def read_geotiff(filename):
    ds = gdal.Open(filename)
    band = ds.GetRasterBand(1)
    arr = band.ReadAsArray()
    return arr, ds

def write_geotiff(filename, arr, in_ds):
    if arr.dtype == np.float32:
        arr_type = gdal.GDT_Float32
    else:
        arr_type = gdal.GDT_Int32

    driver = gdal.GetDriverByName("GTiff")
    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
    out_ds.SetProjection(in_ds.GetProjection())
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())
    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)

nlcd01_arr, nlcd01_ds = read_geotiff("nlcd2001_clipped.tif")
nlcd16_arr, nlcd16_ds = read_geotiff("nlcd2016_clipped.tif")

nlcd_changed = np.where(nlcd01_arr != nlcd16_arr, 1, 0)

write_geotiff("nlcd_changed.tif", nlcd_changed, nlcd01_ds)

plt.subplot(311)
plt.imshow(nlcd01_arr)

plt.subplot(312)
plt.imshow(nlcd16_arr)

plt.subplot(313)
plt.imshow(nlcd_changed)

plt.show()

3   Explanation

Import necessary modules.

from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt

Define the read_geotiff() function that returns both a NumPy array and a GDAL dataset. We need the GDAL dataset later when we want to create a new GeoTIFF file using the same projection and geotransform information.

def read_geotiff(filename):
    ds = gdal.Open(filename)
    band = ds.GetRasterBand(1)
    arr = band.ReadAsArray()
    return arr, ds

Define the write_geotiff() function that takes a filename, a NumPy array, and an input dataset. We use the input dataset argument to set projection and geotransform information for a new GeoTIFF file. This function can handle floating-point and integer NumPy arrays. A NumPy array has an attribute called dtype (data type) that contains its cell data type. If this data type of the input NumPy array is np.flat32, we use the data type of gdal.GDT_Float32 for writing a new GeoTIFF file. Otherwise, we use gdal.GDT_Int32.

def write_geotiff(filename, arr, in_ds):
    if arr.dtype == np.float32:
        arr_type = gdal.GDT_Float32
    else:
        arr_type = gdal.GDT_Int32

The GDAL module supports different file formats for raster data. One of them is GeoTIFF. In GDAL, any code that handles a specific file format is called a driver. For our exercise, we use GeoTIFF, so we get the GeoTIFF driver called GTiff.

    driver = gdal.GetDriverByName("GTiff")

Now, we pass required arguments to the driver. These include the filename (filename), the width (arr.shape[1], the number of columns), the height (arr.shape[0], the number of rows), the number of bands (1 band), and the data type from above (arr_type). Then, we need to set projection and geotransform information using the SetProjection() and SetGeoTransform() functions. Use the input dataset in_ds to retrieve this information from the input GeoTIFF file.

    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
    out_ds.SetProjection(in_ds.GetProjection())
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())

We can now create a new band (1st band) and write the NumPy array (arr) into the new band. band.FlushCache() makes sure that the data is completely flushed out onto the disk before the next line is executed. To save some time, let’s not compute raster statistics by calling band.ComputeStatistics(False).

    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)

Now, we defined the read and write functions. Read two NLCD GeoTIFF files. Since read_geotiff() returns two items, we need to variables on the left-hand side to retrieve both. The first return value is the array and the second is the dataset of the file. We’ll use the dataset later to retrieve projection and geotransform information.

nlcd01_arr, nlcd01_ds = read_geotiff("nlcd2001_clipped.tif")
nlcd16_arr, nlcd16_ds = read_geotiff("nlcd2016_clipped.tif")

Create a binary (0 or 1) difference array between the two datasets.

nlcd_changed = np.where(nlcd01_arr != nlcd16_arr, 1, 0)

Create a new GeoTIFF file called nlcd_changed.tif.

write_geotiff("nlcd_changed.tif", nlcd_changed, nlcd01_ds)

Plot all three arrays using plt.imshow().

plt.subplot(311)
plt.imshow(nlcd01_arr)

plt.subplot(312)
plt.imshow(nlcd16_arr)

plt.subplot(313)
plt.imshow(nlcd_changed)

plt.show()