PyTorch bridges the gap between research and production. One of the ways they’ve achieved this is through Torch Hub. Torch Hub was conceived further to extend PyTorch’s credibility as a production-based framework. In today’s tutorial, we’ll learn how to utilize Torch Hub to Load YOLOv7 pre-trained model for wide-scale use.
YOLO Object-detection technology is widely used in many applications. It’s a backbone for many computer vision tasks, which include object segmentation, object tracking, object classification, object counting, etc.
YOLOv7 has been released with the contribution of AlexeyAB (YOLOv4 author) & WongKinYiu (YOLOR author). The aim behind implementing YOLOv7 is to achieve better accuracy than YOLOR, YOLOv5, and YOLOX. Clone the YOLOv7 repository from the link.
!# Download YOLOv7 code !git clone https://github.com/WongKinYiu/yolov7 %cd yolov7 !ls
Pre-trained Object Detection
YOLOv7 weights must need to be in your local project directory, download the pre-trained weights file from this link.
!# Download trained weights !wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt
Torch Hub helps you to use or publish pre-trained models in the cause of research sharing and reproducibility. The process of harnessing Torch Hub is simple.
To load a model in
torch.hub, you need to have a script called
hubconf.py in your repository/directory. In that script, you’ll define normal callable functions known as entry points. Calling the entry points to return the desired models. These callable functions initialize and return the models which the user requires. Hence, this script will connect our own created model to Torch Hub.
We’ll be using
torch.hub.load to load our model from a local directory. After loading the model with pre-trained weights, we’ll predict it on some sample data. Let’s look at an example of how it works.
import cv2 from google.colab.patches import cv2_imshow import torch # Load fine-tuned custom model model = torch.hub.load('/content/yolov7', 'custom', 'yolov7.pt',force_reload=True, source='local',trust_repo=True)
We now test the detection with pre-trained weights to confirm that all of our modules are working fine.
img = cv2.imread('/content/elephentes.jpg') results = model(img) # batched inference df = results.pandas().xyxy for _, row in df.iterrows(): cv2.rectangle(img, (int(row['xmin']), int(row['ymin'])),(int(row['xmax']), int(row['ymax'])), (255,155,0), 2) cv2_imshow(img)
If everything is working fine, then you will be able to get the following results: