添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Hello,
To get a highly accurate membrane-based cell segmentation I have several cellpose models trained on Different membrane markers (CD3, CD103, etc) and they get me masks the size of our image with ~30-100k cells each. These are more accurate than the cellpose cyto3.

I want to have the most accurate mask we have per cell, so I calculate the IOU between CD3 and CD103 masks first, and if there is overlap above a certain threshold (20%) I discard the CD103 mask and the CD3 mask is the one that goes in the final mask. this also applies to the nuclear mask - if a CD3 or CD103 mask overlaps the nuclear mask we discard it.

Here’s the part that I need some help with - Utilizing my GPU in the most efficient manner. Right now, I have a for loop that first merges the CD3 and CD103 masks into a membrane mask, and then merges that membrane mask with DAPI mask. This used to take tolerably long on ~10-20k cells, but scaling up to 0.6 -1 Million cells has made the execution time scale proportionally, so much so that this execution is projected to take 111 hours. Fitting the whole image on a gpu also started to give me OOM errors and I had to process the image in tiles. This has made the script functional again, but it still took intolerably long. Then, I realized that I only had to check for overlap with cells that are within a certain radius of the current cell of interest, which has cut my projected execution time by around 40%.

I just know that there is a way this can be sped up. I’m on an HPC with a literal A100 gpu, and the script is running a for loop that I have to assume runs sequentially, totally not taking full advantage of the GPUs parallel processing capabilities.
So people can see what I’m talking about, I’m adding my mask merging function below:

def merge_masks_on_tiles(cd3_mask, cd103_mask, dapi_mask, iou_threshold=0.2, tile_size=1024, overlap=128, radius = 1):
    # Ensure masks are uint32
    print("Ensuring masks are cp.uint32...", flush=True)
    cd3_mask = cp.asarray(cd3_mask, dtype=cp.uint32)
    cd103_mask = cp.asarray(cd103_mask, dtype=cp.uint32)
    dapi_mask = cp.asarray(dapi_mask, dtype=cp.uint32)
    print("Done!", flush=True)
    # Get unique IDs for each mask
    print("Getting unique IDs for each mask...", flush=True)
    cd3_ids = cp.unique(cd3_mask)[1:]
    cd103_ids = cp.unique(cd103_mask)[1:]
    dapi_ids = cp.unique(dapi_mask)[1:]
    print("Done!", flush=True)
    # Make DAPI and CD103 IDs unique from CD3 IDs
    print("Making DAPI and CD103 IDs unique from CD3 IDs...", flush=True)
    print(f"adding {cp.max(cd3_mask)} to CD103 IDs...", flush=True)
    cd103_mask[cd103_mask > 0] += cp.max(cd3_mask)
    print(f"adding {cp.max(cd103_mask)} to Nuclear mask IDs...", flush=True)
    dapi_mask[dapi_mask > 0] += cp.max(cd103_mask)
    print("Done!", flush=True)
    # Split masks into tiles
    print("Splitting masks into tiles...", flush=True)
    tile_rows = int(cp.ceil(cd3_mask.shape[0] / tile_size))
    print(f"tile_rows = {tile_rows}", flush=True)
    tile_cols = int(cp.ceil(cd3_mask.shape[1] / tile_size))
    print(f"tile_cols = {tile_cols}", flush=True)
    print(f"Making empty merged mask with shape {cd3_mask.shape}...", flush=True)
    merged_mask = cp.zeros_like(cd3_mask, dtype=cp.uint32)
    print("Done!", flush=True)
    total_tiles = tile_rows * tile_cols
    print(f"total_tiles = {total_tiles}", flush=True)
    tile_counter = 0
    print(f"Starting Tile loooop...", flush=True)
    for row in range(tile_rows):
        for col in range(tile_cols):
            print(f"Starting Tile {row},{col}")
            print("Adding overlap...")
            start_y = row * (tile_size - overlap)
            end_y = min((row + 1) * tile_size, cd3_mask.shape[0])
            start_x = col * (tile_size - overlap)
            end_x = min((col + 1) * tile_size, cd3_mask.shape[1])
            print("Done! Subsetting tile...")
            cd3_tile = cd3_mask[start_y:end_y, start_x:end_x]
            cd103_tile = cd103_mask[start_y:end_y, start_x:end_x]
            dapi_tile = dapi_mask[start_y:end_y, start_x:end_x]
            print("Done! Merging CD3 and CD103 tiles...")
            # Merge CD3 and CD103 tiles
            print("Initialize empty tile...")
            merged_membrane_tile = cp.zeros_like(cd3_tile, dtype=cp.uint32)
            print("Find cell centroids...")
            cd3_centroids = cp.argwhere(cd3_tile > 0).get()  # Convert to NumPy array
            cd103_centroids = cp.argwhere(cd103_tile > 0).get()  # Convert to NumPy array
            print("Build KD-Tree...")
            cd3_tree = KDTree(cd3_centroids)
            print("Query nearby cells...")
            nearby_indices = cd3_tree.query_ball_point(cd103_centroids, r=radius)
            print("Calculate IOU for nearby cells...")
            for i, cd103_indices in enumerate(nearby_indices):
                if len(cd103_indices) > 0:
                    cd103_id = cd103_tile[cd103_centroids[i][0], cd103_centroids[i][1]]
                    cd103_cell_mask = cd103_tile == cd103_id
                    max_iou = 0
                    max_cd3_id = 0
                    for j in cd103_indices:
                        cd3_id = cd3_tile[cd3_centroids[j][0], cd3_centroids[j][1]]
                        cd3_cell_mask = cd3_tile == cd3_id
                        iou = get_iou(cd3_cell_mask, cd103_cell_mask)
                        if iou > max_iou:
                            max_iou = iou
                            max_cd3_id = cd3_id
                    if max_iou > iou_threshold:
                        merged_membrane_tile[cd103_cell_mask] = max_cd3_id
            for cd3_id in cp.unique(cd3_tile):
                merged_membrane_tile[cd3_tile == cd3_id] = cd3_id
            print("Done! Adding remaining CD103 cells to merged membrane tile")
            remaining_cd103_mask = cp.isin(cd103_tile, merged_membrane_tile, invert=True)
            remaining_cd103_ids = cp.unique(cd103_tile[remaining_cd103_mask])
            for cd103_id in remaining_cd103_ids:
                merged_membrane_tile[cd103_tile == cd103_id] = cd103_id
            # Merge membrane and DAPI tiles
            print("Done! Merging membrane and DAPI tiles")
            merged_tile = cp.zeros_like(merged_membrane_tile, dtype=cp.uint32)
            membrane_centroids = cp.argwhere(merged_membrane_tile > 0).get()  # Convert to NumPy array
            dapi_centroids = cp.argwhere(dapi_tile > 0).get()  # Convert to NumPy array
            membrane_tree = KDTree(membrane_centroids)
            nearby_indices = membrane_tree.query_ball_point(dapi_centroids, r=radius)
            for i, dapi_indices in enumerate(nearby_indices):
                if len(dapi_indices) > 0:
                    dapi_id = dapi_tile[dapi_centroids[i][0], dapi_centroids[i][1]]
                    dapi_cell_mask = dapi_tile == dapi_id
                    max_iou = 0
                    max_membrane_id = 0
                    for j in dapi_indices:
                        membrane_id = merged_membrane_tile[membrane_centroids[j][0], membrane_centroids[j][1]]
                        membrane_cell_mask = merged_membrane_tile == membrane_id
                        iou = get_iou(membrane_cell_mask, dapi_cell_mask)
                        if iou > max_iou:
                            max_iou = iou
                            max_membrane_id = membrane_id
                    if max_iou > iou_threshold:
                        merged_tile[dapi_cell_mask] = max_membrane_id
            remaining_membrane_mask = cp.isin(merged_membrane_tile, merged_tile, invert=True)
            remaining_membrane_ids = cp.unique(merged_membrane_tile[remaining_membrane_mask])
            for membrane_id in remaining_membrane_ids:
                merged_tile[merged_membrane_tile == membrane_id] = membrane_id
            remaining_dapi_mask = cp.isin(dapi_tile, merged_tile, invert=True)
            remaining_dapi_ids = cp.unique(dapi_tile[remaining_dapi_mask])
            for dapi_id in remaining_dapi_ids:
                merged_tile[dapi_tile == dapi_id] = dapi_id            
            # Update the merged mask with the processed tile
            print("Done! Update the merged mask with the processed tile")
            merged_mask[start_y:end_y, start_x:end_x] = merged_tile
            tile_counter += 1
            print_progress_bar(tile_counter, total_tiles, prefix='Processing tiles:', suffix='Complete', length=50)
    return merged_mask

Don’t mind all my print statements… I want to be able to identify which steps take longer. Once it’s working I’ll strip out everything unnecessary. I realize this might be a stack overflow question, but in case there’s any imaging GPU people here that know how you accelerate processes like this with gpus?
In theory, I think all the cells within a radius should be able to be calculated at the same time and maybe then I can increase the tile size till it takes up the max gpu memory.

Anyway, Thank you for any help at all.

Hi there, welcome to the forum

I can currently think of two possible directions of attack. One which is leveraging more resources, and another which is simplifying the problem.

  • First, given that this problem seems to be separable into tiles, have you considered using the dask imaging library? Truthfully, I haven’t yet, so I can’t speak to how many niggling details you’ll have to work out. But, hanging around the imaging space long enough suggests that this is the ideal use-case for the work you’re aiming to do.
  • This would allow parallel processing on the tile level, which would allow a greater capacity of your GPU to be used at a time (without storing the entire image in GPU memory).

  • Without sample images, it’s hard to tell the exact scope of your problem. But it looks like before throwing more computational firepower at the problem, it might be good to see if it can be simplified first.
  • My understanding of your problem scope though is: CD3 is the best marker for a cell, then CD103, then the nuclear stain, i.e. for a given cell, the preferred label would be CD3 > CD103 > DAPI/Nuclear. And for each of labels, you consider the nearest labels (via a KD-Tree) to subset the matches to determine if any labels exceed the IOU threshold. Then you keep specific labels according to your preferences.

    If this is correct, it seems like it would be easier if you inverted the problem a bit by starting with checking for binary-mask intersection. As in, start by looking for nuclear/DAPI labels which do not intersect (enough) pixels with the masks of either the CD103 or CD3 channels. These mark unique cells, so you can avoid calculating KD-Trees on them. You can do the same for the CD103 labels against the CD3 labels. This can be a fairly fast operation especially if performed using Cupy.

    Since this reduces the total number of cell calculations in a heavily nested chain, it might have a fairly large effect on the processing time. Hard to say, but it could be worth a shot.

    A contrived example of what I mean is below. You could conceivably call any CD103 labels ‘real’ which have less than 20% Area (or some other fraction) intersecting with the more ‘reliable’ channel.

    Hope any of this helps!