diffusers库的目标是:

  • 将扩散模型(diffusion models)集中到一个单一且长期维护的项目中
  • 以公众可访问的方式复现高影响力的机器学习系统,如DALLE、Imagen等
  • 让开发人员可以很容易地使用API进行模型训练或者使用现有模型进行推理

diffusers的核心分成三个组件:

  • Pipelines: 高层类,以一种用户友好的方式,基于流行的扩散模型快速生成样本
  • Models:训练新扩散模型的流行架构,如UNet
  • Schedulers:推理场景下基于噪声生成图像或训练场景下基于噪声生成带噪图像的各种技术
diffusers的安装
pip install diffusers
先看推理

导入Pipeline,from_pretrained()加载模型,可以是本地模型,或从the Hugging Face Hub自动下载。

from diffusers import StableDiffusionPipeline

image_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
# 加载本地模型:
# image_pipe = StableDiffusionPipeline.from_pretrained("./models/Stablediffusion/stable-diffusion-v1-4")
image_pipe.to("cuda") prompt = "a photograph of an astronaut riding a horse"
pipe_out = image_pipe(prompt) image = pipe_out.images[0]
# you can save the image with
# image.save(f"astronaut_rides_horse.png")

我们查看下image_pipe的内容:

StableDiffusionPipeline {
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.10.2",
"feature_extractor": [
"transformers",
"CLIPFeatureExtractor"
],
"requires_safety_checker": true,
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}

查看Images的结构:

StableDiffusionPipelineOutput(
images=[<PIL.Image.Image image mode=RGB size=512x512 at 0x1A14BDD7730>],
nsfw_content_detected=[False])

由此,可以看到pipe_out的包含两部分,第一部分就是生成的图片列表,如果只有一张图片,则pipe_out.images[0]即可取出目标图像。

如果我们要一次生成多张图像呢?只需要修改prompt的list长度即可,代码如下。

from diffusers import StableDiffusionPipeline

image_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

image_pipe.to("cuda")
prompt = ["a photograph of an astronaut riding a horse"] * 3
out_images = image_pipe(prompt).images
for i, out_image in enumerate(out_images):
out_image.save("astronaut_rides_horse" + str(i) + ".png")

在使用image_pipe生成图像时,默认是float32精度的,若本地现在不足,可能会报Out of memory的错误,此时,可以通过加载float16精度的模型来解决。

Note: If you are limited by GPU memory and have less than 10GB of GPU RAM available, please make sure to load the StableDiffusionPipeline in float16 precision instead of the default float32 precision as done above.

You can do so by loading the weights from the fp16 branch and by telling diffusers to expect the weights to be in float16 precision:

image_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16)

对于每个PipeLine都有一些特定的配置,如StableDiffusionPipeline除了必要的prompt参数,还可以配置如下参数:

  • num_inference_steps: int = 50
  • guidance_scale: float = 7.5
  • generator: Optional[torch.Generator] = None
  • 等等

示例:如果你想要每次得到的结果均一致,可以设置每次的种子都一样

generator = torch.Generator("cuda").manual_seed(1024)
prompt = ["a photograph of an astronaut riding a horse"] * 3
out_images = image_pipe(prompt, generator=generator).images
再看训练

最新文章

  1. svn报错:“Previous operation has not finished; run &#39;cleanup&#39; if it was interrupted“ 的解决方法
  2. 【C语言入门教程】5.5 实现问题(效率)
  3. Information
  4. GridControl的用法(1)
  5. 一次MVVM+ReactiveCocoa实践
  6. iOS 添加阴影后 屏幕卡顿 抖动
  7. Android 每隔3s更新一次title
  8. hibernate-search-5.1.1简易使用
  9. MySQL高级特性——绑定变量
  10. spring boot 扫描不到自定义Controller
  11. A1014. Waiting in Line
  12. [蓝点ZigBee] Zstack 之按键驱动以及控制LED灯 ZigBee/CC2530 视频资料
  13. jQuery 查找元素2
  14. Azure PowerShell (13) 批量设置Azure ARM Network Security Group (NSG)
  15. How to Pronounce AR, ORN, etc.
  16. 【文文殿下】【洛谷】分治NTT模板
  17. 【three.js练习程序】动画效果,100个方块随机运动
  18. Elasticsearch java API (23)查询 DSL Geo查询
  19. Python pycurl使用
  20. bzoj 1270: [BeijingWc2008]雷涛的小猫 简单dp+滚动数组

热门文章

  1. python-CSV文件的读写
  2. 【终极解决办法】pyinstaller打包exe没有错误,运行exe提示Failed to execute script &#39;mainlmageWindows&#39; due tounhandled exception: No module named &#39;docx&#39;
  3. java中使用apache poi 读取 doc,docx,ppt,pptx,xls,xlsx,txt,csv格式的文件示例代码
  4. ChatGPT 可以联网了!浏览器插件下载
  5. 如何用 JavaScript 编写你的第一个单元测试
  6. m3u8文件后缀jpg,png等处理方法及视频合并
  7. C++可执行文件绝对路径获取与屏蔽VS安全检查
  8. ArcGIS工具 - 按要素裁切数据库
  9. [cocos2d-x]关于菜单项
  10. Java 进阶P-1.1+P-1.2