Run dalle on 6gb+ memory
This commit is contained in:
parent
fb8f9f253e
commit
ed4d5f6c1b
|
@ -26,6 +26,6 @@ c. Create a write token/Copy an existing Token key and enter in the cmd window.
|
|||
## Requirements
|
||||
* 11GB of free space.
|
||||
* Nvidia card with 8GB+ video memory.
|
||||
|
||||
* Currently min dalle can run in 6gb. Testing purpose
|
||||
## Additional info
|
||||
Tested on RTX3070. One picture was making 14 seconds.
|
||||
Tested on RTX3070. One picture was making 12 - 14 seconds.
|
|
@ -3,13 +3,14 @@ from distutils.log import error
|
|||
from msilib.schema import Directory
|
||||
import os
|
||||
import time
|
||||
from IPython.display import display, update_display
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from min_dalle import MinDalle
|
||||
|
||||
t = torch.cuda.get_device_properties(0).total_memory
|
||||
if t <=8500000000:
|
||||
|
||||
is_mega = True
|
||||
if t <=6400000000:
|
||||
print("Not enough GPU memory to generate pictures")
|
||||
else:
|
||||
amount = int(input("Amount: "))
|
||||
|
@ -26,11 +27,18 @@ else:
|
|||
dtype = "float32"
|
||||
if t <= 10500000000:
|
||||
dtype = "float16"
|
||||
|
||||
print("Running float16 type")
|
||||
else:
|
||||
print("Running float32 type")
|
||||
if t<= 8500000000:
|
||||
print("Not enough memory to run mega dalle. Running smaller variant")
|
||||
is_mega=False
|
||||
else:
|
||||
print("Running in mega mode")
|
||||
model = MinDalle(
|
||||
dtype=getattr(torch, dtype),
|
||||
device='cuda',
|
||||
is_mega=True,
|
||||
is_mega=False,
|
||||
is_reusable=True
|
||||
)
|
||||
|
||||
|
|
|
@ -46,6 +46,10 @@ else:
|
|||
if t <= 10500000000:
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
def dummy_checker(images, **kwargs):
|
||||
return images, False
|
||||
pipe.safety_checker = dummy_checker
|
||||
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
@ -76,7 +80,7 @@ else:
|
|||
if filename.endswith(".png"):
|
||||
path_img.append(os.path.join(root, filename))
|
||||
|
||||
print("Found %i pictures", len(path_img))
|
||||
print("Found " +str(len(path_img)) +" pictures")
|
||||
start_time = time.time()
|
||||
|
||||
counterr=0
|
||||
|
@ -99,10 +103,10 @@ else:
|
|||
|
||||
for seed in range(endSeed-startSeed+1):
|
||||
generator = torch.Generator(device=device).manual_seed(startSeed+seed)
|
||||
strenght = startStrength
|
||||
while strenght <= endStrength:
|
||||
guidance_scale = startScale
|
||||
while guidance_scale <= endScale:
|
||||
strenght = startStrength
|
||||
while strenght <= endStrength:
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt=prompt, init_image=init_image, strength=strenght, guidance_scale=guidance_scale, generator=generator)["sample"][0]
|
||||
image.save(directory+str(counterr)+"/" + str(allwork) +".jpg")
|
||||
|
@ -119,8 +123,9 @@ else:
|
|||
|
||||
allwork+=1
|
||||
|
||||
guidance_scale+=deltaScale
|
||||
strenght+=deltaStrength
|
||||
guidance_scale+=deltaScale
|
||||
|
||||
|
||||
counterr+=1
|
||||
|
||||
|
|
Loading…
Reference in New Issue