Source code for brightfield2fish.data.split_data

import os
import json
import hashlib

import numpy as np
import pandas as pd


[docs]def hashsplit(X, splits={"train": 0.8, "test": 0.2}, salt=1, N=5): r""" Splits a list of items pseudorandomly (but deterministically) based on the hashes of the items. Args: X (list): list of items to be split into non-overlapping groups splits (dict): dict of {name:weight} pairs definiting the desired split salt (str): str(salt) is appended to each list item before hashing N (int): number of significant figures to compute for binning each list item Returns: (dict): {name:indices} for all names in the input split dict Example: >>> hashsplit(list("allen cell institute"), {'train':0.7,'test':0.3}, salt=3, N=8) {'test': [4, 12, 17], 'train': [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 18, 19]} """ # normalize the weights, just in case splits = {k: v / sum(splits.values()) for k, v in splits.items()} # determine bins in [0,1] that correspond to each split bounds = np.cumsum([0.0] + [v for k, v in sorted(splits.items())]) bins = { k: [bounds[i], bounds[i + 1]] for i, (k, v) in enumerate(sorted(splits.items())) } # hash the strings deterministically hashes = [ hashlib.sha512((str(x) + str(salt)).encode("utf-8")).hexdigest() for x in X ] # create some numbers in [0,1] (at N sig figs) from the hashes nums = np.array( [float("".join(filter(str.isdigit, h))[:N]) / 10 ** N for h in hashes] ) # check where the nums fall in [0,1] relative to the bins left and right boundaries inds = {k: np.where((nums > l) & (nums <= r)) for k, (l, r) in bins.items()} # np.where returns a singleton tuple containing an np array, so convert to list return {k: list(*v) for k, v in inds.items()}
[docs]def split_and_save( csv_name="data_by_images_normalized.csv", csv_dir="/allen/aics/modeling/data/brightfield2fish/preprocessed", split_col="file", save_dir="data/splits", splits={"train": 0.7, "valid": 0.15, "test": 0.15}, seed=0, ): r""" Split a csv dataset and save the splits and indices to disk. Args: csv_name (str): csv to be split into non-overlapping groups csv_dir (str): path to directory in which csv resides split_col (str): column to use as id for splitting into groups save_dir (str): path to directory where split csvs and indices should be saved splits (dict): dict of {name:size} by which to split data seed (int): salt fir the hash fuction that does the splitting """ df = pd.read_csv(os.path.join(csv_dir, csv_name)) splits = hashsplit(df[split_col], splits=splits, salt=seed) if not os.path.exists(save_dir): os.makedirs(save_dir) with open(os.path.join(save_dir, "splits.json"), "w") as fp: json.dump({k: [int(i) for i in v] for k, v in splits.items()}, fp) for k, v in splits.items(): df_subset = df.iloc[v, :].reset_index(drop=True) fname = "{}.csv".format(k) fpath = os.path.join(save_dir, fname) df_subset.to_csv(fpath, index=False)