Run dalle on 6gb+ memory

This commit is contained in:
Морозов Андрей 2022-09-23 19:24:12 +04:00
parent fb8f9f253e
commit ed4d5f6c1b
3 changed files with 27 additions and 14 deletions

View File

@ -26,6 +26,6 @@ c. Create a write token/Copy an existing Token key and enter in the cmd window.
## Requirements ## Requirements
* 11GB of free space. * 11GB of free space.
* Nvidia card with 8GB+ video memory. * Nvidia card with 8GB+ video memory.
* Currently min dalle can run in 6gb. Testing purpose
## Additional info ## Additional info
Tested on RTX3070. One picture was making 14 seconds. Tested on RTX3070. One picture was making 12 - 14 seconds.

View File

@ -3,13 +3,14 @@ from distutils.log import error
from msilib.schema import Directory from msilib.schema import Directory
import os import os
import time import time
from IPython.display import display, update_display
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from min_dalle import MinDalle from min_dalle import MinDalle
t = torch.cuda.get_device_properties(0).total_memory t = torch.cuda.get_device_properties(0).total_memory
if t <=8500000000:
is_mega = True
if t <=6400000000:
print("Not enough GPU memory to generate pictures") print("Not enough GPU memory to generate pictures")
else: else:
amount = int(input("Amount: ")) amount = int(input("Amount: "))
@ -26,11 +27,18 @@ else:
dtype = "float32" dtype = "float32"
if t <= 10500000000: if t <= 10500000000:
dtype = "float16" dtype = "float16"
print("Running float16 type")
else:
print("Running float32 type")
if t<= 8500000000:
print("Not enough memory to run mega dalle. Running smaller variant")
is_mega=False
else:
print("Running in mega mode")
model = MinDalle( model = MinDalle(
dtype=getattr(torch, dtype), dtype=getattr(torch, dtype),
device='cuda', device='cuda',
is_mega=True, is_mega=False,
is_reusable=True is_reusable=True
) )

View File

@ -46,6 +46,10 @@ else:
if t <= 10500000000: if t <= 10500000000:
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
def dummy_checker(images, **kwargs):
return images, False
pipe.safety_checker = dummy_checker
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
@ -76,7 +80,7 @@ else:
if filename.endswith(".png"): if filename.endswith(".png"):
path_img.append(os.path.join(root, filename)) path_img.append(os.path.join(root, filename))
print("Found %i pictures", len(path_img)) print("Found " +str(len(path_img)) +" pictures")
start_time = time.time() start_time = time.time()
counterr=0 counterr=0
@ -99,10 +103,10 @@ else:
for seed in range(endSeed-startSeed+1): for seed in range(endSeed-startSeed+1):
generator = torch.Generator(device=device).manual_seed(startSeed+seed) generator = torch.Generator(device=device).manual_seed(startSeed+seed)
strenght = startStrength guidance_scale = startScale
while strenght <= endStrength: while guidance_scale <= endScale:
guidance_scale = startScale strenght = startStrength
while guidance_scale <= endScale: while strenght <= endStrength:
with autocast("cuda"): with autocast("cuda"):
image = pipe(prompt=prompt, init_image=init_image, strength=strenght, guidance_scale=guidance_scale, generator=generator)["sample"][0] image = pipe(prompt=prompt, init_image=init_image, strength=strenght, guidance_scale=guidance_scale, generator=generator)["sample"][0]
image.save(directory+str(counterr)+"/" + str(allwork) +".jpg") image.save(directory+str(counterr)+"/" + str(allwork) +".jpg")
@ -119,8 +123,9 @@ else:
allwork+=1 allwork+=1
guidance_scale+=deltaScale strenght+=deltaStrength
strenght+=deltaStrength guidance_scale+=deltaScale
counterr+=1 counterr+=1