72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
from msilib.schema import Directory
|
|
import os
|
|
import time
|
|
import torch
|
|
from min_dalle import MinDalle
|
|
|
|
t = torch.cuda.get_device_properties(0).total_memory
|
|
|
|
is_mega = True
|
|
if t <=6400000000:
|
|
print("Not enough GPU memory to generate pictures")
|
|
else:
|
|
amount = int(input("Amount: "))
|
|
|
|
directory = "./data/input/"
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
for root, subdirectories, files in os.walk(directory):
|
|
for filename in files:
|
|
if filename.endswith(".png"):
|
|
path_img = os.path.join(root, filename)
|
|
os.remove(path_img)
|
|
|
|
dtype = "float32"
|
|
if t <= 10500000000:
|
|
dtype = "float16"
|
|
print("Running float16 type")
|
|
else:
|
|
print("Running float32 type")
|
|
if t<= 8500000000:
|
|
print("Not enough memory to run mega dalle. Running smaller variant")
|
|
is_mega=False
|
|
else:
|
|
print("Running in mega mode")
|
|
model = MinDalle(
|
|
dtype=getattr(torch, dtype),
|
|
device='cuda',
|
|
is_mega=False,
|
|
is_reusable=True
|
|
)
|
|
|
|
file1 = open("prompt.txt","r+")
|
|
text = file1.read()
|
|
print(text)
|
|
|
|
##text = "Brave cat knight in closeface helmet with big cat eyes and cat fur, open neck, full face, rainy background, raytraced, digital art , 4k , highly detailed , trending on artstation, close to life" #@param {type:"string"}
|
|
|
|
progressive_outputs = True
|
|
seamless = True
|
|
grid_size = 1
|
|
temperature = 2
|
|
supercondition_factor = 16
|
|
top_k = 128
|
|
|
|
|
|
start_time = time.time()
|
|
for counterr in range(amount):
|
|
print("Making pic "+str(counterr))
|
|
image_stream = model.generate_image_stream(
|
|
text=text,
|
|
seed=-1,
|
|
grid_size=grid_size,
|
|
progressive_outputs=progressive_outputs,
|
|
is_seamless=seamless,
|
|
temperature=temperature,
|
|
top_k=int(top_k),
|
|
supercondition_factor=float(supercondition_factor)
|
|
)
|
|
for image in image_stream:
|
|
image.save(directory+str(counterr)+".png")
|
|
|
|
print("Made " + str(amount) + " pictures in " + str(time.time()-start_time) + " seconds") |