High level wrappers

High level API to fastgs functionality
bands.get_captions([["B03","B02","B01"], ["B01"], ["B03"]]),
(['B03,B02,B01', 'B01', 'B03'],)

MSDescriptor

We need a class that provides basic information about all the channels in the source data. The initial fields are based on the requirements of Sentinel2 images.


source

MSDescriptor

 MSDescriptor ()

Initialize self. See help(type(self)) for accurate signature.

We use factories to create sensible defaults


source

MSDescriptor.from_bands


source

MSDescriptor.from_band_brgt


source

MSDescriptor.from_all

@patch
def num_bands(self: MSDescriptor) -> int:
    return len(self.band_ids)
test_eq(MSDescriptor.from_bands(["B0"]).num_bands(),1)

The list of ‘raw’ bands for sentinel 2 images is “B01”,“B02”,“B03”,“B04”,“B05”,“B06”,“B07”,“B08”,“B8A”, “B09”,“B10”,“B11”,“B12”,“AOT”]

As described here, these images are naturall “dark” and we allow them to be brightened for display by providing a list of brightening multipliers. For this factory, I have selected values that seem to work well with our data, but it is by no means authoritative.

The third parameter is a list of the resolution of each raw band.

Finally we provide some named groups of 3 bands each, that have been found useful in providing false color images for different applications. The goal is to create (multiple) RGB images, corresponding to such sets of bands, for each multi-spectral tensor.


source

createSentinel2Descriptor

 createSentinel2Descriptor ()
sentinel2 = createSentinel2Descriptor()

This method lists all bands of a given resolution.


source

MSDescriptor.get_res_ids

 MSDescriptor.get_res_ids (res:int)
test_eq(sentinel2.get_res_ids(10),["B02","B03","B04","B08"])

We can find the brightness multipliers corresponding to a list of channel names with this


source

MSDescriptor.get_brgtX

 MSDescriptor.get_brgtX (ids:list[str])
test_eq(sentinel2.get_brgtX(["B8A","B01"]), [2.5,2.5])
test_eq(sentinel2.get_brgtX(sentinel2.rgb_combo["natural_color"]), [3.75,4.25,4.75])

… and so also the brightness value lists corresponding to name lists


source

MSDescriptor.get_brgtX_list

 MSDescriptor.get_brgtX_list (ids_list:list[list[str]])
test_eq(
    sentinel2.get_brgtX_list([sentinel2.rgb_combo["color_infrared"],["B12","B11"]]), 
    [[1.7,3.75,4.25],[2.2,1.6]]
)

MSData

We then create a MSData wrapper class which takes parameters that specifiy how to load the multi spectral image into a TensorImageMS object.


source

MSData

 MSData ()

Initialize self. See help(type(self)) for accurate signature.


source

MSData.from_loader

 MSData.from_loader (ms_descriptor:__main__.MSDescriptor,
                     band_ids:list[str], chn_grp_ids:list[list[str]],
                     tg_fn:Callable[[list[str],Any],torch.Tensor])

source

MSData.from_files

 MSData.from_files (ms_descriptor:__main__.MSDescriptor,
                    band_ids:list[str], chn_grp_ids:list[list[str]],
                    files_getter:Callable[[list[str],Any],list[str]],
                    chan_io_fn:Callable[[list[str]],torch.Tensor])
from fastgs.test.io import *
def get_input(stem: str) -> str:
    "Get full input path for stem"
    return "./images/" + stem

def tile_img_name(chn_id: str, tile_num: int) -> str:
    "File name from channel id and tile number"
    return f"Sentinel20m-{chn_id}-20200215-{tile_num:03d}.png"

def get_channel_filenames(chn_ids, tile_idx):
    "Get list of all channel filenames for one tile idx"
    return [get_input(tile_img_name(x, tile_idx)) for x in chn_ids]

seg_codes = ["not-cloudy","cloudt"]

We can create a sentinel data loader for only the RGB channels

rgb_bands = MSData.from_files(
    sentinel2,
    ["B02","B03","B04"],
    [sentinel2.rgb_combo["natural_color"]],
    get_channel_filenames,
    read_multichan_files
)

where read_multichan_files_as_tensor is defined here

or we might choose to only look at the 10m resolution bands

tenm_bands = MSData.from_files(
    sentinel2,
    sentinel2.get_res_ids(10),
    [sentinel2.rgb_combo["natural_color"], ["B08"]],
    get_channel_filenames,
    read_multichan_files
)

or even 11 channels of sentinel 2 data

elvn_bands = MSData.from_files(
    sentinel2,
    ["B02","B03","B04","B05","B06","B07","B08","B8A","B11","B12","AOT"],
    [sentinel2.rgb_combo["natural_color"], ["B07","B06","B05"],["B12","B11","B8A"],["B08"]],
    get_channel_filenames,
    read_multichan_files
)

source

MSData.load_image

 MSData.load_image (img_id)

source

MSData.num_channels

 MSData.num_channels ()
rgb_tensor = rgb_bands.load_image(66)
test_eq(rgb_bands.num_channels(),3)
rgb_tensor.show()
[<AxesSubplot:title={'center':'B04,B03,B02'}>]

tenm_tensor = tenm_bands.load_image(66)
test_eq(tenm_bands.num_channels(),4)
tenm_tensor.show()
[<AxesSubplot:title={'center':'B04,B03,B02'}>,
 <AxesSubplot:title={'center':'B08'}>]

elvn_tensor = elvn_bands.load_image(66)
test_eq(elvn_bands.num_channels(),11)
elvn_tensor.show()
[<AxesSubplot:title={'center':'B04,B03,B02'}>,
 <AxesSubplot:title={'center':'B07,B06,B05'}>,
 <AxesSubplot:title={'center':'B12,B11,B8A'}>,
 <AxesSubplot:title={'center':'B08'}>]

MaskData

Finally, for convenience, we provide a wrapper class to load mask data


source

MaskData

 MaskData ()

Initialize self. See help(type(self)) for accurate signature.


source

MaskData.load_mask

 MaskData.load_mask (img_id)

source

MaskData.num_channels

 MaskData.num_channels ()

source

MaskData.from_loader

 MaskData.from_loader (mask_id:str,
                       tg_fn:Callable[[list[str],Any],torch.Tensor],
                       mask_codes:list[str])

source

MaskData.from_files

 MaskData.from_files (mask_id:str,
                      files_getter:Callable[[list[str],Any],list[str]],
                      mask_io_fn:Callable[[list[str]],torch.Tensor],
                      mask_codes:list[str])
masks = MaskData.from_files("LC",get_channel_filenames,read_mask_file,["non-building","building"])
test_eq(masks.num_channels(),2)
mask = masks.load_mask(66)
mask.show()
<AxesSubplot:>

MSAugment

A wrapper class for augmentations


source

MSAugment

 MSAugment ()

Initialize self. See help(type(self)) for accurate signature.


source

MSAugment.from_augs

 MSAugment.from_augs (train_aug=None, valid_aug=None)

Transforms

Next we create the various transforms required for the fastai pipeline


source

MSData.create_xform_block

 MSData.create_xform_block ()

source

MaskData.create_xform_block

 MaskData.create_xform_block ()

source

MSAugment.create_item_xforms

GSModel

We create a wrapper class the encapsulates the model we use


source

GSUnetModel.load_learner

 GSUnetModel.load_learner (model_path:str, dl)

source

GSUnetModel.create_learner

 GSUnetModel.create_learner (dl, pretrained=False, **kwargs)

source

GSUnetModel.from_all

 GSUnetModel.from_all (model, ms_data:__main__.MSData,
                       mask_codes:[<class'str'>], loss_func=FlattenedLoss
                       of CrossEntropyLoss(), metrics=<fastai.metrics.Dice
                       object at 0x7fd10d591610>)

source

GSUnetModel

 GSUnetModel (model, n_in, n_out, loss_func, metrics)

Initialize self. See help(type(self)) for accurate signature.

FastGS

Finally we have a master wrapper class which provides the high level api to create fastai datablocks and learners.


source

FastGS

 FastGS ()

Initialize self. See help(type(self)) for accurate signature.


source

FastGS.for_inference

 FastGS.for_inference (ms_data:__main__.MSData, mask_codes:[<class'str'>])

source

FastGS.for_training

 FastGS.for_training (ms_data:__main__.MSData,
                      mask_data:__main__.MaskData,
                      ms_aug:__main__.MSAugment=<__main__.MSAugment object
                      at 0x7fd10d6035b0>)
fgs = FastGS.for_training(elvn_bands,masks)

source

FastGS.create_data_block

 FastGS.create_data_block (splitter=<function _inner>)
db = fgs.create_data_block()
dl = db.dataloaders([66]*10,bs=8)

source

FastGS.load_learner

 FastGS.load_learner (model_path, dl)

source

FastGS.create_learner

 FastGS.create_learner (dl, reweight='avg')
learner = fgs.create_learner(dl,resnet18)
/opt/homebrew/Caskroom/miniforge/base/envs/fastgs/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/opt/homebrew/Caskroom/miniforge/base/envs/fastgs/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)