CLIP model export ONNX model

openai CLIP

GitHub - openai/CLIP: CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image

import torch
import torch.nn as nn
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

model.float()
model.eval()

image = preprocess(Image.open("clip_dog.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a dog", "a cat"]).to(device)

print("text:", text)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

# export to ONNX


class ImgModelWrapper(nn.Module):
    def __init__(self):
        super(ImgModelWrapper, self).__init__()
        self.model = model

    def forward(self, image):
        image_features = model.encode_image(image)
        return image_features


class TxtModelWrapper(nn.Module):
    def __init__(self):
        super(TxtModelWrapper, self).__init__()
        self.model = model

    def forward(self, image):
        text_features = model.encode_text(text)
        return text_features


img_model = ImgModelWrapper()
txt_model = TxtModelWrapper()

torch.onnx.export(img_model,               # model being run
                  image,                         # model input (or a tuple for multiple inputs)
                  "openai_vit_img.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=12,          # the ONNX version to export the model to
                  do_constant_folding=False,  # whether to execute constant folding for optimization
                  input_names=['input'],   # the model's input names
                  output_names=['output'],  # the model's output names
                  dynamic_axes={'input': {0: 'batch'}})
torch.onnx.export(txt_model,               # model being run
                  text,                         # model input (or a tuple for multiple inputs)
                  "openai_vit_txt.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=12,          # the ONNX version to export the model to
                  do_constant_folding=False,  # whether to execute constant folding for optimization
                  input_names=['input'],   # the model's input names
                  output_names=['output'],  # the model's output names
                  dynamic_axes={'input': {0: 'batch'}})

Chinese-clip uses the same method as above

GitHub - OFA-Sys/Chinese-CLIP: Chinese version of CLIP which achieves Chinese cross-modal retrieval and representation generation.

Regarding a research direction of the CLIP model, I think it may be: the semantic feature vectors obtained by the same sentence described in different languages ​​are similar, but currently each language and the corresponding image are used to train a clip txt and img model, and the middle feature vector are different, which leads to different models used in different language models. If each language can generate similar semantic vectors, you only need to replace different txt encoders and use the same img encoder. Similarly, when this CLIP model is combined with stable diffusion, it is not necessary to retrain stable diffusion according to each language, but to retrain the txt encoder.

similar:

FlagAI/README.md at master · FlagAI-Open/FlagAI · GitHub

But this work only unifies the semantics of a small number of multilingual models, and stable diffusion still needs to be retrained. My ideal situation is that there is no need to retrain stable diffusion, but any language only needs to retrain an aligned txt encoder.

Guess you like

Origin blog.csdn.net/u013701860/article/details/129764269