jetracer——自动驾驶车项目(interactive_regression.ipynb)

Camera

  • 相机
from jetcam.csi_camera import CSICamera
# from jetcam.usb_camera import USBCamera
​
camera = CSICamera(width=224, height=224)
# camera = USBCamera(width=224, height=224)
​
camera.running = True
  • 插入CSI摄像头

Task

  • 任务
import torchvision.transforms as transforms
from xy_dataset import XYDataset
​
TASK = 'road_following'
​
CATEGORIES = ['apex']
​
DATASETS = ['A', 'B']
​
TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((244, 244)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
​
datasets = {
    
    }
for name in DATASETS:
    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)
  • 设置任务
    DATASETS是数据集名称列表;
    datasets是数据集字典;
    XYDataset是XY数据集的一个已定义的类,参数由该数据集名称,CATEGORIES类别列表,TRANSFORMS语法转化规则,random_hflip随机翻转组成。

Data Collection

  • 数据收集
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget
​
​
# initialize active dataset
dataset = datasets[DATASETS[0]]# unobserve all callbacks from camera in case we are running this cell for second time
camera.unobserve_all()# create image preview
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
snapshot_widget = ipywidgets.Image(width=camera.width, height=camera.height)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)# create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')
count_widget = ipywidgets.IntText(description='count')# manually update counts at initialization
count_widget.value = dataset.get_count(category_widget.value)# sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')# update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')
​
​
def save_snapshot(_, content, msg):
    if content['event'] == 'click':
        data = content['eventData']
        x = data['offsetX']
        y = data['offsetY']
        
        # save to disk
        dataset.save_entry(category_widget.value, camera.value, x, y)
        
        # display saved snapshot
        snapshot = camera.value.copy()
        snapshot = cv2.circle(snapshot, (x, y), 8, (0, 255, 0), 3)
        snapshot_widget.value = bgr8_to_jpeg(snapshot)
        count_widget.value = dataset.get_count(category_widget.value)
        
camera_widget.on_msg(save_snapshot)
​
data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    dataset_widget,
    category_widget,
    count_widget
])
​
display(data_collection_widget)
  • 运行结果:
    在这里插入图片描述
    运行出来后可以选择数据集A或者B,category仅可使用apex,count值为数据集A中存在的数据量。
    可以看到左边的图片是摄像头实时的情况,右边的图片显示异常,其实只要在左边的图片中点击小车前进的理想点之后,右边的图片就会显示出你点击的理想点及当时场景,这样一来,你就得到了一个标记的数据。

Model

  • 模型
import torch
import torchvision
​
device = torch.device('cuda')
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category# ALEXNET
# 经典神经网络实例
# model = torchvision.models.alexnet(pretrained=True)
# model.classifier[-1] = torch.nn.Linear(4096, output_dim)# SQUEEZENET 
# 物品分类
# model = torchvision.models.squeezenet1_1(pretrained=True)
# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
# model.num_classes = len(dataset.categories)# RESNET 18
# 残差网络 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)# RESNET 34
# 残差网络 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)# DENSENET 121
# Dense卷积网络
# model = torchvision.models.densenet121(pretrained=True)
# model.classifier = torch.nn.Linear(model.num_features, output_dim)
​
model = model.to(device)
​
model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='road_following_model.pth')def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)
​
model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])
​
​
display(model_widget)
  • 运行结果:
    在这里插入图片描述
    是一个模型路径选择的widget。

Live Execution

  • 实时执行
import threading
import time
from utils import preprocess
import torch.nn.functional as F
​
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)def live(state_widget, model, camera, prediction_widget):
    global dataset
    while state_widget.value == 'live':
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed).detach().cpu().numpy().flatten()
        category_index = dataset.categories.index(category_widget.value)
        x = output[2 * category_index]
        y = output[2 * category_index + 1]
        
        x = int(camera.width * (x / 2.0 + 0.5))
        y = int(camera.height * (y / 2.0 + 0.5))
        
        prediction = image.copy()
        prediction = cv2.circle(prediction, (x, y), 8, (255, 0, 0), 3)
        prediction_widget.value = bgr8_to_jpeg(prediction)
            
def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget))
        execute_thread.start()
​
state_widget.observe(start_live, names='value')
​
live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    state_widget
])
​
display(live_execution_widget)

最后是epochs,epochs,train,loss等功能或其数值操作显示部件。

BATCH_SIZE = 8
​
optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
​
epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )
​
        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)if is_training:
            model = model.train()
        else:
            model = model.eval()while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                # send data to device
                images = images.to(device)
                xy = xy.to(device)if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()# execute model to get outputs
                outputs = model(images)# compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()# step optimizer to adjust parameters
                    optimizer.step()# increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        pass
    model = model.eval()
​
    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
])
​
display(train_eval_widget)

All together!

  • 总结

The following widget can be used to label a multi-class x, y dataset. It supports labeling only one instance of each class per image (ie: only one dog), but multiple classes (ie: dog, cat, horse) per image are possible.

  • 下面的小部件可以用来标记一个多类x, y数据集。它只支持为每个图像的每个类标记一个实例(即:只有一只狗),但是每个图像可能有多个类(即:狗,猫,马)。

Click the image on the top left to save an image of category to dataset at the clicked location.

  • 单击左上角的图像,在单击的位置将类别图像保存到数据集。
Widget Description
dataset Selects the active dataset
category Selects the active category
epochs Sets the number of epochs to train for
train Trains on the active dataset for the number of epochs specified
evaluate Evaluates the accuracy on the active dataset over one epoch
model path Sets the active model path
load Loads a model from the active model path
save Saves a model to the active model path
stop Disables the live demo
live Enables the live demo
部件 描述
dataset 选择活动数据集
category 选择活动类别
epochs 设定要训练的迭代数
train 在活动数据集上训练指定的纪元数目
evaluate 评估活动数据集在一个历元上的精度
model path 设置活动模型路径
load 从活动模型路径加载模型
save 将模型保存到活动模型路径
stop 禁用现场演示
live 开启现场演示

将所有部件统一运行显示在下面

all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])
​
display(all_widget)
  • 运行结果:
    在这里插入图片描述

  • 左边图片为小车摄像头的实时画面;中间的图片会在点击小车前进的理想位置后显示出标记位置;最右侧图片在选择状态为live后显示出当前小车判断的前进位置。

                                                                                                                                              
1、该数据采集程序退出后,重新打开摄像头可能出现摄像头被占用的出错提示,重启摄像头请在命令行运行下面的命令:

sudo systemctl restart nvargus-daemon

2、直接运行程序后,对最后一整个部件进行操作即可,上面的部件不必进行操作。
3、程序运行到最后,此时可以将小车放到赛道上,同时打开上一节teleoperation.ipynb程序运行,使遥控手柄可以控制小车运动。
4、控制小车沿着赛道运行,每移动一小段位置,就用鼠标移动到图片中小车理想的运行路径上点击一下保存图片,控制小车沿着赛道拍摄10圈(圈数不固定,照片数量足够即可)。
5、采集好数据后,将epochs的值选择为10,然后点击train,训练10轮。
6、训练好数据后可以点击evaluate 评估训练模型,如果采集的数据没有问题,可以在最左边的图像中看到当前小车位置的理想方向。
7、注意采集的数据需将小车沿轨道移动到不同的位置,偏移和方向,尽可能沿理想路径的方向上选择最远的点保存,以保证小车不会离开轨道或碰撞物体。

猜你喜欢

转载自blog.csdn.net/weixin_44350337/article/details/114554186