# -*- coding: utf-8 -*-
##--------------------------------------------------
## Stable Diffusion with diffusers(041) Ver 0.01
##
## 2025.06.17 Masahiro Izutsu
##--------------------------------------------------
## sd_041.py
## Ver 0.01 2025.06.17 SD1.5/SDXL 対応版
import warnings
warnings.simplefilter('ignore')
# Color Escape Code
GREEN = '\033[1;32m'
RED = '\033[1;31m'
NOCOLOR = '\033[0m'
YELLOW = '\033[1;33m'
CYAN = '\033[1;36m'
BLUE = '\033[1;34m'
from torch.cuda import is_available
gpu_d = is_available() # GPU 確認
# インポート&初期設定
import os
import argparse
import glob
import re
import random
import torch
from diffusers import StableDiffusionPipeline, logging # SD1.5
from diffusers import StableDiffusionXLPipeline, logging # SDXL
from translate import Translator
import my_logging
logging.set_verbosity_error() # 不要なエラー出力の抑制
# 定数定義
DEF_RESULT_IMAGE = './sd_results/sd.png'
DEF_MODEL_PATH = '/StabilityMatrix/Data/Models/StableDiffusion/SD1.5/v1-5-pruned-emaonly.safetensors'
#DEF_MODEL_PATH = '/StabilityMatrix/Data/Models/StableDiffusion/sd_xl_base_1.0.safetensors'
DEF_SEED = -1
DEF_PROMPT = '満開の蘭'
DEF_STEP = 30
DEF_SCALE = 7.5
DEF_WIDTH = 512
DEF_HEIGHT = 512
# タイトル
title = 'Stable Diffusion with diffusers(050) Ver 0.00'
# Parses arguments for the application
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--result_image", default = DEF_RESULT_IMAGE, help = "path to output image file")
parser.add_argument("--cpu", dest = "cpu", action = "store_true", help = "cpu mode.")
parser.add_argument('--log', metavar = 'LOG', default = '3', help = 'Log level(-1/0/1/2/3/4/5) Default value is \'3\'')
parser.add_argument("--model_path", default = DEF_MODEL_PATH, help = "Model Path")
parser.add_argument("--prompt", default = DEF_PROMPT, help = "Prompt text")
parser.add_argument("--seed", default = DEF_SEED, help = "Seed parameter (-1 = rundom")
parser.add_argument("--width", default = DEF_WIDTH, help = "image size width")
parser.add_argument("--height", default = DEF_HEIGHT, help = "image size height")
parser.add_argument("--step", default = DEF_STEP, help = "infer step")
parser.add_argument("--scale", default = DEF_SCALE, help = "gaidanse scale")
return parser
# 基本情報の表示
def display_info(opt, title):
print('\n' + GREEN + title + ': Starting application...' + NOCOLOR)
print('\n - ' + YELLOW + 'result_image : ' + NOCOLOR, opt.result_image)
print(' - ' + YELLOW + 'cpu : ' + NOCOLOR, opt.cpu)
print(' - ' + YELLOW + 'log : ' + NOCOLOR, opt.log)
print('\n - ' + YELLOW + 'model_path : ' + NOCOLOR, opt.model_path)
print(' - ' + YELLOW + 'prompt : ' + NOCOLOR, opt.prompt)
print(' - ' + YELLOW + 'seed : ' + NOCOLOR, opt.seed)
print(' - ' + YELLOW + 'width : ' + NOCOLOR, opt.width)
print(' - ' + YELLOW + 'height : ' + NOCOLOR, opt.height)
print(' - ' + YELLOW + 'step : ' + NOCOLOR, opt.step)
print(' - ' + YELLOW + 'scale : ' + NOCOLOR, opt.scale)
print(' ')
# フォルダ内で連番のファイル名を得る(ヘッダ + 連番 で検索)
# in: filename ヘッダー名.拡張子
# seq_digit 連番桁数
# ex 追加文字列
# out: 新しいファイル名
def make_filename_by_seq(dirname, filename, seq_digit = 3, ex = ''):
filename_without_ext, ext = os.path.splitext(filename)
pattern = f"{filename_without_ext}_([0-9]*){ext}"
prog = re.compile(pattern)
files = glob.glob(
os.path.join(dirname, f"{filename_without_ext}_[0-9]*{ext}")
)
max_seq = -1
kn = len(filename_without_ext) + seq_digit + 1 # ヘッダ + 連番桁数 + 1 (_)
for f in files:
fn0, _ = os.path.splitext(os.path.basename(f))
fn1 = fn0[: min(len(fn0), kn)] + ext
m = prog.match(os.path.basename(fn1)) # ファイル名 kn まで比較
if m:
max_seq = max(max_seq, int(m.group(1)))
new_filename = f"{filename_without_ext}_{max_seq+1:0{seq_digit}}_{ex}{ext}"
return new_filename
# モデルを調べる(SD1.5 モデルは SD1.5/フォルダ内にある前提)
# in: model モデル名
# out: bool True = SD1.5, False = SDXL
def is_sd15(model):
return ('SD1.5' in model)
# 画像生成
def image_generation(model, prompt, seed, num_inference_steps = 50, guidance_scale = 7.0, width = 512, height = 512, device = 'cpu'):
# パイプラインを作成
if is_sd15(model):
pipeline = StableDiffusionPipeline.from_single_file(model).to(device)
else:
pipeline = StableDiffusionXLPipeline.from_single_file(model, torch_dtype = torch.float16).to(device)
# Generatorオブジェクト作成
generator = torch.Generator(device).manual_seed(seed)
# 画像を生成
img = pipeline(
prompt = prompt,
num_inference_steps = num_inference_steps,
guidance_scale = guidance_scale,
width = width,
height = height,
generator = generator
).images[0]
return img
# メモリー開放
def device_empty_cache(device):
if device == 'cuda':
torch.cuda.empty_cache()
elif device == 'mps':
torch.mps.empty_cache()
# ** main関数 **
def main(opt):
# パラメータ設定
device = 'cpu' if opt.cpu else 'cuda'
logger.debug(f'device: {device}') # デバイス設定
result_path = os.path.dirname(opt.result_image)
result_file = os.path.basename(opt.result_image)
logger.debug(f'result_path: {result_path}')
os.makedirs(result_path, exist_ok = True) # 出力フォルダ
if len(opt.prompt) != len(opt.prompt.encode('utf-8')): # プロンプト
trans = Translator('en','ja').translate
prompt = trans(opt.prompt)
else:
prompt = opt.prompt
logger.info(f'prompt: {prompt}')
model_path = opt.model_path
width = int(opt.width)
height = int(opt.height)
if not is_sd15(model_path) and (width < 1024 or height < 1024):
width = width * 2
height = height * 2
logger.info(f'size: {width}, {height}')
seed = int(opt.seed)
if seed == -1: # ランダムなシード値を決める
seed = random.randint(0, 2**32-1)
logger.info(f'seed: {seed}')
num_inference_steps = int(opt.step)
guidance_scale = float(opt.scale)
# 画像生成
image = image_generation(model_path, prompt, seed, num_inference_steps, guidance_scale, width, height, device)
filename = result_path + '/' + make_filename_by_seq(result_path, result_file, seq_digit = 5, ex = seed)
image.save(filename)
logger.info(f'result_file: {filename}')
# main関数エントリーポイント(実行開始)
if __name__ == "__main__":
import datetime
parser = parse_args()
opt = parser.parse_args()
if not opt.cpu and not gpu_d:
opt.cpu = True
# アプリケーション・ログ設定
module = os.path.basename(__file__)
module_name = os.path.splitext(module)[0]
logger = my_logging.get_module_logger_sel(module_name, int(opt.log))
display_info(opt, title)
start_time = datetime.datetime.now() # 時間計測開始
main(opt)
# 経過時間
end_time = datetime.datetime.now()
print(start_time.strftime('\nprocessing start >>\t %Y/%m/%d %H:%M:%S'))
print(end_time.strftime('processing end >>\t %Y/%m/%d %H:%M:%S'))
print('processing time >>\t', end_time - start_time)
logger.info('\nFinished.\n')
※ 上記ソースコードは表示の都合上、半角コード '}' が 全角 '}'になっていることに注意