2022-09-17 15:14:39 +03:00
from ast import parse
from distutils . log import error
from msilib . schema import Directory
import os
import time
import torch
from tqdm . auto import tqdm
from min_dalle import MinDalle
t = torch . cuda . get_device_properties ( 0 ) . total_memory
2022-09-23 18:24:12 +03:00
is_mega = True
if t < = 6400000000 :
2022-09-17 15:14:39 +03:00
print ( " Not enough GPU memory to generate pictures " )
else :
amount = int ( input ( " Amount: " ) )
directory = " ./data/input/ "
if not os . path . exists ( directory ) :
os . makedirs ( directory )
for root , subdirectories , files in os . walk ( directory ) :
for filename in files :
if filename . endswith ( " .png " ) :
path_img = os . path . join ( root , filename )
os . remove ( path_img )
dtype = " float32 "
if t < = 10500000000 :
dtype = " float16 "
2022-09-23 18:24:12 +03:00
print ( " Running float16 type " )
else :
print ( " Running float32 type " )
if t < = 8500000000 :
print ( " Not enough memory to run mega dalle. Running smaller variant " )
is_mega = False
else :
print ( " Running in mega mode " )
2022-09-17 15:14:39 +03:00
model = MinDalle (
dtype = getattr ( torch , dtype ) ,
device = ' cuda ' ,
2022-09-23 18:24:12 +03:00
is_mega = False ,
2022-09-17 15:14:39 +03:00
is_reusable = True
)
file1 = open ( " prompt.txt " , " r+ " )
text = file1 . read ( )
print ( text )
##text = "Brave cat knight in closeface helmet with big cat eyes and cat fur, open neck, full face, rainy background, raytraced, digital art , 4k , highly detailed , trending on artstation, close to life" #@param {type:"string"}
progressive_outputs = True
seamless = True
grid_size = 1
temperature = 2
supercondition_factor = 16
top_k = 128
start_time = time . time ( )
for counterr in range ( amount ) :
print ( " Making pic " + str ( counterr ) )
image_stream = model . generate_image_stream (
text = text ,
seed = - 1 ,
grid_size = grid_size ,
progressive_outputs = progressive_outputs ,
is_seamless = seamless ,
temperature = temperature ,
top_k = int ( top_k ) ,
supercondition_factor = float ( supercondition_factor )
)
for image in image_stream :
image . save ( directory + str ( counterr ) + " .png " )
print ( " Made " + str ( amount ) + " pictures in " + str ( time . time ( ) - start_time ) + " seconds " )