Diffusers库的初识及使用
diffusers库的目标是:
- 将扩散模型(diffusion models)集中到一个单一且长期维护的项目中
- 以公众可访问的方式复现高影响力的机器学习系统,如DALLE、Imagen等
- 让开发人员可以很容易地使用API进行模型训练或者使用现有模型进行推理
diffusers的核心分成三个组件:
- Pipelines: 高层类,以一种用户友好的方式,基于流行的扩散模型快速生成样本
- Models:训练新扩散模型的流行架构,如UNet
- Schedulers:推理场景下基于噪声生成图像或训练场景下基于噪声生成带噪图像的各种技术
diffusers的安装
pip install diffusers
先看推理
导入Pipeline,from_pretrained()
加载模型,可以是本地模型,或从the Hugging Face Hub自动下载。
from diffusers import StableDiffusionPipeline
image_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
# 加载本地模型:
# image_pipe = StableDiffusionPipeline.from_pretrained("./models/Stablediffusion/stable-diffusion-v1-4")
image_pipe.to("cuda")
prompt = "a photograph of an astronaut riding a horse"
pipe_out = image_pipe(prompt)
image = pipe_out.images[0]
# you can save the image with
# image.save(f"astronaut_rides_horse.png")
我们查看下image_pipe
的内容:
StableDiffusionPipeline {
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.10.2",
"feature_extractor": [
"transformers",
"CLIPFeatureExtractor"
],
"requires_safety_checker": true,
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}
查看Images的结构:
StableDiffusionPipelineOutput(
images=[<PIL.Image.Image image mode=RGB size=512x512 at 0x1A14BDD7730>],
nsfw_content_detected=[False])
由此,可以看到pipe_out
的包含两部分,第一部分就是生成的图片列表,如果只有一张图片,则pipe_out.images[0]
即可取出目标图像。
如果我们要一次生成多张图像呢?只需要修改prompt的list长度即可,代码如下。
from diffusers import StableDiffusionPipeline
image_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
image_pipe.to("cuda")
prompt = ["a photograph of an astronaut riding a horse"] * 3
out_images = image_pipe(prompt).images
for i, out_image in enumerate(out_images):
out_image.save("astronaut_rides_horse" + str(i) + ".png")
在使用image_pipe
生成图像时,默认是float32
精度的,若本地现在不足,可能会报Out of memory
的错误,此时,可以通过加载float16
精度的模型来解决。
Note: If you are limited by GPU memory and have less than 10GB of GPU RAM available, please make sure to load the
StableDiffusionPipeline
in float16 precision instead of the default float32 precision as done above.You can do so by loading the weights from the
fp16
branch and by tellingdiffusers
to expect the weights to be in float16 precision:image_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)
对于每个PipeLine
都有一些特定的配置,如StableDiffusionPipeline
除了必要的prompt
参数,还可以配置如下参数:
num_inference_steps: int = 50
guidance_scale: float = 7.5
generator: Optional[torch.Generator] = None
- 等等
示例:如果你想要每次得到的结果均一致,可以设置每次的种子都一样
generator = torch.Generator("cuda").manual_seed(1024)
prompt = ["a photograph of an astronaut riding a horse"] * 3
out_images = image_pipe(prompt, generator=generator).images
再看训练
最新文章
- svn报错:“Previous operation has not finished; run &#39;cleanup&#39; if it was interrupted“ 的解决方法
- 【C语言入门教程】5.5 实现问题(效率)
- Information
- GridControl的用法(1)
- 一次MVVM+ReactiveCocoa实践
- iOS 添加阴影后 屏幕卡顿 抖动
- Android 每隔3s更新一次title
- hibernate-search-5.1.1简易使用
- MySQL高级特性——绑定变量
- spring boot 扫描不到自定义Controller
- A1014. Waiting in Line
- [蓝点ZigBee] Zstack 之按键驱动以及控制LED灯 ZigBee/CC2530 视频资料
- jQuery 查找元素2
- Azure PowerShell (13) 批量设置Azure ARM Network Security Group (NSG)
- How to Pronounce AR, ORN, etc.
- 【文文殿下】【洛谷】分治NTT模板
- 【three.js练习程序】动画效果,100个方块随机运动
- Elasticsearch java API (23)查询 DSL Geo查询
- Python pycurl使用
- bzoj 1270: [BeijingWc2008]雷涛的小猫 简单dp+滚动数组
热门文章
- python-CSV文件的读写
- 【终极解决办法】pyinstaller打包exe没有错误,运行exe提示Failed to execute script &#39;mainlmageWindows&#39; due tounhandled exception: No module named &#39;docx&#39;
- java中使用apache poi 读取 doc,docx,ppt,pptx,xls,xlsx,txt,csv格式的文件示例代码
- ChatGPT 可以联网了!浏览器插件下载
- 如何用 JavaScript 编写你的第一个单元测试
- m3u8文件后缀jpg,png等处理方法及视频合并
- C++可执行文件绝对路径获取与屏蔽VS安全检查
- ArcGIS工具 - 按要素裁切数据库
- [cocos2d-x]关于菜单项
- Java 进阶P-1.1+P-1.2