Tutorial: Train and evaluate a model on with images of spots

Contents

Tutorial: Train and evaluate a model on with images of spots#

Spot generation#

First, we need to generate our dataset. For this we use the data generator provided in the package.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("..")

import matplotlib.patches as patches
from matplotlib import colormaps as cmaps
import os
from tqdm import tqdm
import yaml
from importlib import reload  # Python 3.4+
import pandas as pd

import image_generator as ig

from yundon_functions import *
pi = np.pi
from multiprocessing import Pool
from PIL import Image
from cv2 import resize
from itertools import repeat
from matplotlib.patches import Rectangle, Circle
# %matplotlib widget
import cv2
from ultralytics.utils import ops


def plot_result(ax,img,objects, vmin=1.95e4, vmax=2.05e4):
    if isinstance(img, PIL.Image.Image) or isinstance(img, np.ndarray):
        pass
    elif(type(img) == str):
        img = PIL.Image.open(img)
    ax.imshow(img,cmap="grey",vmin=vmin,vmax=vmax)
    ax.axis("off")
    for i,obj in objects.iterrows():
        l,x,y,w,h = obj[:5]
        rect = Rectangle((x-0.5*w,y-0.5*h),h,w, linewidth=1, edgecolor="yellow", facecolor='none')
        tx,ty = rect.get_xy()
        ax.add_patch(rect)
        tx+=6
        ty-=12
        circle = Circle((x,y),1, facecolor="red",edgecolor="red")
        ax.add_patch(circle)
        if l == "Ripple":
            ax.text(tx,ty,f"{obj["z"]*0.134:.3f}",fontsize="small",bbox=dict(facecolor='white', alpha=0.5,))
def plot_gt(path,ax,imgsz=[512,512],vmin=1.95e4,vmax=2.05e4):
    data=np.atleast_2d(np.loadtxt(path.replace("images", "labels").replace("jpg","txt").replace("tif","txt"))).T
    # print(data)
    if(len(data)==0):
         return plot_result(ax,path, np.empty(0),np.empty((0,4)),np.empty(0),np.empty((0,0)),vmin=vmin,vmax=vmax)
    cls = data[0]
    bboxes = data[1:5].T*(imgsz*2)
    z = data[5]
    kpts = data[7:].T*imgsz
    # print(bboxes)
    
    return plot_result(ax,path, cls,bboxes,z,kpts, vmin=vmin,vmax=vmax)
import image_generator as ig

spot_params = {"label":"Spot", "i":("gaussian", [2**15]), "s":8, "n":(30,90)} # Parameters for the image generator
parameter_list = [spot_params]
image_size = 540,640       # Dimension 
distance = 20               # Min distance between points to avoid overlap
offset = 0                # Boundary padding
snr = 5


objects = ig.getRandom(parameter_list, image_size, distance=distance, offset=offset) # Generate object positions for test image
image,_ = ig.generateImage(objects, image_size,snr=float(snr), background=0)                 # Generate image with objects

# Plot test image
plt.close("all")
plt.figure(figsize=(16,5))
ax = plt.subplot(1,2,1)
ax.axis("off")
plt.imshow(image, cmap="grey")
plt.colorbar()

ax = plt.subplot(1,2,2)
plt.title("labeled")
plot_result(ax,image,objects, vmin=image.min(), vmax=image.max())
ax.axis("off")

# plt.tight_layout()
plt.suptitle(f"Test image with generated objects at SNR={snr}")
plt.show()

print(objects)
_images/18a109a6b5dddd16230200d6af5e5255cd00bff7d6a4544f950437e471d315f4.png
   label           x           y   w   h             i  s
0   Spot  582.803059  450.646852  32  32  32767.310194  8
1   Spot  338.485862  378.792683  32  32  32767.390046  8
2   Spot  499.688936  344.170586  32  32  32768.818425  8
3   Spot   89.162469  103.972775  32  32  32765.541459  8
4   Spot   43.183581  156.771007  32  32  32768.041847  8
..   ...         ...         ...  ..  ..           ... ..
61  Spot  179.588831   82.306793  32  32  32770.168981  8
62  Spot  470.131391  132.078602  32  32  32766.958940  8
63  Spot  327.914405   45.142683  32  32  32769.269676  8
64  Spot  287.459739  369.075568  32  32  32767.199348  8
65  Spot  356.549145  538.574278  32  32  32767.588308  8

[66 rows x 7 columns]

Now, let’s generate some training data!

subdir = '../datasets/Dataset_spots/'
nimages = [10000,1000, 10000,1000]
folders = ["train_snr8-12","val_snr8-12", "train_snr15-20","val_snr15-20"] #the names for our folders
snr_range = [[8,12], [8,12], [15,20], [15,20]] # SNR range for the training and validation sets

import time
reload(ig)
if not os.path.exists(subdir):
    os.mkdir(subdir)

l_dir = subdir  + 'labels'
i_dir = subdir + 'images'
if not os.path.exists(i_dir):
    os.mkdir(i_dir)
if not os.path.exists(l_dir):
    os.mkdir(l_dir)
for i, prefix in enumerate(folders):
    image_dir = i_dir + '/' + prefix + '/'
    if not os.path.exists(image_dir):
        os.mkdir(image_dir)
    label_dir = l_dir + '/' + prefix + '/'
    if not os.path.exists(label_dir):
        os.mkdir(label_dir)
for i, prefix in enumerate(folders):
    vectorized_index = np.vectorize(lambda labels,item: labels.index(item))
    rngs = []
    def init():
        # Initialize the random number generator for each process
        global rng #the rng variable is global so that it can be used in the call function
        rng = np.random.default_rng()
        time.sleep(rng.uniform(0,2))
    def call(j):
        labelname = l_dir + '/' + prefix + '/' + 'image_{:05d}.txt'.format(j,2)    
        fname = i_dir + '/' + prefix + '/' + 'image_{:05d}.tif'.format(j,2)


        objects = ig.getRandom(parameter_list, image_size, rng=rng, offset=offset, distance=distance)

        if(len(objects) == 0):
            f = open(labelname, "w")
            f.write("\n")
            f.close()
            return

        image, info_dict = ig.generateImage(objects, image_size , snr=snr_range, rng=rng)
        image = (image).astype(np.uint16) #cast image to 16 bit uint for tiff
        image = Image.fromarray(image).save(fname, quality=100) # save the current image file to disk

        objects[["x","w"]] /= image_size[1] # scale x and w to [0,1]
        objects[["y","h"]] /= image_size[0] # scale y and h to [0,1]

        # turn label into int:
        objects["label"]=pd.factorize(objects["label"])[0]

        # add keypoint coordinates
        objects[["kpt_x", "kpt_y"]] = objects[["x", "y"]]

        np.savetxt(labelname, objects.values, fmt='%g')# save the labels to disk (in txt files) in the form: class x y w h intensity(unused) size(unused) kpt_x kpt_y
    with Pool(initializer=init) as pool:
        for _ in tqdm(pool.imap_unordered(call,range(nimages[i])),total=nimages[i], desc=prefix):
            pass
    global rng
    rng = np.random.default_rng()
train_snr15-20: 100%|██████████| 10000/10000 [11:21<00:00, 14.68it/s]
val_snr15-20: 100%|██████████| 1000/1000 [01:08<00:00, 14.57it/s]

For testing, let’s also generate data for different SNR levels \(\in [1,20]\).

Hide code cell content
subdir = '../datasets/Dataset_spots/'
snr_range = [1,20] # SNR range for the test set
levels = 20
nimages = [1000]*levels
test_noiselevels = list(np.linspace(*snr_range, levels))
folders = [f'test_snr{i}' for i in test_noiselevels]
import time
reload(ig)
if not os.path.exists(subdir):
    os.mkdir(subdir)

l_dir = subdir  + 'labels'
i_dir = subdir + 'images'
if not os.path.exists(i_dir):
    os.mkdir(i_dir)
if not os.path.exists(l_dir):
    os.mkdir(l_dir)
for i, prefix in enumerate(folders):
    image_dir = i_dir + '/' + prefix + '/'
    if not os.path.exists(image_dir):
        os.mkdir(image_dir)
    label_dir = l_dir + '/' + prefix + '/'
    if not os.path.exists(label_dir):
        os.mkdir(label_dir)

for i, prefix in enumerate(folders):
    vectorized_index = np.vectorize(lambda labels,item: labels.index(item))
    rngs = []
    def init():
        # Initialize the random number generator for each process
        global rng #the rng variable is global so that it can be used in the call function
        rng = np.random.default_rng()
        time.sleep(rng.uniform(0,2))
    def call(j):
        labelname = l_dir + '/' + prefix + '/' + 'image_{:05d}.txt'.format(j,2)    
        fname = i_dir + '/' + prefix + '/' + 'image_{:05d}.tif'.format(j,2)


        objects = ig.getRandom(parameter_list, image_size, rng=rng, distance=distance, offset=offset)

        if(len(objects) == 0):
            f = open(labelname, "w")
            f.write("\n")
            f.close()
            return

        image, info_dict = ig.generateImage(objects, image_size, snr=test_noiselevels[i], rng=rng)
        image = (image).astype(np.uint16) #cast image to 16 bit uint for tiff
        image = Image.fromarray(image).save(fname, quality=100)
       
        objects[["x","w"]] /= image_size[1] # scale x and w to [0,1]
        objects[["y","h"]] /= image_size[0] # scale y and h to [0,1]

        # turn label into int:
        objects["label"]=pd.factorize(objects["label"])[0]

        # add keypoint coordinates
        objects[["kpt_x", "kpt_y"]] = objects[["x", "y"]]

        np.savetxt(labelname, objects.values, fmt='%g')# save the labels to disk (in txt files) in the form: class x y w h intensity(unused) size(unused) kpt_x kpt_y
    with Pool(initializer=init) as pool:
        for _ in tqdm(pool.imap_unordered(call,range(nimages[i])),total=nimages[i], desc=prefix):
            pass
    global rng
    rng = np.random.default_rng()
test_snr1.0: 100%|██████████| 1000/1000 [01:08<00:00, 14.52it/s]
test_snr2.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.71it/s]
test_snr3.0: 100%|██████████| 1000/1000 [01:06<00:00, 15.10it/s]
test_snr4.0: 100%|██████████| 1000/1000 [01:06<00:00, 15.06it/s]
test_snr5.0: 100%|██████████| 1000/1000 [01:05<00:00, 15.31it/s]
test_snr6.0: 100%|██████████| 1000/1000 [01:06<00:00, 15.00it/s]
test_snr7.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.92it/s]
test_snr8.0: 100%|██████████| 1000/1000 [01:06<00:00, 15.09it/s]
test_snr9.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.76it/s]
test_snr10.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.74it/s]
test_snr11.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.74it/s]
test_snr12.0: 100%|██████████| 1000/1000 [01:08<00:00, 14.52it/s]
test_snr13.0: 100%|██████████| 1000/1000 [01:06<00:00, 15.09it/s]
test_snr14.0: 100%|██████████| 1000/1000 [01:05<00:00, 15.19it/s]
test_snr15.0: 100%|██████████| 1000/1000 [01:06<00:00, 15.01it/s]
test_snr16.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.91it/s]
test_snr17.0: 100%|██████████| 1000/1000 [01:06<00:00, 14.97it/s]
test_snr18.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.76it/s]
test_snr19.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.81it/s]
test_snr20.0: 100%|██████████| 1000/1000 [01:07<00:00, 14.85it/s]

Generate the corresponding yaml files. These contain metadata about the data.

subdir = '../datasets/Dataset_spots/'

file_path = '../datasets/spots_snr8-12.yaml'

dict = {}

dict["train"]=os.getcwd() + '/' + subdir + "images/" + "train_snr8-12" #directory of training images
dict["val"]=os.getcwd() + '/' + subdir + "images/" + "val_snr8-12" #directory of validation images
# the label directory is inferred automatically by replacing images/ with labels/

dict["kpt_shape"] = [1,2] #the shape of the keypoints: [n_keypoints, dim]. here, we use 1 keypoint (particle center), having 2 dimensions
dict["num_extra_parameters"] = 2 # extra parameters provided in the data. here: intensity, size
dict["extra_param_names"] = ["intensity", "size"]
dict['nc'] = 1 # number of classes
dict['names'] = ["Spot"] #class names

with open(file_path, 'w') as file:
    yaml.dump(dict, file, default_flow_style=True) #store to disk

file_path = '../datasets/spots_snr15-20.yaml'

dict = {}

dict["train"]=os.getcwd() + '/' + subdir + "images/" + "train_snr15-20"
dict["val"]=os.getcwd() + '/' + subdir + "images/" + "val_snr15-20"

dict["kpt_shape"] = [1,2] #the shape of the keypoints: [n_keypoints, dim]. here, we use 1 keypoint (particle center), having 2 dimensions
dict["num_extra_parameters"] = 2 # extra parameters provided in the data. here: intensity, size
dict["extra_param_names"] = ["intensity", "size"]
dict['nc'] = 1 # number of classes
dict['names'] = ["Spot"] #class names

with open(file_path, 'w') as file:
    yaml.dump(dict, file, default_flow_style=True)




#for the noise levels, create separate files
levels_dir = subdir+"/levels"
if not os.path.exists(levels_dir):
    os.mkdir(levels_dir)

for folder in folders:
    file_path = levels_dir + f'/dataset_{folder}.yaml'
    print(file_path)

    dict = {}
    i_dir =  os.getcwd() + '/' + subdir + "images/" + folder
    dict["train"]=os.getcwd() + '/' + subdir + "images/" + "train"
    dict["val"]=i_dir
    dict["kpt_shape"] = [1,2]
    dict["num_extra_parameters"] = 2
    dict['nc'] = 1
    dict['names'] = ["Spot"]

    with open(file_path, 'w') as file:
        yaml.dump(dict, file, default_flow_style=True)

Lets now train our model with the dataset!

from YOLOtrack11 import YOLOtrack11

# model = YOLOtrack11("yolo11n_zaxis.yaml")
model = YOLOtrack11("yolo11n_spots.pt")

results = model.train(data="datasets/spots_snr8-12.yaml", epochs=1, imgsz=(640,540), background=0,window=2**16-1, level=0,  z=0, warmup_epochs=0, pose=1e5)#set z loss to 0 because we don't have z labels
model.save("yolo11n_spots_8-12.pt")
Hide code cell output
New https://pypi.org/project/ultralytics/8.3.157 available 😃 Update with 'pip install -U ultralytics'
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
trainer: task=zaxis, mode=train, model=yolo11n_spots.pt, data=datasets/spots_snr8-12.yaml, epochs=1, time=None, patience=100, batch=16, imgsz=(640, 540), save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=train257, exist_ok=False, pretrained=True, optimizer=SGD, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=False, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=False, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=True, opset=None, workspace=4, nms=False, lr0=0.005, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=100000.0, kobj=1.0, z=0, label_smoothing=0.0, nbs=64, hsv_h=0.0, hsv_s=0.0, hsv_v=0.0, degrees=0.0, translate=0.0, scale=0.0, shear=0.0, perspective=0.0, flipud=0.5, fliplr=0.5, bgr=0.0, mosaic=0.0, mixup=0.0, copy_paste=0.0, copy_paste_mode=flip, auto_augment=None, erasing=0.0, crop_fraction=1.0, cfg=None, tracker=botsort.yaml, background=0, level=0, window=65535, save_dir=/home/jupyter-till/git/ultralytics/runs/zaxis/train257

                   from  n    params  module                                       arguments                     
  0                  -1  1       176  ultralytics.nn.modules.conv.Conv             [1, 16, 3, 2]                 
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]                
  2                  -1  1      6640  ultralytics.nn.modules.block.C3k2            [32, 64, 1, False, 0.25]      
  3                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]                
  4                  -1  1     26080  ultralytics.nn.modules.block.C3k2            [64, 128, 1, False, 0.25]     
  5                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
  6                  -1  1     87040  ultralytics.nn.modules.block.C3k2            [128, 128, 1, True]           
  7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]              
  8                  -1  1    346112  ultralytics.nn.modules.block.C3k2            [256, 256, 1, True]           
  9                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]                 
 10                  -1  1    249728  ultralytics.nn.modules.block.C2PSA           [256, 256, 1]                 
 11                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 12             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 13                  -1  1    111296  ultralytics.nn.modules.block.C3k2            [384, 128, 1, False]          
 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 15             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  1     32096  ultralytics.nn.modules.block.C3k2            [256, 64, 1, False]           
 17                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]                
 18            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 19                  -1  1     86720  ultralytics.nn.modules.block.C3k2            [192, 128, 1, False]          
 20                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
 21            [-1, 10]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  1    378880  ultralytics.nn.modules.block.C3k2            [384, 256, 1, True]           
 23        [16, 19, 22]  1    574303  YOLOtrack11.model.ZAxis                      [1, 2, [1, 2], [64, 128, 256]]
YOLO11n-zaxis summary: 211 layers, 2,733,183 parameters, 2,733,167 gradients, 6.9 GFLOPs

Transferred 625/625 items from pretrained weights
Freezing layer 'model.23.dfl.conv.weight'
WARNING ⚠️ updating to 'imgsz=640'. 'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'
train: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/train_snr8-12.cache... 10000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 10000/10000 [00:00<?, ?it/s]
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/val_snr8-12.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
WARNING ⚠️ 'label_smoothing' is deprecated and will be removed in in the future.
optimizer: SGD(lr=0.005, momentum=0.937) with parameter groups 93 weight(decay=0.0), 106 weight(decay=0.0005), 105 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 8 dataloader workers
Logging results to /home/jupyter-till/git/ultralytics/runs/zaxis/train257
Starting training for 1 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss zaxis_loss  pose_loss  kobj_loss  Instances       Size
        1/1      4.44G    0.09117     0.1047      0.755          0        7.3          0        908        640: 100%|██████████| 625/625 [00:34<00:00, 18.18it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 32/32 [00:02<00:00, 12.08it/s]
                   all       1000      59836          1          1      0.995      0.995   3.28e+04      0.711

1 epochs completed in 0.010 hours.
Optimizer stripped from /home/jupyter-till/git/ultralytics/runs/zaxis/train257/weights/last.pt, 5.8MB
Optimizer stripped from /home/jupyter-till/git/ultralytics/runs/zaxis/train257/weights/best.pt, 5.8MB

Validating /home/jupyter-till/git/ultralytics/runs/zaxis/train257/weights/best.pt...
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
YOLO11n-zaxis summary (fused): 118 layers, 2,725,303 parameters, 0 gradients, 6.7 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 32/32 [00:02<00:00, 15.66it/s]
                   all       1000      59836          1          1      0.995      0.995   3.28e+04      0.711
Speed: 0.0ms preprocess, 0.5ms inference, 0.0ms loss, 0.2ms postprocess per image

Lets now test our model out:

from YOLOtrack11 import YOLOtrack11

model = YOLOtrack11("yolo11n_spots_8-12.pt")
print("loaded")
dataset_path = "../datasets/Dataset_spots/images/test_snr5.0"
imgsz=640,540
test_images = os.listdir(dataset_path)
random_image = lambda: dataset_path+"/"+test_images[np.random.randint(0,len(test_images))]
loaded
Hide code cell source
def plot_result(ax,img,cls,bboxes=repeat(None),z=None,kpts=None, conf=None,vmin=1.95e4,vmax=2.05e4):
    is_conf = conf is not None
    if not is_conf:
         conf = np.zeros_like(cls)
    if isinstance(img, PIL.Image.Image):
        pass
    elif(type(img) == str):
        img = PIL.Image.open(img)
    print(vmin,vmax)
    ax.imshow(img,cmap="grey",vmin=vmin,vmax=vmax)
    ax.axis("off")
    for bbox,kpt,c in zip(bboxes, kpts,conf):
        if(bbox is not None):
            x,y,w,h = bbox.numpy() if not isinstance(bbox, np.ndarray) else bbox
            rect = Rectangle((x-0.5*w,y-0.5*h),h,w, linewidth=1, edgecolor="yellow", facecolor='none')
            tx,ty = rect.get_xy()
            ax.add_patch(rect)
            tx+=6
            ty-=12
        else:
            tx,ty = kpt
            ty-=40
            tx-=50
        circle = Circle(kpt.squeeze(),1, facecolor="red",edgecolor="red")
        ax.add_patch(circle)
        if is_conf:

            ax.text(tx,ty, f"{c*100:.0f}%", fontsize="small",bbox=dict(facecolor='white', alpha=0.5,))

    return bboxes
if not plt.get_fignums() : plt.figure(figsize=(12,5))
else: plt.clf()
image = random_image()
img = np.array(PIL.Image.open(image))
results, = model.predict(np.array(img), conf = 0.7, level=0, background=0)
# boxes2 = results.boxes2.cpu()
# results = results.cpu()

plt.subplot(1,2,1)
plt.title("Prediction")
kpt = results.keypoints
bbox = results.boxes

z = np.zeros((len(kpt)))
# plot_result(plt.gca(),img,z,bbox.xywh.cpu(),z,kpt.cpu(),conf=bbox.conf, vmin=None, vmax=None)
plt.imshow(results.plot())
plt.axis("off")
plt.subplot(1,2,2)
plt.title("Ground Truth")
plot_gt(image,plt.gca(),imgsz, vmin=None, vmax=None)
plt.show()
results.to_df() #print results dataframe

0: 544x640 52 Spots, 29.1ms
Speed: 0.9ms preprocess, 29.1ms inference, 0.5ms postprocess per image at shape (1, 3, 544, 640)
None None
_images/ebca02d640ec79537cb62adbab560453c0e6eb48e775f97dda70106df68a4eda.png
name class confidence box keypoints z
0 Spot 0 0.99845 {'x1': 454.38867, 'y1': 401.5827, 'x2': 486.40... {'x': 469.69635009765625, 'y': 416.8154296875} {'z': -1.0488699674606323, '(unused)': -0.5316...
1 Spot 0 0.99575 {'x1': 429.12033, 'y1': 363.26398, 'x2': 461.1... {'x': 444.1501159667969, 'y': 379.28582763671875} {'z': -0.3155600130558014, '(unused)': -0.0702...
2 Spot 0 0.99341 {'x1': 456.2196, 'y1': 10.27181, 'x2': 488.207... {'x': 471.6224060058594, 'y': 26.113449096679688} {'z': -0.39500001072883606, '(unused)': -1.057...
3 Spot 0 0.99340 {'x1': 30.99963, 'y1': 80.36482, 'x2': 63.0684... {'x': 46.5235481262207, 'y': 96.21710968017578} {'z': 0.05926999822258949, '(unused)': 0.26298...
4 Spot 0 0.99296 {'x1': 11.29617, 'y1': 33.48018, 'x2': 43.2828... {'x': 26.595169067382812, 'y': 49.51639175415039} {'z': 0.6833999752998352, '(unused)': -0.47876...
5 Spot 0 0.99206 {'x1': 428.07178, 'y1': 384.62256, 'x2': 460.0... {'x': 443.7632751464844, 'y': 400.3846435546875} {'z': -0.360260009765625, '(unused)': 0.790880...
6 Spot 0 0.99162 {'x1': 15.75287, 'y1': 362.80835, 'x2': 47.746... {'x': 31.19702911376953, 'y': 378.7030029296875} {'z': 0.0009599999757483602, '(unused)': -0.43...
7 Spot 0 0.99160 {'x1': 509.73969, 'y1': 174.45181, 'x2': 541.7... {'x': 525.0051879882812, 'y': 190.24867248535156} {'z': -0.12634000182151794, '(unused)': -0.269...
8 Spot 0 0.99157 {'x1': 433.39673, 'y1': 242.21719, 'x2': 465.3... {'x': 448.74462890625, 'y': 258.0653991699219} {'z': 0.10350000113248825, '(unused)': -0.5722...
9 Spot 0 0.99149 {'x1': 533.24207, 'y1': 182.69238, 'x2': 565.2... {'x': 548.7119140625, 'y': 198.37135314941406} {'z': -0.9698600172996521, '(unused)': -0.8475...
10 Spot 0 0.99061 {'x1': 543.1864, 'y1': 110.34048, 'x2': 575.16... {'x': 558.679931640625, 'y': 126.2138671875} {'z': -0.02565000019967556, '(unused)': -0.169...
11 Spot 0 0.99020 {'x1': 520.57391, 'y1': 97.08963, 'x2': 552.57... {'x': 536.0918579101562, 'y': 113.06974029541016} {'z': -0.1432500034570694, '(unused)': 0.12768...
12 Spot 0 0.98983 {'x1': 145.42633, 'y1': 35.57172, 'x2': 177.43... {'x': 160.739990234375, 'y': 51.447818756103516} {'z': 0.050280001014471054, '(unused)': -0.483...
13 Spot 0 0.98970 {'x1': 336.9281, 'y1': 387.3891, 'x2': 368.932... {'x': 352.30133056640625, 'y': 403.27301025390... {'z': 0.04839000105857849, '(unused)': -0.4381...
14 Spot 0 0.98932 {'x1': 46.41614, 'y1': 66.05319, 'x2': 78.4793... {'x': 62.11166000366211, 'y': 81.77384948730469} {'z': 0.8098300099372864, '(unused)': -0.48807...
15 Spot 0 0.98884 {'x1': 113.28403, 'y1': 104.95433, 'x2': 145.2... {'x': 128.7715606689453, 'y': 120.79208374023438} {'z': -0.2915700078010559, '(unused)': 0.15180...
16 Spot 0 0.98853 {'x1': 33.51456, 'y1': 189.43661, 'x2': 65.510... {'x': 48.82320022583008, 'y': 205.25967407226562} {'z': -0.4638499915599823, '(unused)': -0.4169...
17 Spot 0 0.98843 {'x1': 515.76465, 'y1': 404.54803, 'x2': 547.7... {'x': 530.8435668945312, 'y': 420.3667297363281} {'z': -0.0623599998652935, '(unused)': -0.2614...
18 Spot 0 0.98834 {'x1': 205.86484, 'y1': -5.1156, 'x2': 237.871... {'x': 221.26718139648438, 'y': 10.648360252380... {'z': -0.08206000179052353, '(unused)': -0.478...
19 Spot 0 0.98788 {'x1': 189.1944, 'y1': 387.82385, 'x2': 221.19... {'x': 204.59280395507812, 'y': 403.63134765625} {'z': -0.07006999850273132, '(unused)': -0.396...
20 Spot 0 0.98779 {'x1': 206.36411, 'y1': 286.6189, 'x2': 238.35... {'x': 221.6929931640625, 'y': 302.7490234375} {'z': 0.009440000168979168, '(unused)': -0.160...
21 Spot 0 0.98762 {'x1': 608.91644, 'y1': 350.29907, 'x2': 640.9... {'x': 624.3041381835938, 'y': 366.15606689453125} {'z': -0.5039200186729431, '(unused)': -0.4855...
22 Spot 0 0.98760 {'x1': 203.92493, 'y1': 203.92267, 'x2': 235.9... {'x': 219.06910705566406, 'y': 219.67239379882... {'z': -0.05420999974012375, '(unused)': -0.238...
23 Spot 0 0.98745 {'x1': 126.61062, 'y1': 221.50766, 'x2': 158.6... {'x': 141.90545654296875, 'y': 237.41143798828... {'z': -0.5934500098228455, '(unused)': -0.6944...
24 Spot 0 0.98725 {'x1': 89.15135, 'y1': 290.26636, 'x2': 121.12... {'x': 104.6095962524414, 'y': 306.14031982421875} {'z': -0.6975399851799011, '(unused)': -1.0777...
25 Spot 0 0.98724 {'x1': 337.82831, 'y1': 264.13217, 'x2': 369.8... {'x': 353.34063720703125, 'y': 279.9650573730469} {'z': -0.039739999920129776, '(unused)': 0.105...
26 Spot 0 0.98722 {'x1': 605.66241, 'y1': 130.90872, 'x2': 637.6... {'x': 621.1229858398438, 'y': 146.7654571533203} {'z': -0.23690000176429749, '(unused)': -0.546...
27 Spot 0 0.98686 {'x1': 399.47757, 'y1': 119.53226, 'x2': 431.4... {'x': 414.8857727050781, 'y': 135.439697265625} {'z': 0.23531000316143036, '(unused)': -0.3317...
28 Spot 0 0.98666 {'x1': 297.89886, 'y1': 443.21854, 'x2': 329.9... {'x': 313.3101806640625, 'y': 459.16302490234375} {'z': -0.015259999781847, '(unused)': -0.34516...
29 Spot 0 0.98648 {'x1': 254.87854, 'y1': 395.76569, 'x2': 286.8... {'x': 270.2241516113281, 'y': 411.6671142578125} {'z': -0.08113999664783478, '(unused)': -0.508...
30 Spot 0 0.98607 {'x1': 61.62058, 'y1': 367.84253, 'x2': 93.626... {'x': 76.9057388305664, 'y': 383.5351867675781} {'z': -0.3797700107097626, '(unused)': 0.16829...
31 Spot 0 0.98605 {'x1': 202.87679, 'y1': 44.76674, 'x2': 234.90... {'x': 218.2605438232422, 'y': 60.645301818847656} {'z': -0.7087200284004211, '(unused)': -0.7812...
32 Spot 0 0.98550 {'x1': 395.59991, 'y1': 209.77979, 'x2': 427.5... {'x': 410.943603515625, 'y': 225.6527862548828} {'z': 0.5764399766921997, '(unused)': -0.44786...
33 Spot 0 0.98520 {'x1': 352.12585, 'y1': 128.67172, 'x2': 384.1... {'x': 367.6582946777344, 'y': 144.52285766601562} {'z': -0.40143001079559326, '(unused)': 0.5516...
34 Spot 0 0.98503 {'x1': 279.09735, 'y1': 128.2417, 'x2': 311.09... {'x': 294.4912109375, 'y': 144.0572509765625} {'z': 0.293720006942749, '(unused)': -0.213239...
35 Spot 0 0.98499 {'x1': 223.25951, 'y1': 366.11719, 'x2': 255.2... {'x': 238.6258544921875, 'y': 382.0509948730469} {'z': 0.017650000751018524, '(unused)': -0.330...
36 Spot 0 0.98470 {'x1': 351.91721, 'y1': 308.34253, 'x2': 383.9... {'x': 367.3822326660156, 'y': 324.24664306640625} {'z': -0.014560000039637089, '(unused)': -0.40...
37 Spot 0 0.98459 {'x1': 408.10101, 'y1': 79.82852, 'x2': 440.10... {'x': 423.4658508300781, 'y': 95.66857147216797} {'z': 0.12411999702453613, '(unused)': -0.3572...
38 Spot 0 0.98358 {'x1': 540.46283, 'y1': 11.82071, 'x2': 572.46... {'x': 555.801513671875, 'y': 27.6430606842041} {'z': -0.8114500045776367, '(unused)': -0.7148...
39 Spot 0 0.98357 {'x1': 87.99037, 'y1': 336.29224, 'x2': 119.99... {'x': 103.45982360839844, 'y': 352.1165466308594} {'z': -0.22312000393867493, '(unused)': 0.4206...
40 Spot 0 0.98343 {'x1': 584.92041, 'y1': 273.99628, 'x2': 616.9... {'x': 600.3150024414062, 'y': 289.9287414550781} {'z': 0.3510800004005432, '(unused)': -0.52626...
41 Spot 0 0.98319 {'x1': -1.27251, 'y1': 418.07477, 'x2': 30.738... {'x': 14.083209991455078, 'y': 433.9892272949219} {'z': 0.3756600022315979, '(unused)': -0.28567...
42 Spot 0 0.98151 {'x1': 337.53781, 'y1': 177.82043, 'x2': 369.5... {'x': 352.94207763671875, 'y': 193.58792114257... {'z': -0.2790899872779846, '(unused)': 0.16873...
43 Spot 0 0.98086 {'x1': 393.14984, 'y1': 252.96042, 'x2': 425.1... {'x': 408.6138916015625, 'y': 268.8370361328125} {'z': -0.2449599951505661, '(unused)': 0.22044...
44 Spot 0 0.98065 {'x1': 486.87518, 'y1': -5.62891, 'x2': 518.87... {'x': 502.29595947265625, 'y': 10.088509559631... {'z': -0.23973000049591064, '(unused)': -0.505...
45 Spot 0 0.98032 {'x1': 347.77759, 'y1': 502.85684, 'x2': 379.7... {'x': 363.0234680175781, 'y': 518.6696166992188} {'z': -0.2752000093460083, '(unused)': -0.1840...
46 Spot 0 0.97986 {'x1': 17.77096, 'y1': 315.86493, 'x2': 49.775... {'x': 33.1440315246582, 'y': 331.84716796875} {'z': 0.13955999910831451, '(unused)': -0.3315...
47 Spot 0 0.97956 {'x1': 60.46625, 'y1': 178.37817, 'x2': 92.492... {'x': 75.89572143554688, 'y': 194.2646026611328} {'z': -1.174180030822754, '(unused)': -1.02040...
48 Spot 0 0.97920 {'x1': 545.45563, 'y1': 230.59537, 'x2': 577.4... {'x': 560.6781616210938, 'y': 246.30990600585938} {'z': -0.28571000695228577, '(unused)': -0.396...
49 Spot 0 0.97641 {'x1': 73.4754, 'y1': 158.67316, 'x2': 105.514... {'x': 88.98802185058594, 'y': 174.60047912597656} {'z': -0.8868399858474731, '(unused)': -0.7769...
50 Spot 0 0.97594 {'x1': 31.17584, 'y1': 516.07214, 'x2': 63.122... {'x': 46.24225997924805, 'y': 531.9324340820312} {'z': 0.09358000010251999, '(unused)': -0.3732...
51 Spot 0 0.95577 {'x1': -14.68774, 'y1': 41.23734, 'x2': 17.269... {'x': 0.9134699702262878, 'y': 56.781700134277... {'z': -0.36566999554634094, '(unused)': 0.4220...

Next, lets look at the validation scores

snr_range = [1,20]
snrlevels = np.linspace(*snr_range, 20)
r,xy_rms, z_rms, kpt_pairs = [],[],[],[]
for i, snr in enumerate(snrlevels):
    fname = f'../datasets/Dataset_spots/levels/dataset_test_snr{snr}.yaml'
    res = model.val(conf=0.9,data=fname, workers=0, background=0)
    r.append(res.all_box_results)
    xy_rms.append(res.xy_rms)
    z_rms.append(res.z_rms)
    kpt_pairs.append(res.kpt_pairs)

    del res

recall = np.array([res["r"] for res in r]).squeeze()
precision = np.array([res["p"] for res in r]).squeeze()

np.savez("spots_validation.npz", recall=recall, precision=precision, xy_rms=np.array(xy_rms))
Hide code cell output
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr1.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.43it/s]
                   all       1000      60179      0.996      0.988      0.992      0.813   3.28e+04       1.77
Speed: 0.1ms preprocess, 0.6ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr2.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.83it/s]
                   all       1000      59890          1          1      0.995      0.945   3.28e+04      0.703
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr3.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.83it/s]
                   all       1000      59146          1          1      0.995      0.981   3.28e+04      0.449
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr4.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.85it/s]
                   all       1000      59792          1          1      0.995      0.992   3.28e+04      0.336
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr5.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.79it/s]
                   all       1000      59366          1          1      0.995      0.995   3.28e+04      0.272
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr6.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.78it/s]
                   all       1000      60315          1          1      0.995      0.995   3.28e+04      0.233
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr7.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.67it/s]
                   all       1000      60693          1          1      0.995      0.995   3.28e+04      0.205
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr8.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.84it/s]
                   all       1000      59994          1          1      0.995      0.995   3.28e+04      0.184
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr9.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.76it/s]
                   all       1000      60798          1          1      0.995      0.995   3.28e+04      0.168
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr10.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.12it/s]
                   all       1000      59803          1          1      0.995      0.995   3.28e+04      0.157
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr11.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.67it/s]
                   all       1000      60149          1          1      0.995      0.995   3.28e+04      0.146
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr12.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.76it/s]
                   all       1000      60333          1          1      0.995      0.995   3.28e+04      0.139
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr13.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.68it/s]
                   all       1000      59609          1          1      0.995      0.995   3.28e+04      0.132
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr14.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.78it/s]
                   all       1000      59375          1          1      0.995      0.995   3.28e+04      0.128
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr15.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.66it/s]
                   all       1000      59628          1          1      0.995      0.995   3.28e+04      0.123
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr16.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.67it/s]
                   all       1000      60449          1          1      0.995      0.995   3.28e+04       0.12
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr17.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.58it/s]
                   all       1000      60427          1          1      0.995      0.995   3.28e+04      0.116
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr18.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.27it/s]
                   all       1000      59911          1          1      0.995      0.995   3.28e+04      0.114
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr19.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.39it/s]
                   all       1000      59731          1          1      0.995      0.995   3.28e+04      0.111
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Ultralytics 8.3.79 🚀 Python-3.12.6 torch-2.5.1+cu124 CUDA:0 (NVIDIA GeForce RTX 4090, 24202MiB)
val: Scanning /home/jupyter-till/git/YOLOTrack11/datasets/Dataset_spots/labels/test_snr20.cache... 1000 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1000/1000 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     z rms.    xy rms.: 100%|██████████| 63/63 [00:04<00:00, 14.24it/s]
                   all       1000      59596          1          1      0.995      0.995   3.28e+04      0.108
Speed: 0.1ms preprocess, 0.5ms inference, 0.0ms loss, 0.7ms postprocess per image
Hide code cell content
# calculate the CRLB
import image_generator as ig
objects = ig.pd.DataFrame({"label":"Spot", "i":2**15, "s":8, "x":40,"y":40}, index=[0])
img, _ = ig.generateImage(objects, (80,80),noise=0, refstack=np.zeros((40,40,40)), background=0)
# img += 2e4
# img /= img.max()
# print(img.max())
img = img.round().astype(np.uint16)
def derivative(psf, axis, delta=0.1): # delta : small shift in x or y
    shift_vec = [0, 0]
    shift_vec[axis] = delta
    psf_forward = (1-delta)*psf+(delta*np.roll(psf, shift=1, axis=axis)).round().astype(np.uint16)
    shift_vec[axis] = -delta
    psf_backward = (1-delta)*psf+(delta*np.roll(psf, shift=-1, axis=axis)).round().astype(np.uint16)
    return (psf_forward - psf_backward) / (2 * delta)
dh_dx = derivative(img, axis=1)  # x-axis derivative
dh_dy = derivative(img, axis=0)  # y-axis derivative
plt.imshow(dh_dy)
plt.colorbar()
# Compute Fisher Information Matrix (FIM)
I_xx =  np.sum(dh_dx**2)
I_yy =  np.sum(dh_dy**2)


I_xy =  np.sum(dh_dx * dh_dy)

FIM = np.array([[I_xx, I_xy],
                [I_xy, I_yy]])
print(FIM)
CRLB_0 = np.sqrt(np.diag(np.linalg.inv(FIM)))
snr_range = [1,20]
snrlevels = np.linspace(*snr_range, 20)
CRLB = np.sqrt((2**15/np.array(snrlevels))**2+1/12) * np.mean(CRLB_0)
print(f"CRLB (x,y) at SNR=1: {CRLB}")
[[ 1.6734e+09           0]
 [          0  1.6734e+09]]
CRLB (x,y) at SNR=1: [    0.80103     0.40052     0.26701     0.20026     0.16021     0.13351     0.11443     0.10013    0.089004    0.080103    0.072821    0.066753    0.061618    0.057217    0.053402    0.050065     0.04712    0.044502     0.04216    0.040052]
_images/4566c256b75a2d68971d79557b10db1a7c77a594841903933faa3cf5696b8c79.png
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

def place_images(ax, images, xmin=0.05,xmax=0.95, vmin=None, vmax=None):
    x_range = np.linspace(xmin, xmax, 11)
    for x, image in zip(x_range, images):
        im = OffsetImage(image, zoom=0.32, cmap="gray")
        if vmin is not None and vmax is not None:
            im.get_children()[0].set_clim(vmin, vmax)
        ab = AnnotationBbox(im, (x, 0), xycoords="axes fraction", frameon=False, pad=0, box_alignment=(0.5, 0))
        ax.add_artist(ab)
recall = np.array([res["r"] for res in r]).squeeze()
precision = np.array([res["p"] for res in r]).squeeze()
f1 = 2 * (recall * precision) / (recall + precision)

images = [PIL.Image.open(f"spot_at_snr/{snr}.png") for snr in np.linspace(1,20,20)[::2]]

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
place_images(plt.gca(), images, xmin=0.05,xmax=1.05)
plt.plot(snrlevels, 1-recall[:,7], ".-",label="missed detection rate")
plt.plot(snrlevels, 1-precision[:,7], ".-",label="false positive rate")
plt.plot(snrlevels, 1-f1[:,7], ".-",label="1 - F1 score")
plt.xlabel("SNR")
# plt.xscale("log")
plt.yscale("log")
plt.ylim(*plt.ylim()) #quick fix to fix fill
plt.axvline(8,0,1,ls="dashed", lw=0.9,c="k")
plt.axvline(12,0,1,ls="dashed", lw=0.9,c="k")
plt.fill_between((8,12),*plt.ylim(), color="grey", alpha=0.2,label="training set SNR")
plt.legend(loc="upper right")

plt.subplot(1,2,2)
place_images(plt.gca(), images, xmin=0.05,xmax=1.05)
plt.plot(snrlevels, np.array(xy_rms)[:,7], ".-",label="xy rms. / px")
plt.plot(snrlevels, CRLB, "--", label="CRLB")
plt.legend(loc="upper right")
plt.xscale("log")
plt.yscale("log")
plt.xlabel("SNR")
plt.ylabel("xy rms. / px")
# plt.ylim(0.3e-1,2) #quick fix to fix fill
plt.axvline(8,0,1,ls="dashed", lw=0.9,c="k")
plt.axvline(12,0,1,ls="dashed", lw=0.9,c="k", label="training set SNR")
plt.fill_between((8,12),*plt.ylim(), color="grey", alpha=0.2)
plt.show()

plt.figure()

# place_images(plt.gca(), images, xmin=0.05,xmax=1.05)
plt.plot(snrlevels, np.array(xy_rms)[:,7]/(1/snrlevels * CRLB.mean()))
plt.legend(loc="upper right")
plt.xscale("log")
plt.yscale("log")
plt.xlabel("SNR")
plt.ylabel("xy rms. / CRLB")
# plt.ylim(0.3e-1,2) #quick fix to fix fill
plt.axvline(8,0,1,ls="dashed", lw=0.9,c="k")
plt.axvline(12,0,1,ls="dashed", lw=0.9,c="k", label="training set SNR")
plt.fill_between((8,12),*plt.ylim(), color="grey", alpha=0.2)
plt.show()
_images/ea3999db64266bc297aadb5983db7c76fd525eae845750d4ce03bd0fbb940fd0.png
/tmp/ipykernel_3742084/1627014097.py:41: UserWarning: No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
  plt.legend(loc="upper right")
_images/56ffb87c3654e1e9703ad5820d17bb454ba3dee0c3b1312d22cbe6c6b1d745d5.png

It seems, we are proportional to the theoretical limit at first, but plateau at an rms ~0.1px for low noise. This might be because we have not trained in the high noise region, but compared to https://doi.org/10.1038/s41467-022-35004-y, where they train on similar noise, our performance is worse.

plt.figure(figsize=(12,3))
x_bias = [np.mean(np.subtract(*kpt[6][0])) for kpt in kpt_pairs]
y_bias = [np.mean(np.subtract(*kpt[6][1])) for kpt in kpt_pairs]
plt.subplot(121)
plt.plot(snrlevels, x_bias)
plt.ylabel("x bias / px")
plt.axhline(0,0,1, lw=0.9,ls="--",c="k")
plt.xlabel("SNR")
plt.subplot(122)

plt.plot(snrlevels, y_bias)
plt.axhline(0,0,1, lw=0.9,ls="--",c="k")
plt.ylabel("y bias / px")
plt.xlabel("SNR")
Text(0.5, 0, 'SNR')
_images/de9738d3e28b1a140baf1aae92c788e0907d11afd2307ac746619ba86fb71815.png

No significant bias