Run dalle on 6gb+ memory
This commit is contained in:
parent
fb8f9f253e
commit
ed4d5f6c1b
|
@ -26,6 +26,6 @@ c. Create a write token/Copy an existing Token key and enter in the cmd window.
|
||||||
## Requirements
|
## Requirements
|
||||||
* 11GB of free space.
|
* 11GB of free space.
|
||||||
* Nvidia card with 8GB+ video memory.
|
* Nvidia card with 8GB+ video memory.
|
||||||
|
* Currently min dalle can run in 6gb. Testing purpose
|
||||||
## Additional info
|
## Additional info
|
||||||
Tested on RTX3070. One picture was making 14 seconds.
|
Tested on RTX3070. One picture was making 12 - 14 seconds.
|
|
@ -3,13 +3,14 @@ from distutils.log import error
|
||||||
from msilib.schema import Directory
|
from msilib.schema import Directory
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from IPython.display import display, update_display
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from min_dalle import MinDalle
|
from min_dalle import MinDalle
|
||||||
|
|
||||||
t = torch.cuda.get_device_properties(0).total_memory
|
t = torch.cuda.get_device_properties(0).total_memory
|
||||||
if t <=8500000000:
|
|
||||||
|
is_mega = True
|
||||||
|
if t <=6400000000:
|
||||||
print("Not enough GPU memory to generate pictures")
|
print("Not enough GPU memory to generate pictures")
|
||||||
else:
|
else:
|
||||||
amount = int(input("Amount: "))
|
amount = int(input("Amount: "))
|
||||||
|
@ -26,11 +27,18 @@ else:
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
if t <= 10500000000:
|
if t <= 10500000000:
|
||||||
dtype = "float16"
|
dtype = "float16"
|
||||||
|
print("Running float16 type")
|
||||||
|
else:
|
||||||
|
print("Running float32 type")
|
||||||
|
if t<= 8500000000:
|
||||||
|
print("Not enough memory to run mega dalle. Running smaller variant")
|
||||||
|
is_mega=False
|
||||||
|
else:
|
||||||
|
print("Running in mega mode")
|
||||||
model = MinDalle(
|
model = MinDalle(
|
||||||
dtype=getattr(torch, dtype),
|
dtype=getattr(torch, dtype),
|
||||||
device='cuda',
|
device='cuda',
|
||||||
is_mega=True,
|
is_mega=False,
|
||||||
is_reusable=True
|
is_reusable=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,10 @@ else:
|
||||||
if t <= 10500000000:
|
if t <= 10500000000:
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
|
def dummy_checker(images, **kwargs):
|
||||||
|
return images, False
|
||||||
|
pipe.safety_checker = dummy_checker
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||||
|
@ -76,7 +80,7 @@ else:
|
||||||
if filename.endswith(".png"):
|
if filename.endswith(".png"):
|
||||||
path_img.append(os.path.join(root, filename))
|
path_img.append(os.path.join(root, filename))
|
||||||
|
|
||||||
print("Found %i pictures", len(path_img))
|
print("Found " +str(len(path_img)) +" pictures")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
counterr=0
|
counterr=0
|
||||||
|
@ -99,10 +103,10 @@ else:
|
||||||
|
|
||||||
for seed in range(endSeed-startSeed+1):
|
for seed in range(endSeed-startSeed+1):
|
||||||
generator = torch.Generator(device=device).manual_seed(startSeed+seed)
|
generator = torch.Generator(device=device).manual_seed(startSeed+seed)
|
||||||
strenght = startStrength
|
|
||||||
while strenght <= endStrength:
|
|
||||||
guidance_scale = startScale
|
guidance_scale = startScale
|
||||||
while guidance_scale <= endScale:
|
while guidance_scale <= endScale:
|
||||||
|
strenght = startStrength
|
||||||
|
while strenght <= endStrength:
|
||||||
with autocast("cuda"):
|
with autocast("cuda"):
|
||||||
image = pipe(prompt=prompt, init_image=init_image, strength=strenght, guidance_scale=guidance_scale, generator=generator)["sample"][0]
|
image = pipe(prompt=prompt, init_image=init_image, strength=strenght, guidance_scale=guidance_scale, generator=generator)["sample"][0]
|
||||||
image.save(directory+str(counterr)+"/" + str(allwork) +".jpg")
|
image.save(directory+str(counterr)+"/" + str(allwork) +".jpg")
|
||||||
|
@ -119,8 +123,9 @@ else:
|
||||||
|
|
||||||
allwork+=1
|
allwork+=1
|
||||||
|
|
||||||
guidance_scale+=deltaScale
|
|
||||||
strenght+=deltaStrength
|
strenght+=deltaStrength
|
||||||
|
guidance_scale+=deltaScale
|
||||||
|
|
||||||
|
|
||||||
counterr+=1
|
counterr+=1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue