Source code for brightfield2fish.data.dataset

import warnings
import numpy as np
import pandas as pd
from scipy.ndimage.interpolation import zoom
import torch
from torch.utils.data import Dataset
from aicsimageio import AICSImage
from brightfield2fish.data.utils import RandomCrop, normalize


[docs]class FishDataframeDatasetTIFF(Dataset): r""" Dataset class for Brightfield -> FISH prediction that reads single channel tiffs. Args: df (pd.DataFrame): input dataframe that specifies dataset csv (bool): if True, accept a csv file path rahter than a DataFrame channel_content (str): what content to pair with brightfiled, e.g. DNA resize_original (float, tuple, or None): if not None, how to resize the original 3D images random_crop (tuple, or None): if not None, tuple of z,y,x sizes (in pixels) to which image woll be randomly cropped math_dtype (numpy.dtype): data type in which internal computations will be done out_dtype (numpy.dtype): data type that will be output output_torch (boool): if True, output a torch.tensor rather than a np.array channel_dim (bool): if True, include a singleton channel dimension for output 3D images return_tuple (bool): if True, return images as (brightfield, target), else return as a dict """ def __init__( self, df, csv=False, channel_content="DNA", resize_original=None, random_crop=None, math_dtype=np.float64, out_dtype=np.float32, output_torch=True, channel_dim=True, return_tuple=True, ): if csv: df = pd.read_csv(df) df_channel = df[df["channel_content"] == channel_content].reset_index( drop=True )[["file", "normalized_single_channel_image"]] df_brightf = df[ (df["channel_content"] == "Brightfield") & (df["file"].isin(df_channel["file"])) ].reset_index(drop=True)[["file", "normalized_single_channel_image"]] df_channel.rename( {"normalized_single_channel_image": "Target"}, axis="columns", inplace=True ) df_brightf.rename( {"normalized_single_channel_image": "Brightfield"}, axis="columns", inplace=True, ) self.df = df_brightf.merge(df_channel, how="inner", on="file")[ ["Brightfield", "Target"] ].reset_index(drop=True) self._resize_original = resize_original self._random_crop = random_crop self._math_dtype = math_dtype self._out_dtype = out_dtype self._output_torch = output_torch self._channel_dim = channel_dim self._return_tuple = return_tuple def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) out = { k: AICSImage(row[k]).get_image_data("ZYX") for k in ("Brightfield", "Target") } if self._resize_original is not None: out = { k: zoom(v.astype(self._math_dtype), self._resize_original).astype( self._out_dtype ) for k, v in out.items() } out = { k: normalize(v.astype(self._math_dtype), content=k).astype(self._out_dtype) for k, v in out.items() } if self._random_crop is not None: randomcropper = RandomCrop(out["Brightfield"], self._random_crop) out = {k: randomcropper.crop(v) for k, v in out.items()} if self._output_torch: out = {k: torch.from_numpy(v) for k, v in out.items()} if self._channel_dim: out = {k: torch.unsqueeze(v, 0) for k, v in out.items()} if self._return_tuple: out = (out["Brightfield"], out["Target"]) return out
[docs]class FishSegDataframeDatasetTIFF(Dataset): r""" Dataset class for Brghtfield -> FISH prediction that reads in 3D tiffs for inputs and 2d fish segs for targets. Extrudes the 2d data along z for image to image prediction task. Args: df (pd.DataFrame): input dataframe that specifies dataset csv (bool): if True, accept a csv file path rahter than a DataFrame channel_content (str): what content to pair with brightfiled, e.g. DNA resize_original (float, tuple, or None): if not None, how to resize the original 3D images random_crop (tuple, or None): if not None, tuple of z,y,x sizes (in pixels) to which image woll be randomly cropped math_dtype (numpy.dtype): data type in which internal computations will be done out_dtype (numpy.dtype): data type that will be output output_torch (boool): if True, output a torch.tensor rather than a np.array channel_dim (bool): if True, include a singleton channel dimension for output 3D images return_tuple (bool): if True, return images as (brightfield, target), else return as a dict fish_3d (bool): if True, return fish image as 3D, extruded along z axis bf_clip_percentiles (list): lower and upper percentiales of pixel intesity at which to clip the brightfield image normalize (bool): if True, normalize the brightfield image to zero mean and unit varinace, and normalize the fish image to min zero and max one """ def __init__( self, df, csv=False, channel_content="MYH7", resize_original=None, random_crop=None, math_dtype=np.float64, out_dtype=np.float32, output_torch=True, channel_dim=True, return_tuple=True, fish_3d=True, bf_clip_percentiles=[0.01, 99.99], normalize=True, ): if csv: df = pd.read_csv(df) self.df = ( df[df["probe name"] == channel_content][["file", "fish segmetation path"]] .copy() .reset_index(drop=True) .rename( {"fish segmetation path": "Target", "file": "Brightfield"}, axis="columns", inplace=False, ) ) self._resize_original = resize_original self._random_crop = random_crop self._math_dtype = math_dtype self._out_dtype = out_dtype self._output_torch = output_torch self._channel_dim = channel_dim self._return_tuple = return_tuple self._fish_3d = fish_3d self._bf_clip_percentiles = bf_clip_percentiles self._normalize = normalize def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) out = { k: AICSImage(row[k]).get_image_data("ZYX") for k in ("Brightfield", "Target") } if self._resize_original is not None: out = { k: zoom( v.astype(self._math_dtype), self._resize_original, order=1, mode="reflect", ).astype(self._out_dtype) for k, v in out.items() } out = { k: normalize(v.astype(self._math_dtype), content=k).astype(self._out_dtype) for k, v in out.items() } if self._bf_clip_percentiles is not None: out["Brightfield"] = np.clip( out["Brightfield"], a_min=np.percentile(out["Brightfield"], self._bf_clip_percentiles[0]), a_max=np.percentile(out["Brightfield"], self._bf_clip_percentiles[1]), ) if self._normalize: out = {k: normalize(v, content=k) for k, v in out.items()} if self._random_crop is not None: randomcropper = RandomCrop(out["Brightfield"], self._random_crop) out = {k: randomcropper.crop(v) for k, v in out.items()} if self._output_torch: out = {k: torch.from_numpy(v) for k, v in out.items()} if self._fish_3d: out["Target"] = out["Target"].expand( *out["Brightfield"].shape ) # extrudes the 2d fish seg in 3d if self._channel_dim: out = {k: torch.unsqueeze(v, 0) for k, v in out.items()} if self._return_tuple: out = (out["Brightfield"], out["Target"]) return out