Stable Diffusion v2でGridを保存しない方法

Stable Diffusionではグリッド形式で作った画像をまとめて出力してくれます。
どうやらv2からskip_gridのパラメータがなくなったようなので追加することにします。

Gridのイメージは例えばこんな感じです。
4つの画像が1にまとめられて出力されたりします。

さほど出力スピードに差がないので別にあってもよいのですが無くてもよい。
というわけで今回はStable Diffusion v2を少し弄ってskip_gridのパラメータを追加します。

Stable Diffusion v2と言っているもの → https://github.com/Stability-AI/StableDiffusion

修正箇所

修正箇所の1つめ

ファイル:scripts/txt2img.py
関数名:parse_args

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a professional photograph of an astronaut riding a triceratops",
        help="the prompt to render"
    )
    # ↓ここを追加
    parser.add_argument(
        "--skip_grid",
        action='store_true',
        help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
    )
    # ↑ここを追加

修正箇所の2つめ

ファイル:scripts/txt2img.py(同じ)
関数名:main
行数:368行目あたり

修正前

                    all_samples.append(x_samples)

            # additionally, save as grid
            grid = torch.stack(all_samples, 0)
            grid = rearrange(grid, 'n b c h w -> (n b) c h w')
            grid = make_grid(grid, nrow=n_rows)

            # to image
            grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
            grid = Image.fromarray(grid.astype(np.uint8))
            grid = put_watermark(grid, wm_encoder)
            grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
            grid_count += 1

修正後

                    if not opt.skip_grid:
                        all_samples.append(x_samples)

            if not opt.skip_grid:
                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                grid = Image.fromarray(grid.astype(np.uint8))
                grid = put_watermark(grid, wm_encoder)
                grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
                grid_count += 1

コマンド実行

適当にプロンプトを実行します。

python scripts/txt2img.py --prompt "a cute himalayan cat" --ckpt checkpoints/v2-1_512-ema-pruned.ckpt --config configs/stable-diffusion/v2-inference.yaml --H 512 --W 512 --n_samples 1 --device cuda --skip_grid

以下の部分が重要です。

--skip_grid

まとめ

なくさなくても良いと思うのですけどね。
確かにどっちでもよいといえばどっちでもよい。

とりあえず修正したものをGitHubに置いておきます。

https://github.com/zeikomi552/stablediffusion/

おわり

PR

コメント

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です