MinDalle_StableDiff/scripts/minDalle.py

67 lines
2.0 KiB
Python
Raw Normal View History

2022-09-17 15:14:39 +03:00
from ast import parse
from distutils.log import error
from msilib.schema import Directory
import os
import time
from IPython.display import display, update_display
import torch
from tqdm.auto import tqdm
from min_dalle import MinDalle
t = torch.cuda.get_device_properties(0).total_memory
if t <=8500000000:
print("Not enough GPU memory to generate pictures")
else:
amount = int(input("Amount: "))
directory = "./data/input/"
if not os.path.exists(directory):
os.makedirs(directory)
for root, subdirectories, files in os.walk(directory):
for filename in files:
if filename.endswith(".png"):
path_img = os.path.join(root, filename)
os.remove(path_img)
dtype = "float32"
if t <= 10500000000:
dtype = "float16"
model = MinDalle(
dtype=getattr(torch, dtype),
device='cuda',
is_mega=True,
is_reusable=True
)
file1 = open("prompt.txt","r+")
text = file1.read()
print(text)
##text = "Brave cat knight in closeface helmet with big cat eyes and cat fur, open neck, full face, rainy background, raytraced, digital art , 4k , highly detailed , trending on artstation, close to life" #@param {type:"string"}
progressive_outputs = True
seamless = True
grid_size = 1
temperature = 2
supercondition_factor = 16
top_k = 128
start_time = time.time()
for counterr in range(amount):
print("Making pic "+str(counterr))
image_stream = model.generate_image_stream(
text=text,
seed=-1,
grid_size=grid_size,
progressive_outputs=progressive_outputs,
is_seamless=seamless,
temperature=temperature,
top_k=int(top_k),
supercondition_factor=float(supercondition_factor)
)
for image in image_stream:
image.save(directory+str(counterr)+".png")
print("Made " + str(amount) + " pictures in " + str(time.time()-start_time) + " seconds")