Skip to content

Commit ce0c892

Browse files
committed
offer ability to generate more than one text
1 parent dbbbcfd commit ce0c892

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,14 @@ $ python generate.py --dalle_path ./dalle.pt --text 'fireflies in a field under
386386

387387
You should see your images saved as `./outputs/{your prompt}/{image number}.jpg`
388388

389+
To generate multiple images, just pass in your text with '|' character as a separator.
390+
391+
ex.
392+
393+
```python
394+
$ python generate.py --dalle_path ./dalle.pt --text 'a dog chewing a bone|a cat chasing mice|a frog eating a fly'
395+
```
396+
389397
### Distributed Training
390398

391399
#### DeepSpeed

generate.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,27 @@ def exists(val):
8888

8989
image_size = vae.image_size
9090

91-
text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda()
91+
texts = args.text.split('|')
9292

93-
text = repeat(text, '() n -> b n', b = args.num_images)
93+
for text in tqdm(texts):
94+
text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda()
9495

95-
outputs = []
96+
text = repeat(text, '() n -> b n', b = args.num_images)
9697

97-
for text_chunk in tqdm(text.split(args.batch_size), desc = 'generating images'):
98-
output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
99-
outputs.append(output)
98+
outputs = []
10099

101-
outputs = torch.cat(outputs)
100+
for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {text}'):
101+
output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
102+
outputs.append(output)
102103

103-
# save all images
104+
outputs = torch.cat(outputs)
104105

105-
outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_')
106-
outputs_dir.mkdir(parents = True, exist_ok = True)
106+
# save all images
107107

108-
for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
109-
save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
108+
outputs_dir = Path(args.outputs_dir) / args.text.replace(' ', '_')
109+
outputs_dir.mkdir(parents = True, exist_ok = True)
110110

111-
print(f'created {args.num_images} images at "{str(outputs_dir)}"')
111+
for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
112+
save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
113+
114+
print(f'created {args.num_images} images at "{str(outputs_dir)}"')

0 commit comments

Comments
 (0)