大象也能P转身?超火的DragGAN正式开源,带你一秒解锁新技能

生成对抗网络(GAN)是一种深度学习模型,由生成器网络和判别器网络组成。GAN用于生成类似真实数据的合成数据。

GAN的工作原理如下:

  1. 生成器网络:生成器以随机噪声为输入,生成合成数据,如图像或文本。它通过一系列神经层学习将输入噪声映射到期望的输出。

  2. 判别器网络:判别器接收真实数据和生成数据作为输入,并试图正确分类它们。它通过训练优化权重来区分真实和假数据。

  3. 对抗训练:生成器和判别器以对抗性方式一起进行训练。生成器旨在生成可以欺骗判别器将其视为真实数据的数据,而判别器旨在正确分类真实和假数据。

在训练过程中,生成器和判别器相互竞争。这种竞争迫使生成器提高生成更逼真数据的能力,同时判别器变得更擅长区分真实和假数据。

训练过程通常涉及交替更新生成器和判别器网络。最终目标是使生成器生成的合成数据在判别器的眼中无法与真实数据区分。

GAN已成功应用于多个领域,包括图像合成、文本生成、音乐创作和视频生成。它们也为计算机视觉、自然语言处理和创造性人工智能等领域的进展做出了贡献。

 什么时候即使是不合理的需求也能通过简简单单的拖拽就能实现了呢,DragGAN就做到了,最近DragGAN项目已经正式开源了,官方项目地址在这里,如下所示:

 短短几天就已经收获了将近3w的star量,相信后续还是会持续增长的。

官方项目介绍地址在这里,如下所示:

 合成满足用户需求的视觉内容通常需要对生成对象的姿势、形状、表情和布局进行灵活而精确的控制。现有的方法通过手动注释的训练数据或先前的3D模型来获得生成对抗性网络(GAN)的可控性,这通常缺乏灵活性、准确性和通用性。在这项工作中,我们研究了一种功能强大但探索较少的控制GANs的方法,即以用户交互的方式“拖动”图像的任何点,以精确到达目标点,如图1所示。为了实现这一点,我们提出了DragGAN,它由两个主要组件组成,包括:1)一种基于特征的运动监督,它驱动手柄点向目标位置移动,以及2)一种新的点跟踪方法,它利用判别性GAN特征来保持手柄点的位置定位。通过DragGAN,任何人都可以通过精确控制像素的位置来变形图像,从而操纵不同类别(如动物、汽车、人类、风景等)的姿势、形状、表情和布局。由于这些操作是在GAN的学习生成图像流形上执行的,即使在具有挑战性的场景中,如产生幻觉的被遮挡内容和始终遵循物体刚性的变形形状,它们也倾向于产生逼真的输出。定性和定量比较都证明了DragGAN在图像处理和点跟踪任务中优于现有方法的优势。我们还展示了通过GAN反演对真实图像的操作。

作者提供的论文地址在这里,感兴趣的话可以仔细研读一下,如下所示:

 这里我还没有仔细读过论文,所以就不再赘述了。

接下来就先以官方开源的项目为基准实际实践试用一下,官方项目下载好后需要下载对应的模型文件才可以继续使用,下载脚本如下所示:

import os
import sys
import json
import requests
from tqdm import tqdm

def download_file(url: str, filename: str, download_dir: str):
    """Download a file if it does not already exist."""

    try:
        filepath = os.path.join(download_dir, filename)
        content_length = int(requests.head(url).headers.get("content-length", 0))

        # If file already exists and size matches, skip download
        if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length:
            print(f"{filepath} already exists. Skipping download.")
            return
        if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length:
            print(f"{filepath} already exists but size does not match. Redownloading.")
        else:
            print(f"Downloading {filename} from {url}")

        # Start download, stream=True allows for progress tracking
        response = requests.get(url, stream=True)

        # Check if request was successful
        response.raise_for_status()

        # Create progress bar
        total_size = int(response.headers.get('content-length', 0))
        progress_bar = tqdm(
            total=total_size, 
            unit='iB', 
            unit_scale=True, 
            ncols=70, 
            file=sys.stdout
        )

        # Write response content to file
        with open(filepath, 'wb') as f:
            for data in response.iter_content(chunk_size=1024):
                f.write(data)
                progress_bar.update(len(data))  # Update progress bar

        # Close progress bar
        progress_bar.close()

        # Error handling for incomplete downloads
        if total_size != 0 and progress_bar.n != total_size:
            print("ERROR, something went wrong while downloading")
            raise Exception()


    except Exception as e:
        print(f"An error occurred: {e}")

def main():
    """Main function to download files from URLs in a config file."""
    
    # Get JSON config file path
    script_dir = os.path.dirname(os.path.realpath(__file__))
    config_file_path = os.path.join(script_dir, "download_models.json")

    # Set download directory
    download_dir = "checkpoints"
    os.makedirs(download_dir, exist_ok=True)

    # Load URL and filenames from JSON
    with open(config_file_path, "r") as f:
        config = json.load(f)

    # Download each file specified in config
    for url, filename in config.items():
        download_file(url, filename, download_dir)


if __name__ == "__main__":
    main()

直接终端执行即可。

当然了,也可以自己手动下载放到指定目录中即可,我这里就是直接选择使用手动下载的方式,创建checkpoints目录,将下载好的模型存入该目录中即可,如下所示:

 之后就可以使用服务了。

作者这里提供了可视化界面和基于Gradio的web服务系统两种形式,我比较喜欢可视化界面的形式,但是无奈界面启动后会报错导致使用不了,界面启动命令如下所示:
 

.\scripts\gui.bat

短暂显示界面后就会因为报错而终止了,如下所示:

 详细报错如下所示:

    return super(Event, cls).__new__(
TypeError: object.__new__() takes exactly one argument (the type to instantiate)

因为不太清楚具体的报错原因,这里就暂时搁置了,如果有知道的朋友欢迎留言交流。

这里就直接使用web系统的形式了,终端执行:

python visualizer_drag_gradio.py

即可一键启动web服务,启动成功输出如下所示:

 在浏览器端输入下述链接即可访问页面:

http://127.0.0.1:7860/

如下所示:

快速入门

选择所需的预训练模型并调整“种子”以生成初始图像。

单击图像以添加控制点。

单击开始并享受它!

提前使用

更改步长以调整阻力优化中的学习率。

选择w或w+以更改要优化的潜在空间:

对w空间进行优化可能会对图像产生更大的影响。

在w+空间上优化可能比w工作得慢,但通常会获得更好的结果。

请注意,更改潜在空间将重置图像、点和遮罩(这与“重置图像”按钮的效果相同)。

单击“编辑柔性区域”以创建遮罩,并约束未遮罩的区域以保持不变。

设置完成后点击star即可启动计算,终端输出如下所示:

 因为我的机器性能一般,所以计算的还是比较慢的,但是你还是能比较清楚地看到:狮子扭头了,简单看下我的操作实例,如需所示:

 计算结果如下所示:

 真的扭头了。。。。

还是很有意思的了,感兴趣的话都可以体验尝试一番。

猜你喜欢

转载自blog.csdn.net/Together_CZ/article/details/131480192
今日推荐