From ed4d5f6c1bcf38b33ab338df2269f3bb85fbf7b2 Mon Sep 17 00:00:00 2001 From: lnd212 Date: Fri, 23 Sep 2022 19:24:12 +0400 Subject: [PATCH] Run dalle on 6gb+ memory --- README.md | 4 ++-- scripts/minDalle.py | 16 ++++++++++++---- scripts/stable_diff.py | 21 +++++++++++++-------- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index c530fe3..8dfb7b2 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,6 @@ c. Create a write token/Copy an existing Token key and enter in the cmd window. ## Requirements * 11GB of free space. * Nvidia card with 8GB+ video memory. - + * Currently min dalle can run in 6gb. Testing purpose ## Additional info -Tested on RTX3070. One picture was making 14 seconds. \ No newline at end of file +Tested on RTX3070. One picture was making 12 - 14 seconds. \ No newline at end of file diff --git a/scripts/minDalle.py b/scripts/minDalle.py index 9f377b3..4e9513b 100644 --- a/scripts/minDalle.py +++ b/scripts/minDalle.py @@ -3,13 +3,14 @@ from distutils.log import error from msilib.schema import Directory import os import time -from IPython.display import display, update_display import torch from tqdm.auto import tqdm from min_dalle import MinDalle t = torch.cuda.get_device_properties(0).total_memory -if t <=8500000000: + +is_mega = True +if t <=6400000000: print("Not enough GPU memory to generate pictures") else: amount = int(input("Amount: ")) @@ -26,11 +27,18 @@ else: dtype = "float32" if t <= 10500000000: dtype = "float16" - + print("Running float16 type") + else: + print("Running float32 type") + if t<= 8500000000: + print("Not enough memory to run mega dalle. Running smaller variant") + is_mega=False + else: + print("Running in mega mode") model = MinDalle( dtype=getattr(torch, dtype), device='cuda', - is_mega=True, + is_mega=False, is_reusable=True ) diff --git a/scripts/stable_diff.py b/scripts/stable_diff.py index 87d13da..2b6c9cf 100644 --- a/scripts/stable_diff.py +++ b/scripts/stable_diff.py @@ -45,6 +45,10 @@ else: ).to(device) if t <= 10500000000: pipe.enable_attention_slicing() + + def dummy_checker(images, **kwargs): + return images, False + pipe.safety_checker = dummy_checker def preprocess(image): w, h = image.size @@ -76,7 +80,7 @@ else: if filename.endswith(".png"): path_img.append(os.path.join(root, filename)) - print("Found %i pictures", len(path_img)) + print("Found " +str(len(path_img)) +" pictures") start_time = time.time() counterr=0 @@ -99,10 +103,10 @@ else: for seed in range(endSeed-startSeed+1): generator = torch.Generator(device=device).manual_seed(startSeed+seed) - strenght = startStrength - while strenght <= endStrength: - guidance_scale = startScale - while guidance_scale <= endScale: + guidance_scale = startScale + while guidance_scale <= endScale: + strenght = startStrength + while strenght <= endStrength: with autocast("cuda"): image = pipe(prompt=prompt, init_image=init_image, strength=strenght, guidance_scale=guidance_scale, generator=generator)["sample"][0] image.save(directory+str(counterr)+"/" + str(allwork) +".jpg") @@ -118,9 +122,10 @@ else: piexif.insert(exif_bytes, directory+str(counterr)+"/" + str(allwork) +".jpg") allwork+=1 - - guidance_scale+=deltaScale - strenght+=deltaStrength + + strenght+=deltaStrength + guidance_scale+=deltaScale + counterr+=1