MinDalle_StableDiff/scripts/minDalle.py

75 lines
2.2 KiB
Python

from ast import parse
from distutils.log import error
from msilib.schema import Directory
import os
import time
import torch
from tqdm.auto import tqdm
from min_dalle import MinDalle
t = torch.cuda.get_device_properties(0).total_memory
is_mega = True
if t <=6400000000:
print("Not enough GPU memory to generate pictures")
else:
amount = int(input("Amount: "))
directory = "./data/input/"
if not os.path.exists(directory):
os.makedirs(directory)
for root, subdirectories, files in os.walk(directory):
for filename in files:
if filename.endswith(".png"):
path_img = os.path.join(root, filename)
os.remove(path_img)
dtype = "float32"
if t <= 10500000000:
dtype = "float16"
print("Running float16 type")
else:
print("Running float32 type")
if t<= 8500000000:
print("Not enough memory to run mega dalle. Running smaller variant")
is_mega=False
else:
print("Running in mega mode")
model = MinDalle(
dtype=getattr(torch, dtype),
device='cuda',
is_mega=False,
is_reusable=True
)
file1 = open("prompt.txt","r+")
text = file1.read()
print(text)
##text = "Brave cat knight in closeface helmet with big cat eyes and cat fur, open neck, full face, rainy background, raytraced, digital art , 4k , highly detailed , trending on artstation, close to life" #@param {type:"string"}
progressive_outputs = True
seamless = True
grid_size = 1
temperature = 2
supercondition_factor = 16
top_k = 128
start_time = time.time()
for counterr in range(amount):
print("Making pic "+str(counterr))
image_stream = model.generate_image_stream(
text=text,
seed=-1,
grid_size=grid_size,
progressive_outputs=progressive_outputs,
is_seamless=seamless,
temperature=temperature,
top_k=int(top_k),
supercondition_factor=float(supercondition_factor)
)
for image in image_stream:
image.save(directory+str(counterr)+".png")
print("Made " + str(amount) + " pictures in " + str(time.time()-start_time) + " seconds")