Unverified Commit f42e0aae authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into master

parents 0e77ee24 d13ce89e
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
name: Feature request name: Feature request
about: Suggest an idea for this project about: Suggest an idea for this project
title: '' title: ''
labels: '' labels: 'suggestion'
assignees: '' assignees: ''
--- ---
......
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
If you have a large change, pay special attention to this paragraph:
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
**Describe what this pull request is trying to achieve.**
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
**Additional notes and description of your changes**
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
**Environment this was tested in**
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux]
- Browser [e.g. chrome, safari]
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes**
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
This is **required** for anything that touches the user interface.
\ No newline at end of file
...@@ -17,6 +17,7 @@ __pycache__ ...@@ -17,6 +17,7 @@ __pycache__
/webui.settings.bat /webui.settings.bat
/embeddings /embeddings
/styles.csv /styles.csv
/params.txt
/styles.csv.bak /styles.csv.bak
/webui-user.bat /webui-user.bat
/webui-user.sh /webui-user.sh
...@@ -25,3 +26,4 @@ __pycache__ ...@@ -25,3 +26,4 @@ __pycache__
/.idea /.idea
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion
* @AUTOMATIC1111
...@@ -11,12 +11,13 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -11,12 +11,13 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git) - One click install and run script (but you still must install python and git)
- Outpainting - Outpainting
- Inpainting - Inpainting
- Prompt - Prompt Matrix
- Stable Diffusion upscale - Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to - Attention, specify parts of text that the model should pay more attention to
- a man in a ((txuedo)) - will pay more attentinoto tuxedo - a man in a ((tuxedo)) - will pay more attention to tuxedo
- a man in a (txuedo:1.21) - alternative syntax - a man in a (tuxedo:1.21) - alternative syntax
- Loopback, run img2img procvessing multiple times - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters - X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them - have as many embeddings as you want and use any names you like for them
...@@ -27,23 +28,25 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -27,23 +28,25 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler - RealESRGAN, neural network upscaler
- ESRGAN, neural network upscaler with a lot of third party models - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler - SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- LDSR, Latent diffusion super resolution upscaling - LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options - Resizing aspect ratio options
- Sampling method selection - Sampling method selection
- Adjust sampler eta values (noise multiplier)
- More advanced noise setting options
- Interrupt processing at any time - Interrupt processing at any time
- 4GB video card support (also reports of 2GB working) - 4GB video card support (also reports of 2GB working)
- Correct seeds for batches - Correct seeds for batches
- Prompt length validation - Prompt length validation
- get length of prompt in tokensas you type - get length of prompt in tokens as you type
- get a warning after geenration if some text was truncated - get a warning after generation if some text was truncated
- Generation parameters - Generation parameters
- parameters you used to generate images are saved with that image - parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG - in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings - can be disabled in settings
- Settings page - Settings page
- Running arbitrary python code from UI (must run with commandline flag to enable) - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button - Random artist button
...@@ -61,6 +64,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -61,6 +64,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Reloading checkpoints on the fly - Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one - Checkpoint Merger, a tab that allows you to merge two checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community - [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
- separate prompts using uppercase `AND`
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -110,11 +119,17 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -110,11 +119,17 @@ The documentation was moved from this README over to the project's [wiki](https:
- CodeFormer - https://github.com/sczhou/CodeFormer - CodeFormer - https://github.com/sczhou/CodeFormer
- ESRGAN - https://github.com/xinntao/ESRGAN - ESRGAN - https://github.com/xinntao/ESRGAN
- SwinIR - https://github.com/JingyunLiang/SwinIR - SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -1045,7 +1045,6 @@ Bakemono Zukushi,0.67051035,anime ...@@ -1045,7 +1045,6 @@ Bakemono Zukushi,0.67051035,anime
Lucy Madox Brown,0.67032814,fineart Lucy Madox Brown,0.67032814,fineart
Paul Wonner,0.6700563,scribbles Paul Wonner,0.6700563,scribbles
Guido Borelli Da Caluso,0.66966087,digipa-high-impact Guido Borelli Da Caluso,0.66966087,digipa-high-impact
Guido Borelli da Caluso,0.66966087,digipa-high-impact
Emil Alzamora,0.5844039,nudity Emil Alzamora,0.5844039,nudity
Heinrich Brocksieper,0.64469147,fineart Heinrich Brocksieper,0.64469147,fineart
Dan Smith,0.669563,digipa-high-impact Dan Smith,0.669563,digipa-high-impact
......
...@@ -3,9 +3,9 @@ channels: ...@@ -3,9 +3,9 @@ channels:
- pytorch - pytorch
- defaults - defaults
dependencies: dependencies:
- python=3.8.5 - python=3.10
- pip=20.3 - pip=22.2.2
- cudatoolkit=11.3 - cudatoolkit=11.3
- pytorch=1.11.0 - pytorch=1.12.1
- torchvision=0.12.0 - torchvision=0.13.1
- numpy=1.19.2 - numpy=1.23.1
\ No newline at end of file
contextMenuInit = function(){
let eventListenerApplied=false;
let menuSpecs = new Map();
const uid = function(){
return Date.now().toString(36) + Math.random().toString(36).substr(2);
}
function showContextMenu(event,element,menuEntries){
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
}
let tabButton = uiCurrentTab
let baseStyle = window.getComputedStyle(tabButton)
const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu"
contextMenu.style.background = baseStyle.background
contextMenu.style.color = baseStyle.color
contextMenu.style.fontFamily = baseStyle.fontFamily
contextMenu.style.top = posy+'px'
contextMenu.style.left = posx+'px'
const contextMenuList = document.createElement('ul')
contextMenuList.className = 'context-menu-items';
contextMenu.append(contextMenuList);
menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name']
contextMenuEntry.addEventListener("click", function(e) {
entry['func']();
})
contextMenuList.append(contextMenuEntry);
})
gradioApp().getRootNode().appendChild(contextMenu)
let menuWidth = contextMenu.offsetWidth + 4;
let menuHeight = contextMenu.offsetHeight + 4;
let windowWidth = window.innerWidth;
let windowHeight = window.innerHeight;
if ( (windowWidth - posx) < menuWidth ) {
contextMenu.style.left = windowWidth - menuWidth + "px";
}
if ( (windowHeight - posy) < menuHeight ) {
contextMenu.style.top = windowHeight - menuHeight + "px";
}
}
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetEmementSelector)
if(!currentItems){
currentItems = []
menuSpecs.set(targetEmementSelector,currentItems);
}
let newItem = {'id':targetEmementSelector+'_'+uid(),
'name':entryName,
'func':entryFunction,
'isNew':true}
currentItems.push(newItem)
return newItem['id']
}
function removeContextMenuOption(uid){
menuSpecs.forEach(function(v,k) {
let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){
v.splice(index, 1);
}
})
}
function addContextMenuEventListener(){
if(eventListenerApplied){
return;
}
gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0]
if(source.id && source.id.indexOf('check_progress')>-1){
return
}
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
}
});
gradioApp().addEventListener("contextmenu", function(e) {
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
}
menuSpecs.forEach(function(v,k) {
if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v)
e.preventDefault()
return
}
})
});
eventListenerApplied=true
}
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
}
initResponse = contextMenuInit();
appendContextMenuOption = initResponse[0];
removeContextMenuOption = initResponse[1];
addContextMenuEventListener = initResponse[2];
(function(){
//Start example Context Menu Items
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
let genbutton = gradioApp().querySelector(genbuttonid);
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
if(!interruptbutton.offsetParent){
genbutton.click();
}
clearInterval(window.generateOnRepeatInterval)
window.generateOnRepeatInterval = setInterval(function(){
if(!interruptbutton.offsetParent){
genbutton.click();
}
},
500)
}
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
})
appendContextMenuOption('#img2img_generate','Generate forever',function(){
generateOnRepeat('#img2img_generate','#img2img_interrupt');
})
let cancelGenerateForever = function(){
clearInterval(window.generateOnRepeatInterval)
}
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
})();
//End example Context Menu Items
onUiUpdate(function(){
addContextMenuEventListener()
});
...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if ( !imgWrap && target.placeholder != "Prompt") {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
...@@ -53,6 +53,9 @@ window.document.addEventListener('dragover', e => { ...@@ -53,6 +53,9 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => { window.document.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder === "Prompt") {
return;
}
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if ( !imgWrap ) {
return; return;
......
addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0];
if (!target.hasAttribute("placeholder")) return;
if (!target.placeholder.toLowerCase().includes("prompt")) return;
let plus = "ArrowUp"
let minus = "ArrowDown"
if (event.key != plus && event.key != minus) return;
selectionStart = target.selectionStart;
selectionEnd = target.selectionEnd;
if(selectionStart == selectionEnd) return;
event.preventDefault();
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
target.value = target.value.slice(0, selectionStart) +
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
target.value.slice(selectionEnd);
target.focus();
target.selectionStart = selectionStart + 1;
target.selectionEnd = selectionEnd + 1;
} else {
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
if (event.key == minus) weight -= 0.1;
if (event.key == plus) weight += 0.1;
weight = parseFloat(weight.toPrecision(12));
target.value = target.value.slice(0, selectionEnd + 1) +
weight +
target.value.slice(selectionEnd + 1 + end - 1);
target.focus();
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
}
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
// internal Svelte data binding remains in sync.
target.dispatchEvent(new Event("input", { bubbles: true }));
});
...@@ -14,8 +14,8 @@ titles = { ...@@ -14,8 +14,8 @@ titles = {
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\uD83D\uDCC2": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
...@@ -35,6 +35,7 @@ titles = { ...@@ -35,6 +35,7 @@ titles = {
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.", "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
"Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.", "Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
...@@ -47,6 +48,7 @@ titles = { ...@@ -47,6 +48,7 @@ titles = {
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
"Tiling": "Produce an image that can be tiled.", "Tiling": "Produce an image that can be tiled.",
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
...@@ -77,6 +79,16 @@ titles = { ...@@ -77,6 +79,16 @@ titles = {
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
"Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
} }
......
window.onload = (function(){
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
const idx = selected_gallery_index();
if (target.placeholder != "Prompt") return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
e.stopPropagation();
e.preventDefault();
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if ( fileInput ) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
});
});
var images_history_click_image = function(){
if (!this.classList.contains("transform")){
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
var buttons = gallery.querySelectorAll(".gallery-item");
var i = 0;
var hidden_list = [];
buttons.forEach(function(e){
if (e.style.display == "none"){
hidden_list.push(i);
}
i += 1;
})
if (hidden_list.length > 0){
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
}
}
images_history_set_image_info(this);
}
var images_history_click_tab = function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
}
function images_history_disabled_del(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.setAttribute('disabled','disabled');
});
}
function images_history_get_parent_by_class(item, class_name){
var parent = item.parentElement;
while(!parent.classList.contains(class_name)){
parent = parent.parentElement;
}
return parent;
}
function images_history_get_parent_by_tagname(item, tagname){
var parent = item.parentElement;
tagname = tagname.toUpperCase()
while(parent.tagName != tagname){
console.log(parent.tagName, tagname)
parent = parent.parentElement;
}
return parent;
}
function images_history_hide_buttons(hidden_list, gallery){
var buttons = gallery.querySelectorAll(".gallery-item");
var num = 0;
buttons.forEach(function(e){
if (e.style.display == "none"){
num += 1;
}
});
if (num == hidden_list.length){
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
}
for( i in hidden_list){
buttons[hidden_list[i]].style.display = "none";
}
}
function images_history_set_image_info(button){
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
var index = -1;
var i = 0;
buttons.forEach(function(e){
if(e == button){
index = i;
}
if(e.style.display != "none"){
i += 1;
}
});
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
var set_btn = gallery.querySelector(".images_history_set_index");
var curr_idx = set_btn.getAttribute("img_index", index);
if (curr_idx != index) {
set_btn.setAttribute("img_index", index);
images_history_disabled_del();
}
set_btn.click();
}
function images_history_get_current_img(tabname, image_path, files){
return [
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
image_path,
files
];
}
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
image_index = parseInt(image_index);
var tab = gradioApp().getElementById(tabname + '_images_history');
var set_btn = tab.querySelector(".images_history_set_index");
var buttons = [];
tab.querySelectorAll(".gallery-item").forEach(function(e){
if (e.style.display != 'none'){
buttons.push(e);
}
});
var img_num = buttons.length / 2;
if (img_num <= del_num){
setTimeout(function(tabname){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, 30, tabname);
} else {
var next_img
for (var i = 0; i < del_num; i++){
if (image_index + i < image_index + img_num){
buttons[image_index + i].style.display = 'none';
buttons[image_index + img_num + 1].style.display = 'none';
next_img = image_index + i + 1
}
}
var bnt;
if (next_img >= img_num){
btn = buttons[image_index - del_num];
} else {
btn = buttons[next_img];
}
setTimeout(function(btn){btn.click()}, 30, btn);
}
images_history_disabled_del();
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
}
function images_history_turnpage(img_path, page_index, image_index, tabname){
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
buttons.forEach(function(elem) {
elem.style.display = 'block';
})
return [img_path, page_index, image_index, tabname];
}
function images_history_enable_del_buttons(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.removeAttribute('disabled');
})
}
function images_history_init(){
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
if (load_txt2img_button){
for (var i in images_history_tab_list ){
tab = images_history_tab_list[i];
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
}
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname);
// this refreshes history upon tab switch
// until the history is known to work well, which is not the case now, we do not do this at startup
//tab_btns[i].addEventListener('click', images_history_click_tab);
}
tabs_box.classList.add(images_history_tab_list[0]);
// same as above, at page load
//load_txt2img_button.click();
} else {
setTimeout(images_history_init, 500);
}
}
var images_history_tab_list = ["txt2img", "img2img", "extras"];
setTimeout(images_history_init, 500);
document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){
for (var i in images_history_tab_list ){
let tabname = images_history_tab_list[i]
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
buttons.forEach(function(bnt){
bnt.addEventListener('click', images_history_click_image, true);
});
// same as load_txt2img_button.click() above
/*
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
if (cls_btn){
cls_btn.addEventListener('click', function(){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, false);
}*/
}
});
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
});
// A full size 'lightbox' preview modal shown when left clicking on gallery previews // A full size 'lightbox' preview modal shown when left clicking on gallery previews
function closeModal() { function closeModal() {
gradioApp().getElementById("lightboxModal").style.display = "none"; gradioApp().getElementById("lightboxModal").style.display = "none";
} }
...@@ -21,29 +20,53 @@ function negmod(n, m) { ...@@ -21,29 +20,53 @@ function negmod(n, m) {
return ((n % m) + m) % m; return ((n % m) + m) % m;
} }
function modalImageSwitch(offset){ function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage")
if (modalImage && modalImage.offsetParent) {
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
let currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
if (modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
modal.style.setProperty('background-image', `url(${modalImage.src})`)
}
}
}
}
function modalImageSwitch(offset) {
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
var galleryButtons = [] var galleryButtons = []
allgalleryButtons.forEach(function(elem){ allgalleryButtons.forEach(function(elem) {
if(elem.parentElement.offsetParent){ if (elem.parentElement.offsetParent) {
galleryButtons.push(elem); galleryButtons.push(elem);
} }
}) })
if(galleryButtons.length>1){ if (galleryButtons.length > 1) {
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
var currentButton = null var currentButton = null
allcurrentButtons.forEach(function(elem){ allcurrentButtons.forEach(function(elem) {
if(elem.parentElement.offsetParent){ if (elem.parentElement.offsetParent) {
currentButton = elem; currentButton = elem;
} }
}) })
var result = -1 var result = -1
galleryButtons.forEach(function(v, i){ if(v==currentButton) { result = i } }) galleryButtons.forEach(function(v, i) {
if (v == currentButton) {
result = i
}
})
if(result != -1){ if (result != -1) {
nextButton = galleryButtons[negmod((result+offset),galleryButtons.length)] nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
nextButton.click() nextButton.click()
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");
...@@ -51,22 +74,24 @@ function modalImageSwitch(offset){ ...@@ -51,22 +74,24 @@ function modalImageSwitch(offset){
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
modal.style.setProperty('background-image', `url(${modalImage.src})`) modal.style.setProperty('background-image', `url(${modalImage.src})`)
} }
setTimeout( function(){modal.focus()},10) setTimeout(function() {
modal.focus()
}, 10)
} }
} }
} }
function modalNextImage(event){ function modalNextImage(event) {
modalImageSwitch(1) modalImageSwitch(1)
event.stopPropagation() event.stopPropagation()
} }
function modalPrevImage(event){ function modalPrevImage(event) {
modalImageSwitch(-1) modalImageSwitch(-1)
event.stopPropagation() event.stopPropagation()
} }
function modalKeyHandler(event){ function modalKeyHandler(event) {
switch (event.key) { switch (event.key) {
case "ArrowLeft": case "ArrowLeft":
modalPrevImage(event) modalPrevImage(event)
...@@ -80,21 +105,22 @@ function modalKeyHandler(event){ ...@@ -80,21 +105,22 @@ function modalKeyHandler(event){
} }
} }
function showGalleryImage(){ function showGalleryImage() {
setTimeout(function() { setTimeout(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain') fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
if(fullImg_preview != null){ if (fullImg_preview != null) {
fullImg_preview.forEach(function function_name(e) { fullImg_preview.forEach(function function_name(e) {
if (e.dataset.modded)
return;
e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.addEventListener('click', function (evt) { e.addEventListener('click', function (evt) {
if(!opts.js_modal_lightbox) return; if(!opts.js_modal_lightbox) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initialy_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
showModal(evt) showModal(evt)
},true); }, true);
} }
}); });
} }
...@@ -102,21 +128,21 @@ function showGalleryImage(){ ...@@ -102,21 +128,21 @@ function showGalleryImage(){
}, 100); }, 100);
} }
function modalZoomSet(modalImage, enable){ function modalZoomSet(modalImage, enable) {
if( enable ){ if (enable) {
modalImage.classList.add('modalImageFullscreen'); modalImage.classList.add('modalImageFullscreen');
} else{ } else {
modalImage.classList.remove('modalImageFullscreen'); modalImage.classList.remove('modalImageFullscreen');
} }
} }
function modalZoomToggle(event){ function modalZoomToggle(event) {
modalImage = gradioApp().getElementById("modalImage"); modalImage = gradioApp().getElementById("modalImage");
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
event.stopPropagation() event.stopPropagation()
} }
function modalTileImageToggle(event){ function modalTileImageToggle(event) {
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");
const isTiling = modalImage.style.display === 'none'; const isTiling = modalImage.style.display === 'none';
...@@ -131,17 +157,18 @@ function modalTileImageToggle(event){ ...@@ -131,17 +157,18 @@ function modalTileImageToggle(event){
event.stopPropagation() event.stopPropagation()
} }
function galleryImageHandler(e){ function galleryImageHandler(e) {
if(e && e.parentElement.tagName == 'BUTTON'){ if (e && e.parentElement.tagName == 'BUTTON') {
e.onclick = showGalleryImage; e.onclick = showGalleryImage;
} }
} }
onUiUpdate(function(){ onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full') fullImg_preview = gradioApp().querySelectorAll('img.w-full')
if(fullImg_preview != null){ if (fullImg_preview != null) {
fullImg_preview.forEach(galleryImageHandler); fullImg_preview.forEach(galleryImageHandler);
} }
updateOnBackgroundChange();
}) })
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
...@@ -149,7 +176,7 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -149,7 +176,7 @@ document.addEventListener("DOMContentLoaded", function() {
const modal = document.createElement('div') const modal = document.createElement('div')
modal.onclick = closeModal; modal.onclick = closeModal;
modal.id = "lightboxModal"; modal.id = "lightboxModal";
modal.tabIndex=0 modal.tabIndex = 0
modal.addEventListener('keydown', modalKeyHandler, true) modal.addEventListener('keydown', modalKeyHandler, true)
const modalControls = document.createElement('div') const modalControls = document.createElement('div')
...@@ -180,23 +207,23 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -180,23 +207,23 @@ document.addEventListener("DOMContentLoaded", function() {
const modalImage = document.createElement('img') const modalImage = document.createElement('img')
modalImage.id = 'modalImage'; modalImage.id = 'modalImage';
modalImage.onclick = closeModal; modalImage.onclick = closeModal;
modalImage.tabIndex=0 modalImage.tabIndex = 0
modalImage.addEventListener('keydown', modalKeyHandler, true) modalImage.addEventListener('keydown', modalKeyHandler, true)
modal.appendChild(modalImage) modal.appendChild(modalImage)
const modalPrev = document.createElement('a') const modalPrev = document.createElement('a')
modalPrev.className = 'modalPrev'; modalPrev.className = 'modalPrev';
modalPrev.innerHTML = '&#10094;' modalPrev.innerHTML = '&#10094;'
modalPrev.tabIndex=0 modalPrev.tabIndex = 0
modalPrev.addEventListener('click',modalPrevImage,true); modalPrev.addEventListener('click', modalPrevImage, true);
modalPrev.addEventListener('keydown', modalKeyHandler, true) modalPrev.addEventListener('keydown', modalKeyHandler, true)
modal.appendChild(modalPrev) modal.appendChild(modalPrev)
const modalNext = document.createElement('a') const modalNext = document.createElement('a')
modalNext.className = 'modalNext'; modalNext.className = 'modalNext';
modalNext.innerHTML = '&#10095;' modalNext.innerHTML = '&#10095;'
modalNext.tabIndex=0 modalNext.tabIndex = 0
modalNext.addEventListener('click',modalNextImage,true); modalNext.addEventListener('click', modalNextImage, true);
modalNext.addEventListener('keydown', modalKeyHandler, true) modalNext.addEventListener('keydown', modalKeyHandler, true)
modal.appendChild(modalNext) modal.appendChild(modalNext)
......
...@@ -36,7 +36,7 @@ onUiUpdate(function(){ ...@@ -36,7 +36,7 @@ onUiUpdate(function(){
const notification = new Notification( const notification = new Notification(
'Stable Diffusion', 'Stable Diffusion',
{ {
body: `Generated ${imgs.size > 1 ? imgs.size - 1 : 1} image${imgs.size > 1 ? 's' : ''}`, body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
icon: headImg, icon: headImg,
image: headImg, image: headImg,
} }
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
global_progressbars = {} global_progressbars = {}
galleries = {}
galleryObservers = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) var progressbar = gradioApp().getElementById(id_progressbar)
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt) var interrupt = gradioApp().getElementById(id_interrupt)
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
if(progressbar.innerText){
let newtitle = 'Stable Diffusion - ' + progressbar.innerText
if(document.title != newtitle){
document.title = newtitle;
}
}else{
let newtitle = 'Stable Diffusion'
if(document.title != newtitle){
document.title = newtitle;
}
}
}
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
global_progressbars[id_progressbar] = progressbar global_progressbars[id_progressbar] = progressbar
...@@ -15,31 +33,72 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte ...@@ -15,31 +33,72 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
//only watch gallery if there is a generation process going on
check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(!progressDiv){
if (skip) {
skip.style.display = "none"
}
interrupt.style.display = "none" interrupt.style.display = "none"
//disconnect observer once generation finished, so user can close selected image if they want
if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect();
galleries[id_gallery] = null;
} }
} }
window.setTimeout(function(){ requestMoreProgress(id_part, id_progressbar_span, id_interrupt) }, 500)
}
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })
} }
} }
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
let prevSelectedIndex = selected_gallery_index();
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
//automatically re-open previously selected index (if exists)
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
}
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
}) })
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
btn = gradioApp().getElementById(id_part+"_check_progress"); btn = gradioApp().getElementById(id_part+"_check_progress");
if(btn==null) return; if(btn==null) return;
btn.click(); btn.click();
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt) var interrupt = gradioApp().getElementById(id_interrupt)
if(progressDiv && interrupt){ if(progressDiv && interrupt){
if (skip) {
skip.style.display = "block"
}
interrupt.style.display = "block" interrupt.style.display = "block"
} }
} }
......
function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments)
}
...@@ -33,27 +33,27 @@ function args_to_array(args){ ...@@ -33,27 +33,27 @@ function args_to_array(args){
} }
function switch_to_txt2img(){ function switch_to_txt2img(){
gradioApp().querySelectorAll('button')[0].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img_img2img(){ function switch_to_img2img_img2img(){
gradioApp().querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img_inpaint(){ function switch_to_img2img_inpaint(){
gradioApp().querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_extras(){ function switch_to_extras(){
gradioApp().querySelectorAll('button')[2].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
...@@ -101,7 +101,8 @@ function create_tab_index_args(tabId, args){ ...@@ -101,7 +101,8 @@ function create_tab_index_args(tabId, args){
} }
function get_extras_tab_index(){ function get_extras_tab_index(){
return create_tab_index_args('mode_extras', arguments) const [,,...args] = [...arguments]
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
} }
function create_submit_args(args){ function create_submit_args(args){
...@@ -186,12 +187,10 @@ onUiUpdate(function(){ ...@@ -186,12 +187,10 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
...@@ -199,12 +198,18 @@ let txt2img_textarea, img2img_textarea = undefined; ...@@ -199,12 +198,18 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800 let wait_time = 800
let token_timeout; let token_timeout;
function submit_prompt(event, generate_button_id) { function update_txt2img_tokens(...args) {
if (event.altKey && event.keyCode === 13) { update_token_counter("txt2img_token_button")
event.preventDefault(); if (args.length == 2)
gradioApp().getElementById(generate_button_id).click(); return args[0]
return; return args;
} }
function update_img2img_tokens(...args) {
update_token_counter("img2img_token_button")
if (args.length == 2)
return args[0]
return args;
} }
function update_token_counter(button_id) { function update_token_counter(button_id) {
...@@ -212,3 +217,8 @@ function update_token_counter(button_id) { ...@@ -212,3 +217,8 @@ function update_token_counter(button_id) {
clearTimeout(token_timeout); clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
}
...@@ -4,38 +4,18 @@ import os ...@@ -4,38 +4,18 @@ import os
import sys import sys
import importlib.util import importlib.util
import shlex import shlex
import platform
dir_repos = "repositories" dir_repos = "repositories"
dir_tmp = "tmp"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") index_url = os.environ.get('INDEX_URL', "")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
args = shlex.split(commandline_args)
def extract_arg(args, name): def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
def repo_dir(name):
return os.path.join(dir_repos, name)
def run(command, desc=None, errdesc=None): def run(command, desc=None, errdesc=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
...@@ -55,23 +35,11 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st ...@@ -55,23 +35,11 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
return result.stdout.decode(encoding="utf8", errors="ignore") return result.stdout.decode(encoding="utf8", errors="ignore")
def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run(command): def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0 return result.returncode == 0
def check_run_python(code):
return check_run(f'"{python}" -c "{code}"')
def is_installed(package): def is_installed(package):
try: try:
spec = importlib.util.find_spec(package) spec = importlib.util.find_spec(package)
...@@ -81,10 +49,36 @@ def is_installed(package): ...@@ -81,10 +49,36 @@ def is_installed(package):
return spec is not None return spec is not None
def repo_dir(name):
return os.path.join(dir_repos, name)
def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code):
return check_run(f'"{python}" -c "{code}"')
def git_clone(url, dir, name, commithash=None): def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful # TODO clone into temporary dir and move if successful
if os.path.exists(dir): if os.path.exists(dir):
if commithash is None:
return
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash:
return
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
...@@ -93,47 +87,85 @@ def git_clone(url, dir, name, commithash=None): ...@@ -93,47 +87,85 @@ def git_clone(url, dir, name, commithash=None):
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
try: def prepare_enviroment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
args = shlex.split(commandline_args)
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
xformers = '--xformers' in args
deepdanbooru = '--deepdanbooru' in args
ngrok = '--ngrok' in args
try:
commit = run(f"{git} rev-parse HEAD").strip() commit = run(f"{git} rev-parse HEAD").strip()
except Exception: except Exception:
commit = "<none>" commit = "<none>"
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"): if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
os.makedirs(dir_repos, exist_ok=True) if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) if platform.system() == "Windows":
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) elif platform.system() == "Linux":
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) run_pip("install xformers", "xformers")
if not is_installed("lpips"): if not is_installed("deepdanbooru") and deepdanbooru:
run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args sys.argv += args
if "--exit" in args: if "--exit" in args:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
def start_webui(): def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
import webui import webui
webui.webui() webui.webui()
if __name__ == "__main__": if __name__ == "__main__":
prepare_enviroment()
start_webui() start_webui()
...@@ -8,15 +8,13 @@ import torch ...@@ -8,15 +8,13 @@ import torch
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import shared, modelloader from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler): class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "BSRGAN" self.name = "BSRGAN"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x" self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname self.user_path = dirname
...@@ -44,13 +42,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler): ...@@ -44,13 +42,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: if model is None:
return img return img
model.to(shared.device) model.to(devices.device_bsrgan)
torch.cuda.empty_cache() torch.cuda.empty_cache()
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
......
...@@ -69,10 +69,14 @@ def setup_model(dirname): ...@@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net self.net = net
self.face_helper = face_helper self.face_helper = face_helper
self.net.to(devices.device_codeformer)
return net, face_helper return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None): def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1] np_image = np_image[:, :, ::-1]
...@@ -82,6 +86,8 @@ def setup_model(dirname): ...@@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None: if self.net is None or self.face_helper is None:
return np_image return np_image
self.send_model_to(devices.device_codeformer)
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
...@@ -113,8 +119,10 @@ def setup_model(dirname): ...@@ -113,8 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]: if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
self.net.to(devices.cpu) self.send_model_to(devices.cpu)
return restored_img return restored_img
......
import os.path
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import time
import re
re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
try:
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
return get_tags_from_process(pil_image)
finally:
release_process()
OPT_INCLUDE_RANKS = "include_ranks"
def create_deepbooru_opts():
from modules import shared
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
}
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_manager = multiprocessing.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
def get_tags_from_process(image):
from modules import shared
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
return caption
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
model_path, compile_model=False
)
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
image = tf.image.resize(
image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True,
)
image = image.numpy() # EagerTensor to np.array
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.0
image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
y = model.predict(image)[0]
result_dict = {}
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
# sort tags
result_tags_out = []
sort_ndx = 0
if alpha_sort:
sort_ndx = 1
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
# note: tag_outformat will still have a colon if include_ranks is True
tag_outformat = tag.replace(':', ' ')
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{weight:.3f})"
result_tags_out.append(tag_outformat)
print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out)
import contextlib
import torch import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False) has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu") cpu = torch.device("cpu")
...@@ -32,10 +34,9 @@ def enable_tf32(): ...@@ -32,10 +34,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
device = get_optimal_device() dtype = torch.float16
device_codeformer = cpu if has_mps else device dtype_vae = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
...@@ -58,3 +59,14 @@ def randn_without_seed(shape): ...@@ -58,3 +59,14 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def autocast(disable=False):
from modules import shared
if disable:
return contextlib.nullcontext()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
return torch.autocast("cuda")
...@@ -5,10 +5,8 @@ import torch ...@@ -5,10 +5,8 @@ import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images from modules import shared, modelloader, images, devices
from modules.devices import has_mps
from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
...@@ -73,11 +71,10 @@ def fix_model_layers(crt_model, pretrained_net): ...@@ -73,11 +71,10 @@ def fix_model_layers(crt_model, pretrained_net):
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "ESRGAN" self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
self.model_name = "ESRGAN 4x" self.model_name = "ESRGAN_4x"
self.scalers = [] self.scalers = []
self.user_path = dirname self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
super().__init__() super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"]) model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = [] scalers = []
...@@ -97,7 +94,7 @@ class UpscalerESRGAN(Upscaler): ...@@ -97,7 +94,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model) model = self.load_model(selected_model)
if model is None: if model is None:
return img return img
model.to(shared.device) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
...@@ -112,7 +109,7 @@ class UpscalerESRGAN(Upscaler): ...@@ -112,7 +109,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename)) print("Unable to load %s from %s" % (self.model_path, filename))
return None return None
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net) pretrained_net = fix_model_layers(crt_model, pretrained_net)
...@@ -127,7 +124,7 @@ def upscale_without_tiling(model, img): ...@@ -127,7 +124,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
......
import math
import os import os
import numpy as np import numpy as np
...@@ -19,7 +20,7 @@ import gradio as gr ...@@ -19,7 +20,7 @@ import gradio as gr
cached_images = {} cached_images = {}
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc() devices.torch_gc()
imageArr = [] imageArr = []
...@@ -29,7 +30,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -29,7 +30,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if extras_mode == 1: if extras_mode == 1:
#convert file to pillow image #convert file to pillow image
for img in image_folder: for img in image_folder:
image = Image.fromarray(np.array(Image.open(img))) image = Image.open(img)
imageArr.append(image) imageArr.append(image)
imageNameArr.append(os.path.splitext(img.orig_name)[0]) imageNameArr.append(os.path.splitext(img.orig_name)[0])
else: else:
...@@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res image = res
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
if upscaling_resize != 1.0: if upscaling_resize != 1.0:
def upscale(image, scaler_index, resize): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist()) pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
...@@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if c is None: if c is None:
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path) c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c cached_images[key] = c
return c return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
res = upscale(image, extras_upscaler_1, upscaling_resize) res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
res2 = upscale(image, extras_upscaler_2, upscaling_resize) res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility) res = Image.blend(res, res2, extras_upscaler_2_visibility)
...@@ -98,8 +108,14 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -98,8 +108,14 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
forced_filename=image_name if opts.use_original_name_batch else None) forced_filename=image_name if opts.use_original_name_batch else None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
image.info["extras"] = info
outputs.append(image) outputs.append(image)
devices.torch_gc()
return outputs, plaintext_to_html(info), '' return outputs, plaintext_to_html(info), ''
...@@ -143,48 +159,52 @@ def run_pnginfo(image): ...@@ -143,48 +159,52 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) def weighted_sum(theta0, theta1, theta2, alpha):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) def add_difference(theta0, theta1, theta2, alpha):
def sigmoid(theta0, theta1, alpha): return theta0 + (theta1 - theta2) * alpha
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def inv_sigmoid(theta0, theta1, alpha):
import math
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
return theta0 + ((theta1 - theta0) * alpha)
primary_model_info = sd_models.checkpoints_list[primary_model_name] primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu') primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
theta_0 = primary_model['state_dict'] if teritary_model_info is not None:
theta_1 = secondary_model['state_dict'] print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else:
theta_2 = None
theta_funcs = { theta_funcs = {
"Weighted Sum": weighted_sum, "Weighted sum": weighted_sum,
"Sigmoid": sigmoid, "Add difference": add_difference,
"Inverse Sigmoid": inv_sigmoid,
} }
theta_func = theta_funcs[interp_method] theta_func = theta_funcs[interp_method]
print(f"Merging...") print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint t2 = (theta_2 or {}).get(key)
if t2 is None:
t2 = torch.zeros_like(theta_0[key])
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
...@@ -193,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -193,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt') filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
...@@ -203,4 +223,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -203,4 +223,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models() sd_models.list_models()
print(f"Checkpoint saved.") print(f"Checkpoint saved.")
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
import os
import re import re
import gradio as gr import gradio as gr
from modules.shared import script_path
from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
...@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None): def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file:
prompt = file.read()
params = parse_generation_parameters(prompt) params = parse_generation_parameters(prompt)
res = [] res = []
......
...@@ -21,7 +21,7 @@ def gfpgann(): ...@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model return loaded_gfpgan_model
if gfpgan_constructor is None: if gfpgan_constructor is None:
...@@ -37,22 +37,32 @@ def gfpgann(): ...@@ -37,22 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model
def send_model_to(model, device):
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgann() model = gfpgann()
if model is None: if model is None:
return np_image return np_image
send_model_to(model, devices.device_gfpgan)
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
model.gfpgan.to(devices.cpu) send_model_to(model, devices.cpu)
return np_image return np_image
...@@ -97,11 +107,7 @@ def setup_model(dirname): ...@@ -97,11 +107,7 @@ def setup_model(dirname):
return "GFPGAN" return "GFPGAN"
def restore(self, np_image): def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1] return gfpgan_fix_faces(np_image)
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
shared.face_restorers.append(FaceRestorerGFPGAN()) shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:
......
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm
import csv
import torch
from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)
if state_dict is not None:
self.load_state_dict(state_dict, strict=True)
else:
self.linear1.weight.data.normal_(mean=0.0, std=0.01)
self.linear1.bias.data.zero_()
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.bias.data.zero_()
self.to(devices.device)
def forward(self, x):
return x + (self.linear2(self.linear1(x))) * self.multiplier
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None):
self.filename = None
self.name = name
self.layers = {}
self.step = 0
self.sd_checkpoint = None
self.sd_checkpoint_name = None
for size in enable_sizes or []:
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
def weights(self):
res = []
for k, layers in self.layers.items():
for layer in layers:
layer.train()
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
return res
def save(self, filename):
state_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
torch.save(state_dict, filename)
def load(self, filename):
self.filename = filename
if self.name is None:
self.name = os.path.splitext(os.path.basename(filename))[0]
state_dict = torch.load(filename, map_location='cpu')
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
def find_closest_hypernetwork_name(search: str):
if not search:
return None
search = search.lower()
applicable = [name for name in shared.hypernetworks if search in name.lower()]
if not applicable:
return None
applicable = sorted(applicable, key=lambda name: len(name))
return applicable[0]
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
def stack_conds(conds):
if len(conds) == 1:
return torch.stack(conds)
# same as in reconstruct_multicond_batch
token_count = max([x.shape[0] for x in conds])
for i in range(len(conds)):
if conds[i].shape[0] != token_count:
last_vector = conds[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
conds[i] = torch.vstack([conds[i], last_vector_repeated])
return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
unload = shared.opts.unload_models_when_training
if save_hypernetwork_every > 0:
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
os.makedirs(hypernetwork_dir, exist_ok=True)
else:
hypernetwork_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.save(filename)
return hypernetwork, filename
import html
import os
import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
hypernet.save(fn)
shared.reload_hypernetworks()
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
def train_hypernetwork(*args):
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
try:
sd_hijack.undo_optimizations()
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
Hypernetwork saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
import datetime import datetime
import io
import math import math
import os import os
from collections import namedtuple from collections import namedtuple
...@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None):
rows = opts.n_rows rows = opts.n_rows
elif opts.n_rows == 0: elif opts.n_rows == 0:
rows = batch_size rows = batch_size
elif opts.grid_prevent_empty_spots:
rows = math.floor(math.sqrt(len(imgs)))
while len(imgs) % rows != 0:
rows -= 1
else: else:
rows = math.sqrt(len(imgs)) rows = math.sqrt(len(imgs))
rows = round(rows) rows = round(rows)
...@@ -287,6 +292,20 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -287,6 +292,20 @@ def apply_filename_pattern(x, p, seed, prompt):
if seed is not None: if seed is not None:
x = x.replace("[seed]", str(seed)) x = x.replace("[seed]", str(seed))
if p is not None:
x = x.replace("[steps]", str(p.steps))
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
x = x.replace("[date]", datetime.date.today().isoformat())
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
# Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None: if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt)) x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x: if "[prompt_no_styles]" in x:
...@@ -306,19 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -306,19 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
words = ["empty"] words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None:
x = x.replace("[steps]", str(p.steps))
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
if cmd_opts.hide_ui_dir_config: if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x) x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
...@@ -347,7 +353,39 @@ def get_next_sequence_number(path, basename): ...@@ -347,7 +353,39 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
'''Save an image.
Args:
image (`PIL.Image`):
The image to be saved.
path (`str`):
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
basename (`str`):
The base filename which will be applied to `filename pattern`.
seed, prompt, short_filename,
extension (`str`):
Image file extension, default is `png`.
pngsectionname (`str`):
Specify the name of the section which `info` will be saved in.
info (`str` or `PngImagePlugin.iTXt`):
PNG info chunks.
existing_info (`dict`):
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
no_prompt:
TODO I don't know its meaning.
p (`StableDiffusionProcessing`)
forced_filename (`str`):
If specified, `basename` and filename pattern will be ignored.
save_to_dirs (bool):
If true, the image will be saved into a subdirectory of `path`.
Returns: (fullfn, txt_fullfn)
fullfn (`str`):
The full path of the saved imaged.
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
'''
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
elif opts.save_to_dirs: elif opts.save_to_dirs:
...@@ -371,10 +409,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -371,10 +409,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else: else:
pnginfo = None pnginfo = None
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs: if save_to_dirs:
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt) dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname) path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
...@@ -422,7 +461,29 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -422,7 +461,29 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg") piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
if opts.save_txt and info is not None: if opts.save_txt and info is not None:
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: txt_fullfn = f"{fullfn_without_extension}.txt"
with open(txt_fullfn, "w", encoding="utf8") as file:
file.write(info + "\n") file.write(info + "\n")
else:
txt_fullfn = None
return fullfn, txt_fullfn
def image_data(data):
try:
image = Image.open(io.BytesIO(data))
textinfo = image.text["parameters"]
return textinfo, None
except Exception:
pass
try:
text = data.decode('utf8')
assert len(text) < 10000
return text, None
except Exception:
pass
return '', None
import os
import shutil
def traverse_all_files(output_dir, image_list, curr_dir=None):
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
try:
f_list = os.listdir(curr_path)
except:
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
image_list.append(curr_dir)
return image_list
for file in f_list:
file = file if curr_dir is None else os.path.join(curr_dir, file)
file_path = os.path.join(curr_path, file)
if file[-4:] == ".txt":
pass
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
image_list.append(file)
else:
image_list = traverse_all_files(output_dir, image_list, file)
return image_list
def get_recent_images(dir_name, page_index, step, image_index, tabname):
page_index = int(page_index)
f_list = os.listdir(dir_name)
image_list = []
image_list = traverse_all_files(dir_name, image_list)
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
num = 48 if tabname != "extras" else 12
max_page_index = len(image_list) // num + 1
page_index = max_page_index if page_index == -1 else page_index + step
page_index = 1 if page_index < 1 else page_index
page_index = max_page_index if page_index > max_page_index else page_index
idx_frm = (page_index - 1) * num
image_list = image_list[idx_frm:idx_frm + num]
image_index = int(image_index)
if image_index < 0 or image_index > len(image_list) - 1:
current_file = None
hidden = None
else:
current_file = image_list[int(image_index)]
hidden = os.path.join(dir_name, current_file)
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
def first_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, 1, 0, image_index, tabname)
def end_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, -1, 0, image_index, tabname)
def prev_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
def next_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
def page_index_change(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
def show_image_info(num, image_path, filenames):
# print(f"select image {num}")
file = filenames[int(num)]
return file, num, os.path.join(image_path, file)
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
if name == "":
return filenames, delete_num
else:
delete_num = int(delete_num)
index = list(filenames).index(name)
i = 0
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
path = os.path.join(dir_name, name)
if os.path.exists(path):
print(f"Delete file {path}")
os.remove(path)
txt_file = os.path.splitext(path)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
print(f"Not exists file {path}")
else:
new_file_list.append(name)
i += 1
return new_file_list, 1
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if opts.outdir_samples != "":
dir_name = opts.outdir_samples
elif tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
d = dir_name.split("/")
dir_name = "/" if dir_name.startswith("/") else d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
with gr.Row():
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
next_page = gr.Button('Next Page')
end_page = gr.Button('End Page')
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Row():
with gr.Column(scale=2):
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
with gr.Column():
with gr.Row():
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
pnginfo_send_to_img2img = gr.Button('Send to img2img')
with gr.Row():
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
img_file_name = gr.Textbox(label="File Name", interactive=False)
with gr.Row():
# hiden items
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
tabname_box = gr.Textbox(tabname, visible=False)
image_index = gr.Textbox(value=-1, visible=False)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
filenames = gr.State()
hidden = gr.Image(type="pil", visible=False)
info1 = gr.Textbox(visible=False)
info2 = gr.Textbox(visible=False)
# turn pages
gallery_inputs = [img_path, page_index, image_index, tabname_box]
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
# other funcitons
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
with gr.Tab("txt2img history"):
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
with gr.Tab("img2img history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
with gr.Tab("extras history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
return images_history
...@@ -23,13 +23,17 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -23,13 +23,17 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True p.do_not_save_grid = True
p.do_not_save_samples = True p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter state.job_count = len(images) * p.n_iter
for i, image in enumerate(images): for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}" state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False
if state.interrupted: if state.interrupted:
break break
...@@ -48,6 +52,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -48,6 +52,7 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename) left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}" filename = f"{left}-{n}{right}"
if not save_normally:
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
...@@ -103,6 +108,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -103,6 +108,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpaint_full_res_padding=inpaint_full_res_padding, inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blur"] = mask_blur
...@@ -124,4 +131,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -124,4 +131,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)
...@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) ...@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.") re_topn = re.compile(r"\.top(\d+)\.")
class InterrogateModels: class InterrogateModels:
blip_model = None blip_model = None
clip_model = None clip_model = None
...@@ -54,7 +55,7 @@ class InterrogateModels: ...@@ -54,7 +55,7 @@ class InterrogateModels:
model, preprocess = clip.load(clip_model_name) model, preprocess = clip.load(clip_model_name)
model.eval() model.eval()
model = model.to(shared.device) model = model.to(devices.device_interrogate)
return model, preprocess return model, preprocess
...@@ -64,14 +65,14 @@ class InterrogateModels: ...@@ -64,14 +65,14 @@ class InterrogateModels:
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(shared.device) self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(shared.device) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = next(self.clip_model.parameters()).dtype
...@@ -98,11 +99,11 @@ class InterrogateModels: ...@@ -98,11 +99,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device) text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device) similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]): for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0] similarity /= image_features.shape[0]
...@@ -115,14 +116,14 @@ class InterrogateModels: ...@@ -115,14 +116,14 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad(): with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
return caption[0] return caption[0]
def interrogate(self, pil_image): def interrogate(self, pil_image, include_ranks=False):
res = None res = None
try: try:
...@@ -139,11 +140,11 @@ class InterrogateModels: ...@@ -139,11 +140,11 @@ class InterrogateModels:
res = caption res = caption
cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(cilp_image).type(self.dtype) image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
...@@ -155,7 +156,10 @@ class InterrogateModels: ...@@ -155,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
if include_ranks:
res += ", " + match res += ", " + match
else:
res += f", ({match}:{score})"
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)
......
...@@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR from modules.ldsr_model_arch import LDSR
from modules import shared from modules import shared
from modules.paths import models_path
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
def __init__(self, user_path): def __init__(self, user_path):
self.name = "LDSR" self.name = "LDSR"
self.model_path = os.path.join(models_path, self.name)
self.user_path = user_path self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
......
...@@ -5,7 +5,6 @@ import importlib ...@@ -5,7 +5,6 @@ import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules import shared from modules import shared
from modules.upscaler import Upscaler from modules.upscaler import Upscaler
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
...@@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places: for place in places:
if os.path.exists(place): if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True): for file in glob.iglob(place + '**/**', recursive=True):
full_path = os.path.join(place, file) full_path = file
if os.path.isdir(full_path): if os.path.isdir(full_path):
continue continue
if len(ext_filter) != 0: if len(ext_filter) != 0:
...@@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): ...@@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers(): def load_upscalers():
sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
modules_dir = os.path.join(sd, "modules")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")
full_model = f"modules.{model_name}_model"
try:
importlib.import_module(full_model)
except:
pass
datas = [] datas = []
c_o = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
name = cls.__name__ name = cls.__name__
module_name = cls.__module__ module_name = cls.__module__
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
class_ = getattr(module, name) class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None opt_string = None
try: try:
opt_string = shared.opts.__getattr__(cmd_name) if cmd_name in c_o:
opt_string = c_o[cmd_name]
except: except:
pass pass
scaler = class_(opt_string) scaler = class_(opt_string)
......
from pyngrok import ngrok, conf, exception
def connect(token, port):
if token == None:
token = 'None'
conf.get_default().auth_token = token
try:
public_url = ngrok.connect(port).public_url
except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
else:
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
'You can use this link after the launch is complete.')
import argparse import argparse
import os import os
import sys import sys
import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models") models_path = os.path.join(script_path, "models")
...@@ -12,6 +13,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), ...@@ -12,6 +13,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'),
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path) sd_path = os.path.abspath(possible_sd_path)
break
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
...@@ -20,7 +22,6 @@ path_dirs = [ ...@@ -20,7 +22,6 @@ path_dirs = [
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
......
import contextlib
import json import json
import math import math
import os import os
...@@ -12,9 +11,8 @@ import cv2 ...@@ -12,9 +11,8 @@ import cv2
from skimage import exposure from skimage import exposure
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.face_restoration import modules.face_restoration
...@@ -48,6 +46,12 @@ def apply_color_correction(correction, image): ...@@ -48,6 +46,12 @@ def apply_color_correction(correction, image):
return image return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model self.sd_model = sd_model
...@@ -56,7 +60,7 @@ class StableDiffusionProcessing: ...@@ -56,7 +60,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt self.prompt: str = prompt
self.prompt_for_display: str = None self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "") self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles self.styles: list = styles or []
self.seed: int = seed self.seed: int = seed
self.subseed: int = subseed self.subseed: int = subseed
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
...@@ -111,7 +115,7 @@ class Processed: ...@@ -111,7 +115,7 @@ class Processed:
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_index = p.sampler_index self.sampler_index = p.sampler_index
self.sampler = samplers[p.sampler_index].name self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
...@@ -123,6 +127,9 @@ class Processed: ...@@ -123,6 +127,9 @@ class Processed:
self.denoising_strength = getattr(p, 'denoising_strength', None) self.denoising_strength = getattr(p, 'denoising_strength', None)
self.extra_generation_params = p.extra_generation_params self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image self.index_of_first_image = index_of_first_image
self.styles = p.styles
self.job_timestamp = state.job_timestamp
self.clip_skip = opts.CLIP_stop_at_last_layers
self.eta = p.eta self.eta = p.eta
self.ddim_discretize = p.ddim_discretize self.ddim_discretize = p.ddim_discretize
...@@ -133,7 +140,7 @@ class Processed: ...@@ -133,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
...@@ -167,6 +174,9 @@ class Processed: ...@@ -167,6 +174,9 @@ class Processed:
"extra_generation_params": self.extra_generation_params, "extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image, "index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts, "infotexts": self.infotexts,
"styles": self.styles,
"job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip,
} }
return json.dumps(obj) return json.dumps(obj)
...@@ -197,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see ...@@ -197,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing. # enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101]. # produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else: else:
sampler_noises = None sampler_noises = None
...@@ -237,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see ...@@ -237,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None: if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p) cnt = p.sampler.number_of_needed_noises(p)
if opts.eta_noise_seed_delta > 0:
torch.manual_seed(seed + opts.eta_noise_seed_delta)
for j in range(cnt): for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
...@@ -249,29 +262,49 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see ...@@ -249,29 +262,49 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x return x
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
x = model.decode_first_stage(x)
return x
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
return seed
def fix_seed(p): def fix_seed(p):
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed p.seed = get_fixed_seed(p.seed)
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": samplers[p.sampler_index].name, "Sampler": get_correct_sampler(p)[p.sampler_index].name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
...@@ -291,14 +324,23 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -291,14 +324,23 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else: else:
assert p.prompt is not None assert p.prompt is not None
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
devices.torch_gc() devices.torch_gc()
fix_seed(p) seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True) os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
comments = {} comments = {}
...@@ -309,33 +351,36 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -309,33 +351,36 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else: else:
all_prompts = p.batch_size * p.n_iter * [p.prompt] all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(p.seed) == list: if type(seed) == list:
all_seeds = p.seed all_seeds = seed
else: else:
all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
if type(p.subseed) == list: if type(subseed) == list:
all_subseeds = p.subseed all_subseeds = subseed
else: else:
all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))] all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = [] infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) with torch.no_grad(), p.sd_model.ema_scope():
with torch.no_grad(), precision_scope("cuda"), ema_scope(): with devices.autocast():
p.init(all_prompts, all_seeds, all_subseeds) p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
for n in range(p.n_iter): for n in range(p.n_iter):
if state.skipped:
state.skipped = False
if state.interrupted: if state.interrupted:
break break
...@@ -348,8 +393,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -348,8 +393,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts) #c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) with devices.autocast():
c = prompt_parser.get_learned_conditioning(prompts, p.steps) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
for comment in model_hijack.comments: for comment in model_hijack.comments:
...@@ -358,16 +404,26 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -358,16 +404,26 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
if state.interrupted:
# if we are interruped, sample returns just noise if state.interrupted or state.skipped:
# if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop # use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent samples_ddim = shared.state.current_latent
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
if opts.filter_nsfw: if opts.filter_nsfw:
import modules.safety as safety import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
...@@ -383,6 +439,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -383,6 +439,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
...@@ -408,9 +465,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -408,9 +465,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i)) text = infotext(n, i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image) output_images.append(image)
del x_samples_ddim
devices.torch_gc()
state.nextjob() state.nextjob()
p.color_corrections = None p.color_corrections = None
...@@ -421,7 +485,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -421,7 +485,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
infotexts.insert(0, infotext()) text = infotext()
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
...@@ -434,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -434,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
firstphase_width = 0
firstphase_height = 0
firstphase_width_truncated = 0
firstphase_height_truncated = 0
def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs): def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.scale_latent = scale_latent
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
self.firstphase_width = firstphase_width
self.firstphase_height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr: if self.enable_hr:
...@@ -452,17 +518,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -452,17 +518,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
state.job_count = state.job_count * 2 state.job_count = state.job_count * 2
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512 desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height actual_pixel_count = self.width * self.height
scale = math.sqrt(desired_pixel_count / actual_pixel_count) scale = math.sqrt(desired_pixel_count / actual_pixel_count)
self.firstphase_width = math.ceil(scale * self.width / 64) * 64 self.firstphase_width = math.ceil(scale * self.width / 64) * 64
self.firstphase_height = math.ceil(scale * self.height / 64) * 64 self.firstphase_height = math.ceil(scale * self.height / 64) * 64
self.firstphase_width_truncated = int(scale * self.width) firstphase_width_truncated = int(scale * self.width)
self.firstphase_height_truncated = int(scale * self.height) firstphase_height_truncated = int(scale * self.height)
else:
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
if width_ratio > height_ratio:
firstphase_width_truncated = self.firstphase_width
firstphase_height_truncated = self.firstphase_width * self.height / self.width
else:
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = samplers[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
...@@ -472,15 +555,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -472,15 +555,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] decoded_samples = decode_first_stage(self.sd_model, samples)
if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else:
decoded_samples = self.sd_model.decode_first_stage(samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
...@@ -505,7 +582,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -505,7 +582,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
self.sampler = samplers[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
...@@ -540,7 +618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -540,7 +618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: if self.image_mask is not None:
...@@ -647,4 +725,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -647,4 +725,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask is not None: if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask samples = samples * self.nmask + self.init_latent * self.mask
del x
devices.torch_gc()
return samples return samples
import re import re
from collections import namedtuple from collections import namedtuple
import torch from typing import List
import lark
import modules.shared as shared
re_prompt = re.compile(r'''
(.*?)
\[
([^]:]+):
(?:([^]:]*):)?
([0-9]*\.?[0-9]+)
]
|
(.+)
''', re.X)
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
...@@ -23,71 +11,117 @@ re_prompt = re.compile(r''' ...@@ -23,71 +11,117 @@ re_prompt = re.compile(r'''
# [75, 'fantasy landscape with a lake and an oak in background masterful'] # [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
schedule_parser = lark.Lark(r"""
!start: (prompt | /[][():]/+)*
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
alternate: "[" prompt ("|" prompt)+ "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, steps):
res = [] """
cache = {} >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
for prompt in prompts: [[10, 'test']]
prompt_schedule: list[list[str | int]] = [[steps, ""]] >>> g("a [b:3]")
[[3, 'a '], [10, 'a b']]
cached = cache.get(prompt, None) >>> g("a [b: 3]")
if cached is not None: [[3, 'a '], [10, 'a b']]
res.append(cached) >>> g("a [[[b]]:2]")
continue [[2, 'a '], [10, 'a [[b]]']]
>>> g("[(a:2):3]")
for m in re_prompt.finditer(prompt): [[3, ''], [10, '(a:2)']]
plaintext = m.group(1) if m.group(5) is None else m.group(5) >>> g("a [b : c : 1] d")
concept_from = m.group(2) [[1, 'a b d'], [10, 'a c d']]
concept_to = m.group(3) >>> g("a[b:[c:d:2]:1]e")
if concept_to is None: [[1, 'abe'], [2, 'ace'], [10, 'ade']]
concept_to = concept_from >>> g("a [unbalanced")
concept_from = "" [[10, 'a [unbalanced']]
swap_position = float(m.group(4)) if m.group(4) is not None else None >>> g("a [b:.5] c")
[[5, 'a c'], [10, 'a b c']]
if swap_position is not None: >>> g("a [{b|d{:.5] c") # not handling this right now
if swap_position < 1: [[5, 'a c'], [10, 'a {b|d{ c']]
swap_position = swap_position * steps >>> g("((a][:b:c [d:3]")
swap_position = int(min(swap_position, steps)) [[3, '((a][:b:c '], [10, '((a][:b:c d']]
"""
swap_index = None
found_exact_index = False
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
prompt_schedule[i][1] += plaintext
if swap_position is not None and swap_index is None:
if swap_position == end_step:
swap_index = i
found_exact_index = True
if swap_position < end_step:
swap_index = i
if swap_index is not None:
if not found_exact_index:
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
must_replace = swap_position < end_step
prompt_schedule[i][1] += concept_to if must_replace else concept_from
res.append(prompt_schedule)
cache[prompt] = prompt_schedule
#for t in prompt_schedule:
# print(t)
return res def collect_steps(steps, tree):
l = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1])
def alternate(self, tree):
l.extend(range(1, steps+1))
CollectSteps().visit(tree)
return sorted(set(l))
def at_step(step, tree):
class AtStep(lark.Transformer):
def scheduled(self, args):
before, after, _, when = args
yield before or () if step <= when else after
def alternate(self, args):
yield next(args[(step - 1)%len(args)])
def start(self, args):
def flatten(x):
if type(x) == str:
yield x
else:
for gen in x:
yield from flatten(gen)
return ''.join(flatten(args))
def plain(self, args):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
yield from child
return AtStep().transform(tree)
def get_schedule(prompt):
try:
tree = schedule_parser.parse(prompt)
except lark.exceptions.LarkError as e:
if 0:
import traceback
traceback.print_exc()
return [[steps, prompt]]
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
return [promptdict[prompt] for prompt in prompts]
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
def get_learned_conditioning(prompts, steps): def get_learned_conditioning(model, prompts, steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
Input:
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
Output:
[
[
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
],
[
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
]
]
"""
res = [] res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
...@@ -101,7 +135,7 @@ def get_learned_conditioning(prompts, steps): ...@@ -101,7 +135,7 @@ def get_learned_conditioning(prompts, steps):
continue continue
texts = [x[1] for x in prompt_schedule] texts = [x[1] for x in prompt_schedule]
conds = shared.sd_model.get_learned_conditioning(texts) conds = model.get_learned_conditioning(texts)
cond_schedule = [] cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule): for i, (end_at_step, text) in enumerate(prompt_schedule):
...@@ -110,22 +144,118 @@ def get_learned_conditioning(prompts, steps): ...@@ -110,22 +144,118 @@ def get_learned_conditioning(prompts, steps):
cache[prompt] = cond_schedule cache[prompt] = cond_schedule
res.append(cond_schedule) res.append(cond_schedule)
return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) return res
re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts):
res_indexes = []
prompt_flat_list = []
prompt_indexes = {}
for prompt in prompts:
subprompts = re_AND.split(prompt)
indexes = []
for subprompt in subprompts:
match = re_weight.search(subprompt)
text, weight = match.groups() if match is not None else (subprompt, 1.0)
weight = float(weight) if weight is not None else 1.0
index = prompt_indexes.get(text, None)
if index is None:
index = len(prompt_flat_list)
prompt_flat_list.append(text)
prompt_indexes[text] = index
indexes.append((index, weight))
res_indexes.append(indexes)
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res_indexes, prompt_flat_list, prompt_indexes
res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
for i, cond_schedule in enumerate(c.schedules):
class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
self.schedules: List[ScheduledPromptConditioning] = schedules
self.weight: float = weight
class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
"""
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
res = []
for indexes in res_indexes:
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0 target_index = 0
for curret_index, (end_at, cond) in enumerate(cond_schedule): for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at: if current_step <= end_at:
target_index = curret_index target_index = current
break break
res[i] = cond_schedule[target_index].cond res[i] = cond_schedule[target_index].cond
return res return res
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
tensors = []
conds_list = []
for batch_no, composable_prompts in enumerate(c.batch):
conds_for_batch = []
for cond_index, composable_prompt in enumerate(composable_prompts):
target_index = 0
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
if current_step <= end_at:
target_index = current
break
conds_for_batch.append((len(tensors), composable_prompt.weight))
tensors.append(composable_prompt.schedules[target_index].cond)
conds_list.append(conds_for_batch)
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
re_attention = re.compile(r""" re_attention = re.compile(r"""
\\\(| \\\(|
\\\)| \\\)|
...@@ -157,14 +287,18 @@ def parse_prompt_attention(text): ...@@ -157,14 +287,18 @@ def parse_prompt_attention(text):
\\ - literal character '\' \\ - literal character '\'
anything else - just text anything else - just text
Example: >>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' >>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
produces: >>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
[ >>> parse_prompt_attention('\(literal\]')
['a ', 1.0], [['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004], ['house', 1.5730000000000004],
[' ', 1.1], [' ', 1.1],
['on', 1.0], ['on', 1.0],
...@@ -172,8 +306,7 @@ def parse_prompt_attention(text): ...@@ -172,8 +306,7 @@ def parse_prompt_attention(text):
['hill', 0.55], ['hill', 0.55],
[', sun, ', 1.1], [', sun, ', 1.1],
['sky', 1.4641000000000006], ['sky', 1.4641000000000006],
['.', 1.1] ['.', 1.1]]
]
""" """
res = [] res = []
...@@ -215,4 +348,19 @@ def parse_prompt_attention(text): ...@@ -215,4 +348,19 @@ def parse_prompt_attention(text):
if len(res) == 0: if len(res) == 0:
res = [["", 1.0]] res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res return res
if __name__ == "__main__":
import doctest
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
else:
import torch # doctest faster
...@@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
def __init__(self, path): def __init__(self, path):
self.name = "RealESRGAN" self.name = "RealESRGAN"
self.model_path = os.path.join(models_path, self.name)
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try: try:
......
# this code is adapted from the script contributed by anon from /h/
import io
import pickle
import collections
import sys
import traceback
import torch
import numpy
import _codecs
import zipfile
import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
out = _codecs.encode(*args)
return out
class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name == 'scalar':
return numpy.core.multiarray.scalar
if module == 'numpy' and name == 'dtype':
return numpy.dtype
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.model_checkpoint
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "__builtin__" and name == 'set':
return set
# Forbid everything else.
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"]
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
def check_zip_filenames(filename, names):
for name in names:
if name in allowed_zip_names:
continue
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
unpickler.load()
except zipfile.BadZipfile:
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
unsafe_torch_load = torch.load
torch.load = load
...@@ -162,6 +162,40 @@ class ScriptRunner: ...@@ -162,6 +162,40 @@ class ScriptRunner:
return processed return processed
def reload_sources(self):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
args_to = script.args_to
filename = script.filename
text = file.read()
from types import ModuleType
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
def reload_script_body_only():
scripts_txt2img.reload_sources()
scripts_img2img.reload_sources()
def reload_scripts(basedir):
global scripts_txt2img, scripts_img2img
scripts_data.clear()
load_scripts(basedir)
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from modules.scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "ScuNET"
self.model_name = "ScuNET GAN"
self.model_name2 = "ScuNET PSNR"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pth"])
scalers = []
add_model2 = True
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
if name == self.model_name2 or file == self.model_url2:
add_model2 = False
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
device = devices.device_scunet
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
device = devices.device_scunet
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_, DropPath
class WMSA(nn.Module):
""" Self-attention module in Swin Transformer
"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
super(WMSA, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.n_heads = input_dim // head_dim
self.window_size = window_size
self.type = type
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
self.linear = nn.Linear(self.input_dim, self.output_dim)
trunc_normal_(self.relative_position_params, std=.02)
self.relative_position_params = torch.nn.Parameter(
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
2).transpose(
0, 1))
def generate_mask(self, h, w, p, shift):
""" generating the mask of SW-MSA
Args:
shift: shift parameters in CyclicShift.
Returns:
attn_mask: should be (1 1 w p p),
"""
# supporting square.
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
if self.type == 'W':
return attn_mask
s = p - shift
attn_mask[-1, :, :s, :, s:, :] = True
attn_mask[-1, :, s:, :, :s, :] = True
attn_mask[:, -1, :, :s, :, s:] = True
attn_mask[:, -1, :, s:, :, :s] = True
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
return attn_mask
def forward(self, x):
""" Forward pass of Window Multi-head Self-attention module.
Args:
x: input tensor with shape of [b h w c];
attn_mask: attention mask, fill -inf where the value is True;
Returns:
output: tensor shape [b h w c]
"""
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
# square validation
# assert h_windows == w_windows
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
qkv = self.embedding_layer(x)
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
# Adding learnable relative embedding
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
# Using Attn Mask to distinguish different subwindows.
if self.type != 'W':
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
sim = sim.masked_fill_(attn_mask, float("-inf"))
probs = nn.functional.softmax(sim, dim=-1)
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
output = rearrange(output, 'h b w p c -> b w p (h c)')
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
dims=(1, 2))
return output
def relative_embedding(self):
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
# negative is allowed
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
class Block(nn.Module):
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer Block
"""
super(Block, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ['W', 'SW']
self.type = type
if input_resolution <= window_size:
self.type = 'W'
self.ln1 = nn.LayerNorm(input_dim)
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ln2 = nn.LayerNorm(input_dim)
self.mlp = nn.Sequential(
nn.Linear(input_dim, 4 * input_dim),
nn.GELU(),
nn.Linear(4 * input_dim, output_dim),
)
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
class ConvTransBlock(nn.Module):
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer and Conv Block
"""
super(ConvTransBlock, self).__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
self.window_size = window_size
self.drop_path = drop_path
self.type = type
self.input_resolution = input_resolution
assert self.type in ['W', 'SW']
if self.input_resolution <= self.window_size:
self.type = 'W'
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
self.type, self.input_resolution)
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv_block = nn.Sequential(
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
)
def forward(self, x):
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
class SCUNet(nn.Module):
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
super(SCUNet, self).__init__()
if config is None:
config = [2, 2, 2, 2, 2, 2, 2]
self.config = config
self.dim = dim
self.head_dim = 32
self.window_size = 8
# drop path rate for each layer
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
begin = 0
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[0])] + \
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
begin += config[0]
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[1])] + \
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
begin += config[1]
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[2])] + \
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
begin += config[2]
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 8)
for i in range(config[3])]
begin += config[3]
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[4])]
begin += config[4]
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[5])]
begin += config[5]
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[6])]
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
self.m_head = nn.Sequential(*self.m_head)
self.m_down1 = nn.Sequential(*self.m_down1)
self.m_down2 = nn.Sequential(*self.m_down2)
self.m_down3 = nn.Sequential(*self.m_down3)
self.m_body = nn.Sequential(*self.m_body)
self.m_up3 = nn.Sequential(*self.m_up3)
self.m_up2 = nn.Sequential(*self.m_up2)
self.m_up1 = nn.Sequential(*self.m_up1)
self.m_tail = nn.Sequential(*self.m_tail)
# self.apply(self._init_weights)
def forward(self, x0):
h, w = x0.size()[-2:]
paddingBottom = int(np.ceil(h / 64) * 64 - h)
paddingRight = int(np.ceil(w / 64) * 64 - w)
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
x = x[..., :h, :w]
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
\ No newline at end of file
...@@ -5,245 +5,66 @@ import traceback ...@@ -5,245 +5,66 @@ import traceback
import torch import torch
import numpy as np import numpy as np
from torch import einsum from torch import einsum
from torch.nn.functional import silu
from modules import prompt_parser import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available
from ldm.util import default
from einops import rearrange
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else:
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] def undo_optimizations():
del v1, w4 from modules.hypernetworks import hypernetwork
h2 = h_.reshape(b, c, h, w) ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
del h_ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
h3 = self.proj_out(h2)
del h2
h3 += x def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
return h3
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None fixes = None
comments = [] comments = []
dir_mtime = None
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None clip = None
def load_textual_inversion_embeddings(self, dirname, model): embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
fullfn = os.path.join(dirname, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
...@@ -253,12 +74,7 @@ class StableDiffusionModelHijack: ...@@ -253,12 +74,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1: apply_optimizations()
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
...@@ -286,21 +102,24 @@ class StableDiffusionModelHijack: ...@@ -286,21 +102,24 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros' layer.padding_mode = 'circular' if enable else 'zeros'
def clear_comments(self):
self.comments = []
def tokenize(self, text): def tokenize(self, text):
max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.hijack = hijack self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {} self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens: for text, ident in tokens_with_parens:
mult = 1.0 mult = 1.0
...@@ -317,11 +136,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -317,11 +136,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments): def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
if opts.enable_emphasis: if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line) parsed = prompt_parser.parse_prompt_attention(line)
...@@ -333,48 +149,53 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -333,48 +149,53 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes = [] fixes = []
remade_tokens = [] remade_tokens = []
multipliers = [] multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed): for tokens, (text, weight) in zip(tokenized, parsed):
i = 0 i = 0
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
if possible_matches is None: remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(weight) multipliers.append(weight)
i += 1
else: else:
found = False emb_len = int(embedding.vec.shape[0])
for ids, word in possible_matches: iteration = len(remade_tokens) // 75
if tokens[i:i + len(ids)] == ids: if (len(remade_tokens) + emb_len) // 75 != iteration:
emb_len = int(self.hijack.word_embeddings[word].shape[0]) rem = (75 * (iteration + 1) - len(remade_tokens))
fixes.append((len(remade_tokens), word)) remade_tokens += [id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len multipliers += [weight] * emb_len
i += len(ids) - 1 used_custom_terms.append((embedding.name, embedding.checksum()))
found = True i += embedding_length_in_tokens
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) prompt_target_length = get_target_prompt_token_count(token_count)
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] tokens_to_add = prompt_target_length - len(remade_tokens)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count return remade_tokens, fixes, multipliers, token_count
...@@ -391,7 +212,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -391,7 +212,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache: if line in cache:
remade_tokens, fixes, multipliers = cache[line] remade_tokens, fixes, multipliers = cache[line]
else: else:
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers) cache[line] = (remade_tokens, fixes, multipliers)
...@@ -405,7 +227,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -405,7 +227,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_text_old(self, text): def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = [] used_custom_terms = []
remade_batch_tokens = [] remade_batch_tokens = []
overflowing_words = [] overflowing_words = []
...@@ -431,32 +253,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -431,32 +253,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None: if mult_change is not None:
mult *= mult_change mult *= mult_change
elif possible_matches is None: i += 1
elif embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult) multipliers.append(mult)
i += 1
else: else:
found = False emb_len = int(embedding.vec.shape[0])
for ids, word in possible_matches: fixes.append((len(remade_tokens), embedding))
if tokens[i:i+len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len multipliers += [mult] * emb_len
i += len(ids) - 1 used_custom_terms.append((embedding.name, embedding.checksum()))
found = True i += embedding_length_in_tokens
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
...@@ -464,6 +277,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -464,6 +277,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
...@@ -478,25 +292,72 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -478,25 +292,72 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text): def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if opts.use_old_emphasis_implementation: if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else: else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(device) batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean() original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean() new_mean = z.mean()
...@@ -517,14 +378,19 @@ class EmbeddingsWithFixes(torch.nn.Module): ...@@ -517,14 +378,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids) inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None: if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes: for offset, embedding in fixes:
emb = self.embeddings.word_embeddings[word] emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
return inputs_embeds vecs.append(tensor)
return torch.stack(vecs)
def add_circular_option_to_conv_2d(): def add_circular_option_to_conv_2d():
......
import math
import sys
import traceback
import importlib
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
import xformers.ops
shared.xformers_available = True
except Exception:
print("Cannot import xformers", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
k_in *= self.scale
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def check_for_psutil():
try:
spec = importlib.util.find_spec('psutil')
return spec is not None
except ModuleNotFoundError:
return False
invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
def einsum_op_slice_0(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
def einsum_op_slice_1(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
return r
def einsum_op_mps_v1(q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[1] <= 4096:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return einsum_op_compvis(q, k, v)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(q, k, v):
if q.device.type == 'cuda':
return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
if mem_total_gb >= 32:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return einsum_op_tensor_mem(q, k, v, 32)
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
# -- End of code from https://github.com/invoke-ai/InvokeAI --
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
def xformers_attnblock_forward(self, x):
try:
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_).contiguous()
k1 = self.k(h_).contiguous()
v = self.v(h_).contiguous()
out = xformers.ops.memory_efficient_attention(q1, k1, v)
out = self.proj_out(out)
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
import glob import collections
import os.path import os.path
import sys import sys
from collections import namedtuple from collections import namedtuple
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader from modules import shared, modelloader, devices
from modules.paths import models_path from modules.paths import models_path
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
...@@ -30,12 +27,10 @@ except Exception: ...@@ -30,12 +27,10 @@ except Exception:
pass pass
def setup_model(dirname): def setup_model():
global user_dir
user_dir = dirname
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(model_path) os.makedirs(model_path)
checkpoints_list.clear()
list_models() list_models()
...@@ -45,13 +40,13 @@ def checkpoint_tiles(): ...@@ -45,13 +40,13 @@ def checkpoint_tiles():
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
def modeltitle(path, shorthash): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
if user_dir is not None and abspath.startswith(user_dir): if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(user_dir, '') name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path): elif abspath.startswith(model_path):
name = abspath.replace(model_path, '') name = abspath.replace(model_path, '')
else: else:
...@@ -68,14 +63,20 @@ def list_models(): ...@@ -68,14 +63,20 @@ def list_models():
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
shared.opts.sd_model_checkpoint = title shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list: for filename in model_list:
h = model_hash(filename) h = model_hash(filename)
title, short_model_name = modeltitle(filename, h) title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
basename, _ = os.path.splitext(filename)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString): def get_closet_checkpoint_match(searchString):
...@@ -106,7 +107,10 @@ def select_checkpoint(): ...@@ -106,7 +107,10 @@ def select_checkpoint():
if len(checkpoints_list) == 0: if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1) exit(1)
...@@ -118,14 +122,25 @@ def select_checkpoint(): ...@@ -118,14 +122,25 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
def load_model_weights(model, checkpoint_file, sd_model_hash): def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
return pl_sd["state_dict"]
return pl_sd
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location="cpu") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
sd = get_state_dict_from_checkpoint(pl_sd)
model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
...@@ -134,17 +149,45 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): ...@@ -134,17 +149,45 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
model.half() model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
def load_model(): def load_model():
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint() checkpoint_info = select_checkpoint()
sd_config = OmegaConf.load(shared.cmd_opts.config) if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
...@@ -163,9 +206,14 @@ def reload_model_weights(sd_model, info=None): ...@@ -163,9 +206,14 @@ def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
checkpoints_loaded.clear()
shared.sd_model = load_model()
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
else: else:
...@@ -173,7 +221,7 @@ def reload_model_weights(sd_model, info=None): ...@@ -173,7 +221,7 @@ def reload_model_weights(sd_model, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
......
...@@ -7,37 +7,63 @@ import inspect ...@@ -7,37 +7,63 @@ import inspect
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
from modules import prompt_parser from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [ samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a']), ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
('Euler', 'sample_euler', ['k_euler']), ('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms']), ('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun']), ('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2']), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
for label, funcname, aliases in samplers_k_diffusion for label, funcname, aliases, options in samplers_k_diffusion
if hasattr(k_diffusion.sampling, funcname) if hasattr(k_diffusion.sampling, funcname)
] ]
samplers = [ all_samplers = [
*samplers_data_k_diffusion, *samplers_data_k_diffusion,
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
] ]
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
samplers = []
samplers_for_img2img = []
def create_sampler_with_index(list_of_configs, index, model):
config = list_of_configs[index]
sampler = config.constructor(model)
sampler.config = config
return sampler
def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
set_samplers()
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
...@@ -57,7 +83,7 @@ def setup_img2img_steps(p, steps=None): ...@@ -57,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples): def sample_to_image(samples):
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
...@@ -77,8 +103,10 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): ...@@ -77,8 +103,10 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence) state.sampling_steps = len(sequence)
state.sampling_step = 0 state.sampling_step = 0
for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs): seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
if state.interrupted:
for x in seq:
if state.interrupted or state.skipped:
break break
yield x yield x
...@@ -102,14 +130,28 @@ class VanillaStableDiffusionSampler: ...@@ -102,14 +130,28 @@ class VanillaStableDiffusionSampler:
self.step = 0 self.step = 0
self.eta = None self.eta = None
self.default_eta = 0.0 self.default_eta = 0.0
self.config = None
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return 0 return 0
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
# filling unconditional_conditioning with repeats of the last vector to match length is
# not 100% correct but should work well enough
if unconditional_conditioning.shape[1] < cond.shape[1]:
last_vector = unconditional_conditioning[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
elif unconditional_conditioning.shape[1] > cond.shape[1]:
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
if self.mask is not None: if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts) img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec x_dec = img_orig * self.mask + self.nmask * x_dec
...@@ -125,7 +167,7 @@ class VanillaStableDiffusionSampler: ...@@ -125,7 +167,7 @@ class VanillaStableDiffusionSampler:
return res return res
def initialize(self, p): def initialize(self, p):
self.eta = p.eta or opts.eta_ddim self.eta = p.eta if p.eta is not None else opts.eta_ddim
for fieldname in ['p_sample_ddim', 'p_sample_plms']: for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname): if hasattr(self.sampler, fieldname):
...@@ -139,7 +181,7 @@ class VanillaStableDiffusionSampler: ...@@ -139,7 +181,7 @@ class VanillaStableDiffusionSampler:
self.initialize(p) self.initialize(p)
# existing code fails with cetain step counts, like 9 # existing code fails with certain step counts, like 9
try: try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception: except Exception:
...@@ -162,7 +204,7 @@ class VanillaStableDiffusionSampler: ...@@ -162,7 +204,7 @@ class VanillaStableDiffusionSampler:
steps = steps or p.steps steps = steps or p.steps
# existing code fails with cetin step counts, like 9 # existing code fails with certain step counts, like 9
try: try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
except Exception: except Exception:
...@@ -181,19 +223,42 @@ class CFGDenoiser(torch.nn.Module): ...@@ -181,19 +223,42 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0 self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2) x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
else: else:
uncond = self.inner_model(x, sigma, cond=uncond) x_out = torch.zeros_like(x_in)
cond = self.inner_model(x, sigma, cond=cond) for batch_offset in range(0, x_out.shape[0], batch_size):
denoised = uncond + (cond - uncond) * cond_scale a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
if self.mask is not None: if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = self.init_latent * self.mask + self.nmask * denoised
...@@ -207,8 +272,10 @@ def extended_trange(sampler, count, *args, **kwargs): ...@@ -207,8 +272,10 @@ def extended_trange(sampler, count, *args, **kwargs):
state.sampling_steps = count state.sampling_steps = count
state.sampling_step = 0 state.sampling_step = 0
for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs): seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
if state.interrupted:
for x in seq:
if state.interrupted or state.skipped:
break break
if sampler.stop_at is not None and x > sampler.stop_at: if sampler.stop_at is not None and x > sampler.stop_at:
...@@ -246,6 +313,7 @@ class KDiffusionSampler: ...@@ -246,6 +313,7 @@ class KDiffusionSampler:
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
self.config = None
def callback_state(self, d): def callback_state(self, d):
store_latent(d["denoised"]) store_latent(d["denoised"])
...@@ -292,27 +360,42 @@ class KDiffusionSampler: ...@@ -292,27 +360,42 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else: else:
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1] sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)
if 'sigma_min' in inspect.signature(self.func).parameters:
sigma_sched = sigmas[steps - t_enc - 1:] ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
if 'sigma_max' in inspect.signature(self.func).parameters:
extra_params_kwargs['sigma_max'] = sigma_sched[0]
if 'n' in inspect.signature(self.func).parameters:
extra_params_kwargs['n'] = len(sigma_sched) - 1
if 'sigma_sched' in inspect.signature(self.func).parameters:
extra_params_kwargs['sigma_sched'] = sigma_sched
if 'sigmas' in inspect.signature(self.func).parameters:
extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps steps = steps or p.steps
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else: else:
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0] x = x * sigmas[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)
......
...@@ -12,12 +12,13 @@ import modules.interrogate ...@@ -12,12 +12,13 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
from modules.devices import get_optimal_device import modules.devices as devices
from modules.paths import script_path, sd_path from modules import sd_samplers, sd_models
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
...@@ -25,26 +26,36 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director ...@@ -25,26 +26,36 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
...@@ -53,21 +64,44 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire ...@@ -53,21 +64,44 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
device = get_optimal_device()
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
def reload_hypernetworks():
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
class State: class State:
skipped = False
interrupted = False interrupted = False
job = "" job = ""
job_no = 0 job_no = 0
...@@ -78,6 +112,10 @@ class State: ...@@ -78,6 +112,10 @@ class State:
current_latent = None current_latent = None
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None
def skip(self):
self.skipped = True
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
...@@ -88,7 +126,7 @@ class State: ...@@ -88,7 +126,7 @@ class State:
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State() state = State()
...@@ -101,8 +139,6 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) ...@@ -101,8 +139,6 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate") interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
# This was moved to webui.py with the other model "setup" calls.
# modules.sd_models.list_models()
def realesrgan_models_names(): def realesrgan_models_names():
...@@ -111,18 +147,19 @@ def realesrgan_models_names(): ...@@ -111,18 +147,19 @@ def realesrgan_models_names():
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = None self.section = None
self.refresh = refresh
def options_section(section_identifer, options_dict): def options_section(section_identifier, options_dict):
for k, v in options_dict.items(): for k, v in options_dict.items():
v.section = section_identifer v.section = section_identifier
return options_dict return options_dict
...@@ -140,6 +177,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -140,6 +177,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_format": OptionInfo('png', 'File format for grids'), "grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
...@@ -150,6 +188,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -150,6 +188,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
...@@ -165,9 +204,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { ...@@ -165,9 +204,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("", "Directory name pattern"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
})) }))
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
...@@ -177,7 +217,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { ...@@ -177,7 +217,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
...@@ -189,50 +229,75 @@ options_templates.update(options_section(('face-restoration', "Face restoration" ...@@ -189,50 +229,75 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
})) }))
class Options: class Options:
data = None data = None
data_labels = options_templates data_labels = options_templates
...@@ -289,6 +354,8 @@ class Options: ...@@ -289,6 +354,8 @@ class Options:
item = self.data_labels.get(key) item = self.data_labels.get(key)
item.onchange = func item.onchange = func
func()
def dumpjson(self): def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d) return json.dumps(d)
...@@ -318,14 +385,14 @@ class TotalTQDM: ...@@ -318,14 +385,14 @@ class TotalTQDM:
) )
def update(self): def update(self):
if not opts.multiple_tqdm: if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return return
if self._tqdm is None: if self._tqdm is None:
self.reset() self.reset()
self._tqdm.update() self._tqdm.update()
def updateTotal(self, new_total): def updateTotal(self, new_total):
if not opts.multiple_tqdm: if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return return
if self._tqdm is None: if self._tqdm is None:
self.reset() self.reset()
......
...@@ -5,11 +5,12 @@ import numpy as np ...@@ -5,11 +5,12 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
precision_scope = ( precision_scope = (
...@@ -24,7 +25,6 @@ class UpscalerSwinIR(Upscaler): ...@@ -24,7 +25,6 @@ class UpscalerSwinIR(Upscaler):
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth " "-L_x4_GAN.pth "
self.model_name = "SwinIR 4x" self.model_name = "SwinIR 4x"
self.model_path = os.path.join(models_path, self.name)
self.user_path = dirname self.user_path = dirname
super().__init__() super().__init__()
scalers = [] scalers = []
...@@ -58,6 +58,22 @@ class UpscalerSwinIR(Upscaler): ...@@ -58,6 +58,22 @@ class UpscalerSwinIR(Upscaler):
filename = path filename = path
if filename is None or not os.path.exists(filename): if filename is None or not os.path.exists(filename):
return None return None
if filename.endswith(".v2.pth"):
model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = net( model = net(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
...@@ -71,9 +87,13 @@ class UpscalerSwinIR(Upscaler): ...@@ -71,9 +87,13 @@ class UpscalerSwinIR(Upscaler):
upsampler="nearest+conv", upsampler="nearest+conv",
resi_connection="3conv", resi_connection="3conv",
) )
params = "params_ema"
pretrained_model = torch.load(filename) pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True) if params is not None:
model.load_state_dict(pretrained_model[params], strict=True)
else:
model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half: if not cmd_opts.no_half:
model = model.half() model = model.half()
return model return model
...@@ -122,6 +142,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -122,6 +142,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device) W = torch.zeros_like(E, dtype=torch.half, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
for w_idx in w_idx_list: for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
...@@ -134,6 +155,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -134,6 +155,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
W[ W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask) ].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W) output = E.div_(W)
return output return output
...@@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module): ...@@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module):
Args: Args:
dim (int): Number of input channels. dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion. input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
window_size (int): Window size. window_size (int): Window size.
shift_size (int): Shift size for SW-MSA. shift_size (int): Shift size for SW-MSA.
......
# -----------------------------------------------------------------------------------
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
# Written by Conde and Choi et al.
# -----------------------------------------------------------------------------------
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
nn.ReLU(inplace=True),
nn.Linear(512, num_heads, bias=False))
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(
torch.meshgrid([relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
else:
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
self.register_buffer("relative_coords_table", relative_coords_table)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim))
self.v_bias = nn.Parameter(torch.zeros(dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, ' \
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size))
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
#assert L == H * W, "input feature has wrong size"
shortcut = x
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(self.norm1(x))
# FFN
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
pretrained_window_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
def _init_respostnorm(self):
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1],
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class Swin2SR(nn.Module):
r""" Swin2SR
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(Swin2SR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_aux':
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.upsample_hf = Upsample_hf(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
self.conv_before_upsample_hf = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers_hf:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffle_aux':
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
bicubic = self.conv_bicubic(bicubic)
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
x = self.conv_after_aux(aux)
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
x = self.conv_last(x)
aux = aux / self.img_range + self.mean
elif self.upsampler == 'pixelshuffle_hf':
# for classical SR with HF
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
x = x_out + x_hf
x_hf = x_hf / self.img_range + self.mean
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
else:
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = Swin2SR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
\ No newline at end of file
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import tqdm
from modules import devices, shared
import re
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
if include_cond:
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
assert len(self.dataset) > 1, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
return self.length
def __getitem__(self, i):
res = []
for j in range(self.batch_size):
position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
index = self.indexes[position % len(self.indexes)]
entry = self.dataset[index]
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
res.append(entry)
return res
import base64
import json
import numpy as np
import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embedding_to_b64(data):
d = json.dumps(data, cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def embedding_from_b64(data):
d = base64.b64decode(data)
return json.loads(d, cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed % 255
def xor_block(block):
g = lcg()
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
def style_block(block, sequence):
im = Image.new('RGB', (block.shape[1], block.shape[0]))
draw = ImageDraw.Draw(im)
i = 0
for x in range(-6, im.size[0], 8):
for yi, y in enumerate(range(-6, im.size[1], 8)):
offset = 0
if yi % 2 == 0:
offset = 4
shade = sequence[i % len(sequence)]
i += 1
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
fg = np.array(im).astype(np.uint8) & 0xF0
return block ^ fg
def insert_image_data_embed(image, data):
d = 3
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
data_np_high = data_np_ >> 4
data_np_low = data_np_ & 0x0F
h = image.size[1]
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
data_np_low.resize(next_size)
data_np_low = data_np_low.reshape((h, -1, d))
data_np_high.resize(next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
data_np_low = style_block(data_np_low, sequence=edge_style)
data_np_low = xor_block(data_np_low)
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
data_np_high = xor_block(data_np_high)
im_low = Image.fromarray(data_np_low, mode='RGB')
im_high = Image.fromarray(data_np_high, mode='RGB')
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
background.paste(im_low, (0, 0))
background.paste(image, (im_low.size[0]+1, 0))
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
return background
def crop_black(img, tol=0):
mask = (img > tol).all(2)
mask0, mask1 = mask.any(0), mask.any(1)
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end, col_start:col_end]
def extract_image_data_embed(image):
d = 3
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
if black_cols[0].shape[0] < 2:
print('No Image data blocks found.')
return None
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
data_block_lower = xor_block(data_block_lower)
data_block_upper = xor_block(data_block_upper)
data_block = (data_block_upper << 4) | (data_block_lower)
data_block = data_block.flatten().tobytes()
data = zlib.decompress(data_block)
return json.loads(data, cls=EmbeddingDecoder)
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
from math import cos
image = srcimage.copy()
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
textfont = opts.font or Roboto
except Exception:
textfont = Roboto
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
for y in range(image.size[1]):
mag = 1-cos(y/image.size[1]*factor)
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
fontsize = 32
font = ImageFont.truetype(textfont, fontsize)
padding = 10
_, _, w, h = draw.textbbox((0, 0), title, font=font)
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
font = ImageFont.truetype(textfont, fontsize)
_, _, w, h = draw.textbbox((0, 0), title, font=font)
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
_, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
return image
if __name__ == '__main__':
testEmbed = Image.open('test_embedding.png')
data = extract_image_data_embed(testEmbed)
assert data is not None
data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
assert data is not None
image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
embedded_image = insert_image_data_embed(cap_image, test_embed)
retrived_embed = extract_image_data_embed(embedded_image)
assert str(retrived_embed) == str(test_embed)
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
assert embedded_image == embedded_image2
g = lcg()
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
204, 86, 73, 222, 44, 198, 118, 240, 97]
assert shared_random == reference_random
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
assert 12731374 == hunna_kay_random_sum
import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
"""
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
def __iter__(self):
return self
def __next__(self):
if self.it < self.maxit:
self.it += 1
return self.rates[self.it - 1]
else:
raise StopIteration
class LearnRateScheduler:
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
(self.learn_rate, self.end_step) = next(self.schedules)
self.verbose = verbose
if self.verbose:
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
self.finished = False
def apply(self, optimizer, step_number):
if step_number <= self.end_step:
return
try:
(self.learn_rate, self.end_step) = next(self.schedules)
except Exception:
self.finished = True
return
if self.verbose:
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
for pg in optimizer.param_groups:
pg['lr'] = self.learn_rate
import os
from PIL import Image, ImageOps
import platform
import sys
import tqdm
import time
from modules import shared, images
from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
if process_caption_deepbooru:
db_opts = deepbooru.create_deepbooru_opts()
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
if process_caption:
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
files = os.listdir(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
def save_pic_with_caption(image, index):
caption = ""
if process_caption:
caption += shared.interrogator.generate_caption(image)
if process_caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
filename_part = filename
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
def save_pic(image, index):
save_pic_with_caption(image, index)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
if shared.state.interrupted:
break
ratio = img.height / img.width
is_tall = ratio > 1.35
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height))
save_pic(top, index)
bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
save_pic(left, index)
right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
import os
import sys
import traceback
import torch
import tqdm
import html
import datetime
import csv
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
"sd_checkpoint": self.sd_checkpoint,
"sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
def load_textual_inversion_embeddings(self):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = []
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
else:
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('hash', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding, len(ids)
return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
embedding.save(fn)
return fn
def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if step % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
if write_csv_header:
csv_writer.writeheader()
epoch = step // epoch_len
epoch_step = step - epoch * epoch_len
csv_writer.writerow({
"step": step + 1,
"epoch": epoch + 1,
"epoch_step": epoch_step + 1,
**values,
})
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
if create_image_every > 0 and save_image_with_stored_embedding:
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
checkpoint = sd_models.select_checkpoint()
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
embedding.save(filename)
return embedding, filename
import html
import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
return "Preprocessing finished.", ""
def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
try:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
Embedding saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
sd_hijack.apply_optimizations()
...@@ -6,7 +6,7 @@ import modules.processing as processing ...@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -30,11 +30,14 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -30,11 +30,14 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces, restore_faces=restore_faces,
tiling=tiling, tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None,
firstphase_height=firstphase_height if enable_hr else None,
) )
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args) processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None: if processed is None:
...@@ -46,5 +49,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -46,5 +49,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)
...@@ -11,6 +11,7 @@ import time ...@@ -11,6 +11,7 @@ import time
import traceback import traceback
import platform import platform
import subprocess as sp import subprocess as sp
from functools import reduce
import numpy as np import numpy as np
import torch import torch
...@@ -21,8 +22,11 @@ import gradio as gr ...@@ -21,8 +22,11 @@ import gradio as gr
import gradio.utils import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack, sd_models
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
...@@ -32,8 +36,13 @@ import modules.gfpgan_model ...@@ -32,8 +36,13 @@ import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('application/javascript', '.js')
...@@ -43,6 +52,11 @@ if not cmd_opts.share and not cmd_opts.listen: ...@@ -43,6 +52,11 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1' gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
if cmd_opts.ngrok != None:
import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...')
ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860)
def gr_show(visible=True): def gr_show(visible=True):
return {"visible": visible, "__type__": "update"} return {"visible": visible, "__type__": "update"}
...@@ -64,7 +78,9 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ ...@@ -64,7 +78,9 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨 art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\uD83D\uDCC2' folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
def plaintext_to_html(text): def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>" text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
...@@ -93,19 +109,32 @@ def send_gradio_gallery_to_image(x): ...@@ -93,19 +109,32 @@ def send_gradio_gallery_to_image(x):
return image_from_url_text(x[0]) return image_from_url_text(x[0])
def save_files(js_data, images, index): def save_files(js_data, images, do_make_zip, index):
import csv import csv
os.makedirs(opts.outdir_save, exist_ok=True)
filenames = [] filenames = []
fullfns = []
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
class MyObject:
def __init__(self, d=None):
if d is not None:
for key, value in d.items():
setattr(self, key, value)
data = json.loads(js_data) data = json.loads(js_data)
p = MyObject(data)
path = opts.outdir_save
save_to_dirs = opts.use_save_to_dirs_for_ui
extension: str = opts.samples_format
start_index = 0
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
images = [images[index]] images = [images[index]]
infotexts = [data["infotexts"][index]] start_index = index
else:
infotexts = data["infotexts"] os.makedirs(opts.outdir_save, exist_ok=True)
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0 at_start = file.tell() == 0
...@@ -113,37 +142,42 @@ def save_files(js_data, images, index): ...@@ -113,37 +142,42 @@ def save_files(js_data, images, index):
if at_start: if at_start:
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
filename_base = str(int(time.time() * 1000)) for image_index, filedata in enumerate(images, start_index):
extension = opts.samples_format.lower()
for i, filedata in enumerate(images):
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
filepath = os.path.join(opts.outdir_save, filename)
if filedata.startswith("data:image/png;base64,"): if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8')))) image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
if opts.enable_pnginfo and extension == 'png':
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', infotexts[i])
image.save(filepath, pnginfo=pnginfo)
else:
image.save(filepath, quality=opts.jpeg_quality)
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"): is_grid = image_index < p.index_of_first_image
piexif.insert(piexif.dump({"Exif": { i = 0 if is_grid else (image_index - p.index_of_first_image)
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
}}), filepath)
fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
filenames.append(filename) filenames.append(filename)
fullfns.append(fullfn)
if txt_fullfn:
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
return '', '', plaintext_to_html(f"Saved: {filenames[0]}") # Make Zip
if do_make_zip:
zip_filepath = os.path.join(path, "images.zip")
from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file:
for i in range(len(fullfns)):
with open(fullfns[i], mode="rb") as f:
zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath)
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def wrap_gradio_call(func):
def f(*args, **kwargs): def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon: if run_memmon:
shared.mem_mon.monitor() shared.mem_mon.monitor()
...@@ -152,16 +186,31 @@ def wrap_gradio_call(func): ...@@ -152,16 +186,31 @@ def wrap_gradio_call(func):
try: try:
res = list(func(*args, **kwargs)) res = list(func(*args, **kwargs))
except Exception as e: except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr) print("Error completing request", file=sys.stderr)
print("Arguments:", args, kwargs, file=sys.stderr) argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if (elapsed_m > 0):
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon: if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
...@@ -176,9 +225,11 @@ def wrap_gradio_call(func): ...@@ -176,9 +225,11 @@ def wrap_gradio_call(func):
vram_html = '' vram_html = ''
# last item is always HTML # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res) return tuple(res)
...@@ -187,7 +238,7 @@ def wrap_gradio_call(func): ...@@ -187,7 +238,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part): def check_progress_call(id_part):
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False) return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0 progress = 0
...@@ -219,13 +270,19 @@ def check_progress_call(id_part): ...@@ -219,13 +270,19 @@ def check_progress_call(id_part):
else: else:
preview_visibility = gr_show(True) preview_visibility = gr_show(True)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image if shared.state.textinfo is not None:
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
else:
textinfo_result = gr_show(False)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part): def check_progress_call_initial(id_part):
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.textinfo = None
return check_progress_call(id_part) return check_progress_call(id_part)
...@@ -271,6 +328,11 @@ def interrogate(image): ...@@ -271,6 +328,11 @@ def interrogate(image):
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image)
return gr_show(True) if prompt is None else prompt
def create_seed_inputs(): def create_seed_inputs():
with gr.Row(): with gr.Row():
with gr.Box(): with gr.Box():
...@@ -345,11 +407,24 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: ...@@ -345,11 +407,24 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text):
tokens, token_count, max_length = model_hijack.tokenize(text) def update_token_counter(text, steps):
try:
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else "" style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>" return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img): def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"
...@@ -358,54 +433,76 @@ def create_toprow(is_img2img): ...@@ -358,54 +433,76 @@ def create_toprow(is_img2img):
with gr.Row(): with gr.Row():
with gr.Column(scale=80): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
with gr.Column(scale=10, elem_id="style_pos_col"): with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with gr.Column(scale=1, elem_id="roll_col"):
sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"): with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row(): with gr.Row():
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
fn=lambda: shared.state.skip(),
inputs=[],
outputs=[],
)
interrupt.click( interrupt.click(
fn=lambda: shared.state.interrupt(), fn=lambda: shared.state.interrupt(),
inputs=[], inputs=[],
outputs=[], outputs=[],
) )
with gr.Row(): with gr.Row(scale=1):
if is_img2img: if is_img2img:
interrogate = gr.Button('Interrogate', elem_id="interrogate") interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
if cmd_opts.deepdanbooru:
deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
else:
deepbooru = None
else: else:
interrogate = None interrogate = None
deepbooru = None
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create") save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
if textinfo is None:
textinfo = gr.HTML(visible=False)
def setup_progressbar(progressbar, preview, id_part):
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click( check_progress.click(
fn=lambda: check_progress_call(id_part), fn=lambda: check_progress_call(id_part),
show_progress=False, show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar, preview, preview], outputs=[progressbar, preview, preview, textinfo],
) )
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
...@@ -413,14 +510,44 @@ def setup_progressbar(progressbar, preview, id_part): ...@@ -413,14 +510,44 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part), fn=lambda: check_progress_call_initial(id_part),
show_progress=False, show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar, preview, preview], outputs=[progressbar, preview, preview, textinfo],
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def apply_setting(key, value):
if value is None:
return gr.update()
if key == "sd_model_checkpoint":
ckpt_info = sd_models.get_closet_checkpoint_match(value)
if ckpt_info is not None:
value = ckpt_info.title
else:
return gr.update()
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
return
valtype = type(opts.data_labels[key].default)
oldval = opts.data[key]
opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
opts.save(shared.config_filename)
return value
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
with gr.Row(elem_id='txt2img_progress_row'): with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1): with gr.Column(scale=1):
...@@ -446,11 +573,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -446,11 +573,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
enable_hr = gr.Checkbox(label='Highres. fix', value=False) enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options: with gr.Row(visible=False) as hr_options:
scale_latent = gr.Checkbox(label='Scale latent', value=False) firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(): with gr.Row(equal_height=True):
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
...@@ -475,6 +603,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -475,6 +603,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False)
...@@ -483,7 +617,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -483,7 +617,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict( txt2img_args = dict(
fn=txt2img, fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit", _js="submit",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
...@@ -502,8 +636,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -502,8 +636,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
height, height,
width, width,
enable_hr, enable_hr,
scale_latent,
denoising_strength, denoising_strength,
firstphase_width,
firstphase_height,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
txt2img_gallery, txt2img_gallery,
...@@ -516,6 +651,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -516,6 +651,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
txt2img_prompt.submit(**txt2img_args) txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args) submit.click(**txt2img_args)
txt_prompt_img.change(
fn=modules.images.image_data,
inputs=[
txt_prompt_img
],
outputs=[
txt2img_prompt,
txt_prompt_img
]
)
enable_hr.change( enable_hr.change(
fn=lambda x: gr_show(x), fn=lambda x: gr_show(x),
inputs=[enable_hr], inputs=[enable_hr],
...@@ -524,13 +670,15 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -524,13 +670,15 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
save.click( save.click(
fn=wrap_gradio_call(save_files), fn=wrap_gradio_call(save_files),
_js="(x, y, z) => [x, y, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
inputs=[ inputs=[
generation_info, generation_info,
txt2img_gallery, txt2img_gallery,
do_make_zip,
html_info, html_info,
], ],
outputs=[ outputs=[
download_files,
html_info, html_info,
html_info, html_info,
html_info, html_info,
...@@ -539,6 +687,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -539,6 +687,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click( roll.click(
fn=roll_artist, fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
], ],
...@@ -565,13 +714,29 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -565,13 +714,29 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d), (enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
txt2img_preview_params = [
txt2img_prompt,
txt2img_negative_prompt,
steps,
sampler_index,
cfg_scale,
seed,
width,
height,
]
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
with gr.Column(scale=1): with gr.Column(scale=1):
pass pass
...@@ -585,7 +750,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -585,7 +750,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'): with gr.TabItem('img2img', id='img2img'):
init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil") init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
...@@ -607,7 +772,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -607,7 +772,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.TabItem('Batch img2img', id='batch'): with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>") gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
...@@ -626,7 +791,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -626,7 +791,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
tiling = gr.Checkbox(label='Tiling', value=False) tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row(): with gr.Row():
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group(): with gr.Group():
...@@ -653,6 +818,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -653,6 +818,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False)
...@@ -660,6 +831,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -660,6 +831,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
img2img_prompt_img.change(
fn=modules.images.image_data,
inputs=[
img2img_prompt_img
],
outputs=[
img2img_prompt,
img2img_prompt_img
]
)
mask_mode.change( mask_mode.change(
lambda mode, img: { lambda mode, img: {
init_img_with_mask: gr_show(mode == 0), init_img_with_mask: gr_show(mode == 0),
...@@ -675,7 +857,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -675,7 +857,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
) )
img2img_args = dict( img2img_args = dict(
fn=img2img, fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img", _js="submit_img2img",
inputs=[ inputs=[
dummy_component, dummy_component,
...@@ -726,15 +908,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -726,15 +908,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
if cmd_opts.deepdanbooru:
img2img_deepbooru.click(
fn=interrogate_deepbooru,
inputs=[init_img],
outputs=[img2img_prompt],
)
save.click( save.click(
fn=wrap_gradio_call(save_files), fn=wrap_gradio_call(save_files),
_js="(x, y, z) => [x, y, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
inputs=[ inputs=[
generation_info, generation_info,
img2img_gallery, img2img_gallery,
html_info do_make_zip,
html_info,
], ],
outputs=[ outputs=[
download_files,
html_info, html_info,
html_info, html_info,
html_info, html_info,
...@@ -743,6 +934,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -743,6 +934,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click( roll.click(
fn=roll_artist, fn=roll_artist,
_js="update_img2img_tokens",
inputs=[ inputs=[
img2img_prompt, img2img_prompt,
], ],
...@@ -753,6 +945,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -753,6 +945,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click( button.click(
...@@ -764,9 +957,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -764,9 +957,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
) )
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns): for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click( button.click(
fn=apply_styles, fn=apply_styles,
_js=js_func,
inputs=[prompt, negative_prompt, style1, style2], inputs=[prompt, negative_prompt, style1, style2],
outputs=[prompt, negative_prompt, style1, style2], outputs=[prompt, negative_prompt, style1, style2],
) )
...@@ -788,7 +982,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -788,7 +982,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(seed_resize_from_h, "Seed resize from-2"), (seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
...@@ -800,7 +994,15 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -800,7 +994,15 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.TabItem('Batch Process'): with gr.TabItem('Batch Process'):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group(): with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
...@@ -827,10 +1029,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -827,10 +1029,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else '' button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
open_extras_folder = gr.Button('Open output directory', elem_id=button_id) open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click( submit.click(
fn=run_extras, fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index", _js="get_extras_tab_index",
inputs=[ inputs=[
dummy_component,
dummy_component, dummy_component,
extras_image, extras_image,
image_batch, image_batch,
...@@ -838,6 +1042,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -838,6 +1042,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
codeformer_visibility, codeformer_visibility,
codeformer_weight, codeformer_weight,
upscaling_resize, upscaling_resize,
upscaling_resize_w,
upscaling_resize_h,
upscaling_crop,
extras_upscaler_1, extras_upscaler_1,
extras_upscaler_2, extras_upscaler_2,
extras_upscaler_2_visibility, extras_upscaler_2_visibility,
...@@ -858,7 +1065,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -858,7 +1065,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
extras_send_to_inpaint.click( extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x), fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_img2img", _js="extract_image_from_gallery_inpaint",
inputs=[result_images], inputs=[result_images],
outputs=[init_img_with_mask], outputs=[init_img_with_mask],
) )
...@@ -878,10 +1085,18 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -878,10 +1085,18 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img') pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change( image.change(
fn=wrap_gradio_call(run_pnginfo), fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image], inputs=[image],
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
#images history
images_history_switch_dict = {
"fn":modules.generation_parameters_copypaste.connect_paste,
"t2i":txt2img_paste_fields,
"i2i":img2img_paste_fields
}
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface: with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
...@@ -889,18 +1104,201 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -889,18 +1104,201 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>") gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row(): with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)") custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16") save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
def create_setting_component(key): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
with gr.Row().style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
initialization_text,
nvpt,
],
outputs=[
train_embedding_name,
ti_output,
ti_outcome,
]
)
create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
],
outputs=[
train_hypernetwork_name,
ti_output,
ti_outcome,
]
)
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
process_src,
process_dst,
process_width,
process_height,
process_flip,
process_split,
process_caption,
process_caption_deepbooru
],
outputs=[
ti_output,
ti_outcome,
],
)
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
batch_size,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
create_image_every,
save_embedding_every,
template_file,
save_image_with_stored_embedding,
preview_from_txt2img,
*txt2img_preview_params,
],
outputs=[
ti_output,
ti_outcome,
]
)
train_hypernetwork.click(
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
learn_rate,
batch_size,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
preview_from_txt2img,
*txt2img_preview_params,
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key, is_quicksettings=False):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
...@@ -920,12 +1318,48 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -920,12 +1318,48 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
else: else:
raise Exception(f'bad options item type: {str(t)} for key {key}') raise Exception(f'bad options item type: {str(t)} for key {key}')
return comp(label=info.label, value=fun, **(args or {})) if info.refresh is not None:
if is_quicksettings:
res = comp(label=info.label, value=fun, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
else:
with gr.Row(variant="compact"):
res = comp(label=info.label, value=fun, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
def refresh():
info.refresh()
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
for k, v in refreshed_args.items():
setattr(res, k, v)
return gr.update(**(refreshed_args or {}))
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[res],
)
else:
res = comp(label=info.label, value=fun, **(args or {}))
return res
components = [] components = []
component_dict = {} component_dict = {}
def open_folder(f): def open_folder(f):
if not os.path.isdir(f):
print(f"""
WARNING
An open_folder request was made with an argument that is not a folder.
This could be an error or a malicious attempt to run code on your computer.
Requested path was: {f}
""", file=sys.stderr)
return
if not shared.cmd_opts.hide_ui_dir_config: if not shared.cmd_opts.hide_ui_dir_config:
path = os.path.normpath(f) path = os.path.normpath(f)
if platform.system() == "Windows": if platform.system() == "Windows":
...@@ -939,10 +1373,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -939,10 +1373,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
changed = 0 changed = 0
for key, value, comp in zip(opts.data_labels.keys(), args, components): for key, value, comp in zip(opts.data_labels.keys(), args, components):
if not opts.same_type(value, opts.data_labels[key].default): if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
for key, value, comp in zip(opts.data_labels.keys(), args, components): for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
continue
comp_args = opts.data_labels[key].component_args comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue continue
...@@ -960,6 +1397,21 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -960,6 +1397,21 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
return f'{changed} settings changed.', opts.dumpjson() return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
opts.save(shared.config_filename)
return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:
settings_submit = gr.Button(value="Apply settings", variant='primary') settings_submit = gr.Button(value="Apply settings", variant='primary')
result = gr.HTML() result = gr.HTML()
...@@ -967,6 +1419,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -967,6 +1419,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
settings_cols = 3 settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
quicksettings_list = []
cols_displayed = 0 cols_displayed = 0
items_displayed = 0 items_displayed = 0
previous_section = None previous_section = None
...@@ -989,12 +1446,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -989,12 +1446,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1])) gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
if k in quicksettings_names:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
component = create_setting_component(k) component = create_setting_component(k)
component_dict[k] = component component_dict[k] = component
components.append(component) components.append(component)
items_displayed += 1 items_displayed += 1
with gr.Row():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
request_notifications.click( request_notifications.click(
fn=lambda: None, fn=lambda: None,
inputs=[], inputs=[],
...@@ -1002,6 +1467,27 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1002,6 +1467,27 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
_js='function(){}' _js='function(){}'
) )
def reload_scripts():
modules.scripts.reload_script_body_only()
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
outputs=[],
_js='function(){}'
)
def request_restart():
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
restart_gradio.click(
fn=request_restart,
inputs=[],
outputs=[],
_js='function(){restart_reload()}'
)
if column is not None: if column is not None:
column.__exit__() column.__exit__()
...@@ -1010,7 +1496,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1010,7 +1496,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(images_history, "History", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]
...@@ -1026,10 +1514,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1026,10 +1514,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
css += css_hide_progressbar css += css_hide_progressbar
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
settings_interface.gradio_ref = demo
with gr.Tabs() as tabs: with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid): with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render() interface.render()
if os.path.exists(os.path.join(script_path, "notification.mp3")): if os.path.exists(os.path.join(script_path, "notification.mp3")):
...@@ -1042,13 +1536,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1042,13 +1536,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[result, text_settings], outputs=[result, text_settings],
) )
for i, k, item in quicksettings_list:
component = component_dict[k]
component.change(
fn=lambda value, k=k: run_settings_single(value, key=k),
inputs=[component],
outputs=[component, text_settings],
)
def modelmerger(*args): def modelmerger(*args):
try: try:
results = run_modelmerger(*args) results = modules.extras.run_modelmerger(*args)
except Exception as e: except Exception as e:
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results return results
...@@ -1057,6 +1560,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1057,6 +1560,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
inputs=[ inputs=[
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
tertiary_model_name,
interp_method, interp_method,
interp_amount, interp_amount,
save_as_half, save_as_half,
...@@ -1066,6 +1570,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1066,6 +1570,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
submit_result, submit_result,
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
tertiary_model_name,
component_dict['sd_model_checkpoint'], component_dict['sd_model_checkpoint'],
] ]
) )
...@@ -1132,8 +1637,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1132,8 +1637,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[extras_image], outputs=[extras_image],
) )
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img') settings_map = {
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img') 'sd_hypernetwork': 'Hypernet',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
for k, v in settings_map.items()
]
modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
ui_config_file = cmd_opts.ui_config_file ui_config_file = cmd_opts.ui_config_file
ui_settings = {} ui_settings = {}
...@@ -1206,12 +1725,13 @@ for filename in sorted(os.listdir(jsdir)): ...@@ -1206,12 +1725,13 @@ for filename in sorted(os.listdir(jsdir)):
javascript += f"\n<script>{jsfile.read()}</script>" javascript += f"\n<script>{jsfile.read()}</script>"
def template_response(*args, **kwargs): if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs):
res = gradio_routes_templates_response(*args, **kwargs) res = gradio_routes_templates_response(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8")) res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers() res.init_headers()
return res return res
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response
...@@ -36,10 +36,11 @@ class Upscaler: ...@@ -36,10 +36,11 @@ class Upscaler:
self.half = not modules.shared.cmd_opts.no_half self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0 self.pre_pad = 0
self.mod_scale = None self.mod_scale = None
if self.name is not None and create_dirs:
if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(models_path, self.name)
if not os.path.exists(self.model_path): if self.model_path and create_dirs:
os.makedirs(self.model_path) os.makedirs(self.model_path, exist_ok=True)
try: try:
import cv2 import cv2
......
...@@ -4,7 +4,7 @@ fairscale==0.4.4 ...@@ -4,7 +4,7 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.4b3 gradio==3.4.1
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf
...@@ -13,14 +13,13 @@ Pillow ...@@ -13,14 +13,13 @@ Pillow
pytorch_lightning pytorch_lightning
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12 timm==0.4.12
transformers==4.19.2 transformers==4.19.2
torch torch
einops einops
jsonmerge jsonmerge
clean-fid clean-fid
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right resize-right
torchdiffeq torchdiffeq
kornia kornia
lark
...@@ -2,7 +2,7 @@ transformers==4.19.2 ...@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.4b3 gradio==3.4.1
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0
...@@ -18,7 +18,7 @@ piexif==1.1.3 ...@@ -18,7 +18,7 @@ piexif==1.1.3
einops==0.4.1 einops==0.4.1
jsonmerge==1.8.0 jsonmerge==1.8.0
clean-fid==0.1.29 clean-fid==0.1.29
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right==0.0.2 resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2
...@@ -6,6 +6,10 @@ function get_uiCurrentTab() { ...@@ -6,6 +6,10 @@ function get_uiCurrentTab() {
return gradioApp().querySelector('.tabs button:not(.border-transparent)') return gradioApp().querySelector('.tabs button:not(.border-transparent)')
} }
function get_uiCurrentTabContent() {
return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
}
uiUpdateCallbacks = [] uiUpdateCallbacks = []
uiTabChangeCallbacks = [] uiTabChangeCallbacks = []
let uiCurrentTab = null let uiCurrentTab = null
...@@ -40,6 +44,25 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -40,6 +44,25 @@ document.addEventListener("DOMContentLoaded", function() {
mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
}); });
/**
* Add a ctrl+enter as a shortcut to start a generation
*/
document.addEventListener('keydown', function(e) {
var handled = false;
if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) {
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
}
if (handled) {
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
if (button) {
button.click();
}
e.preventDefault();
}
})
/** /**
* checks that a UI element is not in another hidden element or tab content * checks that a UI element is not in another hidden element or tab content
*/ */
......
...@@ -8,7 +8,6 @@ import gradio as gr ...@@ -8,7 +8,6 @@ import gradio as gr
from modules import processing, shared, sd_samplers, prompt_parser from modules import processing, shared, sd_samplers, prompt_parser
from modules.processing import Processed from modules.processing import Processed
from modules.sd_samplers import samplers
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import torch import torch
...@@ -121,17 +120,45 @@ class Script(scripts.Script): ...@@ -121,17 +120,45 @@ class Script(scripts.Script):
return is_img2img return is_img2img
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.Markdown('''
* `CFG Scale` should be 2 or lower.
''')
override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True)
override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True)
original_prompt = gr.Textbox(label="Original prompt", lines=1) original_prompt = gr.Textbox(label="Original prompt", lines=1)
original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1) original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1)
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True)
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0) randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False) sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False)
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment): return [
p.batch_size = 1 info,
p.batch_count = 1 override_sampler,
override_prompt, original_prompt, original_negative_prompt,
override_steps, st,
override_strength,
cfg, randomness, sigma_adjustment,
]
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override
if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
if override_prompt:
p.prompt = original_prompt
p.negative_prompt = original_negative_prompt
if override_steps:
p.steps = st
if override_strength:
p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
...@@ -155,11 +182,11 @@ class Script(scripts.Script): ...@@ -155,11 +182,11 @@ class Script(scripts.Script):
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])]) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = samplers[p.sampler_index].constructor(p.sd_model) sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)
......
...@@ -38,6 +38,7 @@ class Script(scripts.Script): ...@@ -38,6 +38,7 @@ class Script(scripts.Script):
grids = [] grids = []
all_images = [] all_images = []
original_init_image = p.init_images
state.job_count = loops * batch_count state.job_count = loops * batch_count
initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
...@@ -45,6 +46,9 @@ class Script(scripts.Script): ...@@ -45,6 +46,9 @@ class Script(scripts.Script):
for n in range(batch_count): for n in range(batch_count):
history = [] history = []
# Reset to original init image at the start of each batch
p.init_images = original_init_image
for i in range(loops): for i in range(loops):
p.n_iter = 1 p.n_iter = 1
p.batch_size = 1 p.batch_size = 1
......
...@@ -85,8 +85,11 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0 ...@@ -85,8 +85,11 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0
src_dist = np.absolute(src_fft) src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
noise_rgb = np.random.random_sample((width, height, num_channels)) noise_rgb = rng.random((width, height, num_channels))
noise_grey = (np.sum(noise_rgb, axis=2) / 3.) noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range(num_channels): for c in range(num_channels):
......
...@@ -10,7 +10,6 @@ from modules.processing import Processed, process_images ...@@ -10,7 +10,6 @@ from modules.processing import Processed, process_images
from PIL import Image from PIL import Image
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file or textbox" return "Prompts from file or textbox"
...@@ -67,6 +66,9 @@ class Script(scripts.Script): ...@@ -67,6 +66,9 @@ class Script(scripts.Script):
"do_not_save_grid": process_boolean_tag "do_not_save_grid": process_boolean_tag
} }
def on_show(self, checkbox_txt, file, prompt_txt):
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
if (checkbox_txt): if (checkbox_txt):
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
......
...@@ -34,7 +34,11 @@ class Script(scripts.Script): ...@@ -34,7 +34,11 @@ class Script(scripts.Script):
seed = p.seed seed = p.seed
init_img = p.init_images[0] init_img = p.init_images[0]
if(upscaler.name != "None"):
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path) img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
else:
img = init_img
devices.torch_gc() devices.torch_gc()
......
from collections import namedtuple from collections import namedtuple
from copy import copy from copy import copy
from itertools import permutations, chain
import random import random
import csv
from io import StringIO
from PIL import Image from PIL import Image
import numpy as np import numpy as np
...@@ -9,7 +11,8 @@ import modules.scripts as scripts ...@@ -9,7 +11,8 @@ import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images
from modules.processing import process_images, Processed from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
...@@ -25,31 +28,101 @@ def apply_field(field): ...@@ -25,31 +28,101 @@ def apply_field(field):
def apply_prompt(p, x, xs): def apply_prompt(p, x, xs):
if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
p.prompt = p.prompt.replace(xs[0], x) p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x)
samplers_dict = {} def apply_order(p, x, xs):
for i, sampler in enumerate(modules.sd_samplers.samplers): token_order = []
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
prompt_parts = []
# Split the prompt up, taking out the tokens
for _, token in token_order:
n = p.prompt.find(token)
prompt_parts.append(p.prompt[0:n])
p.prompt = p.prompt[n + len(token):]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = ""
for idx, part in enumerate(prompt_parts):
prompt_tmp += part
prompt_tmp += x[idx]
p.prompt = prompt_tmp + p.prompt
def build_samplers_dict(p):
samplers_dict = {}
for i, sampler in enumerate(get_correct_sampler(p)):
samplers_dict[sampler.name.lower()] = i samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases: for alias in sampler.aliases:
samplers_dict[alias.lower()] = i samplers_dict[alias.lower()] = i
return samplers_dict
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = samplers_dict.get(x.lower(), None) sampler_index = build_samplers_dict(p).get(x.lower(), None)
if sampler_index is None: if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
p.sampler_index = sampler_index p.sampler_index = sampler_index
def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p)
for x in xs:
if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}")
def apply_checkpoint(p, x, xs): def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x) info = modules.sd_models.get_closet_checkpoint_match(x)
assert info is not None, f'Checkpoint for {x} not found' if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
def confirm_checkpoints(p, xs):
for x in xs:
if modules.sd_models.get_closet_checkpoint_match(x) is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_hypernetwork(p, x, xs):
if x.lower() in ["", "none"]:
name = None
else:
name = hypernetwork.find_closest_hypernetwork_name(x)
if not name:
raise RuntimeError(f"Unknown hypernetwork: {x}")
hypernetwork.load_hypernetwork(name)
def apply_hypernetwork_strength(p, x, xs):
hypernetwork.apply_strength(x)
def confirm_hypernetworks(p, xs):
for x in xs:
if x.lower() in ["", "none"]:
continue
if not hypernetwork.find_closest_hypernetwork_name(x):
raise RuntimeError(f"Unknown hypernetwork: {x}")
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
...@@ -60,46 +133,64 @@ def format_value_add_label(p, opt, x): ...@@ -60,46 +133,64 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x): def format_value(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
return x return x
def format_value_join_list(p, opt, x):
return ", ".join(x)
def do_nothing(p, x, xs): def do_nothing(p, x, xs):
pass pass
def format_nothing(p, opt, x): def format_nothing(p, opt, x):
return "" return ""
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) def str_permutations(x):
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
axis_options = [ axis_options = [
AxisOption("Nothing", str, do_nothing, format_nothing), AxisOption("Nothing", str, do_nothing, format_nothing, None),
AxisOption("Seed", int, apply_field("seed"), format_value_add_label), AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label), AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label), AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
] ]
def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
res = []
ver_texts = [[images.GridAnnotation(y)] for y in y_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_pocessed = None # Temporary list of all the images that are generated to be populated into the grid.
# Will be filled with empty images for any individual step that fails to process properly
image_cache = []
processed_result = None
cell_mode = "P"
cell_size = (1,1)
state.job_count = len(xs) * len(ys) * p.n_iter state.job_count = len(xs) * len(ys) * p.n_iter
...@@ -107,22 +198,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -107,22 +198,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
for ix, x in enumerate(xs): for ix, x in enumerate(xs):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed = cell(x, y) processed:Processed = cell(x, y)
if first_pocessed is None:
first_pocessed = processed
try: try:
res.append(processed.images[0]) # this dereference will throw an exception if the image was not processed
# (this happens in cases such as if the user stops the process from the UI)
processed_image = processed.images[0]
if processed_result is None:
# Use our first valid processed result as a template container to hold our full results
processed_result = copy(processed)
cell_mode = processed_image.mode
cell_size = processed_image.size
processed_result.images = [Image.new(cell_mode, cell_size)]
image_cache.append(processed_image)
if include_lone_images:
processed_result.images.append(processed_image)
processed_result.all_prompts.append(processed.prompt)
processed_result.all_seeds.append(processed.seed)
processed_result.infotexts.append(processed.infotexts[0])
except: except:
res.append(Image.new(res[0].mode, res[0].size)) image_cache.append(Image.new(cell_mode, cell_size))
grid = images.image_grid(res, rows=len(ys)) if not processed_result:
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
return Processed()
grid = images.image_grid(image_cache, rows=len(ys))
if draw_legend: if draw_legend:
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
first_pocessed.images = [grid] processed_result.images[0] = grid
return first_pocessed return processed_result
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
...@@ -143,23 +251,30 @@ class Script(scripts.Script): ...@@ -143,23 +251,30 @@ class Script(scripts.Script):
x_values = gr.Textbox(label="X values", visible=False, lines=1) x_values = gr.Textbox(label="X values", visible=False, lines=1)
with gr.Row(): with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type") y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type")
y_values = gr.Textbox(label="Y values", visible=False, lines=1) y_values = gr.Textbox(label="Y values", visible=False, lines=1)
draw_legend = gr.Checkbox(label='Draw legend', value=True) draw_legend = gr.Checkbox(label='Draw legend', value=True)
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
if not no_fixed_seeds:
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
if not opts.return_grid:
p.batch_size = 1 p.batch_size = 1
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
valslist = [x.strip() for x in vals.split(",")] valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
if opt.type == int: if opt.type == int:
valslist_ext = [] valslist_ext = []
...@@ -168,7 +283,6 @@ class Script(scripts.Script): ...@@ -168,7 +283,6 @@ class Script(scripts.Script):
m = re_range.fullmatch(val) m = re_range.fullmatch(val)
mc = re_range_count.fullmatch(val) mc = re_range_count.fullmatch(val)
if m is not None: if m is not None:
start = int(m.group(1)) start = int(m.group(1))
end = int(m.group(2))+1 end = int(m.group(2))+1
step = int(m.group(3)) if m.group(3) is not None else 1 step = int(m.group(3)) if m.group(3) is not None else 1
...@@ -206,9 +320,15 @@ class Script(scripts.Script): ...@@ -206,9 +320,15 @@ class Script(scripts.Script):
valslist_ext.append(val) valslist_ext.append(val)
valslist = valslist_ext valslist = valslist_ext
elif opt.type == str_permutations:
valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]
# Confirm options are valid before starting
if opt.confirm:
opt.confirm(p, valslist)
return valslist return valslist
x_opt = axis_options[x_type] x_opt = axis_options[x_type]
...@@ -218,7 +338,7 @@ class Script(scripts.Script): ...@@ -218,7 +338,7 @@ class Script(scripts.Script):
ys = process_axis(y_opt, y_values) ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list): def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label == 'Seed': if axis_opt.label in ['Seed','Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
else: else:
return axis_list return axis_list
...@@ -234,6 +354,9 @@ class Script(scripts.Script): ...@@ -234,6 +354,9 @@ class Script(scripts.Script):
else: else:
total_steps = p.steps * len(xs) * len(ys) total_steps = p.steps * len(xs) * len(ys)
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
total_steps *= 2
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)
...@@ -251,7 +374,8 @@ class Script(scripts.Script): ...@@ -251,7 +374,8 @@ class Script(scripts.Script):
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell, cell=cell,
draw_legend=draw_legend draw_legend=draw_legend,
include_lone_images=include_lone_images
) )
if opts.grid_save: if opts.grid_save:
...@@ -260,4 +384,10 @@ class Script(scripts.Script): ...@@ -260,4 +384,10 @@ class Script(scripts.Script):
# restore checkpoint in case it was changed by axes # restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
return processed return processed
.container {
max-width: 100%;
}
#txt2img_token_counter {
height: 0px;
}
#img2img_token_counter {
height: 0px;
}
#sh{
min-width: 2em;
min-height: 2em;
max-width: 2em;
max-height: 2em;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
margin: 0.1em 0;
opacity: 0%;
cursor: default;
}
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.row > *, .row > *,
...@@ -103,7 +128,12 @@ ...@@ -103,7 +128,12 @@
#style_apply, #style_create, #interrogate{ #style_apply, #style_create, #interrogate{
margin: 0.75em 0.25em 0.25em 0.25em; margin: 0.75em 0.25em 0.25em 0.25em;
min-width: 3em; min-width: 5em;
}
#style_apply, #style_create, #deepbooru{
margin: 0.75em 0.25em 0.25em 0.25em;
min-width: 5em;
} }
#style_pos_col, #style_neg_col{ #style_pos_col, #style_neg_col{
...@@ -137,14 +167,6 @@ button{ ...@@ -137,14 +167,6 @@ button{
align-self: stretch !important; align-self: stretch !important;
} }
#prompt, #negative_prompt{
border: none !important;
}
#prompt textarea, #negative_prompt textarea{
border: none !important;
}
#img2maskimg .h-60{ #img2maskimg .h-60{
height: 30rem; height: 30rem;
} }
...@@ -157,7 +179,7 @@ button{ ...@@ -157,7 +179,7 @@ button{
max-width: 10em; max-width: 10em;
} }
#txt2img_preview, #img2img_preview{ #txt2img_preview, #img2img_preview, #ti_preview{
position: absolute; position: absolute;
width: 320px; width: 320px;
left: 0; left: 0;
...@@ -172,18 +194,18 @@ button{ ...@@ -172,18 +194,18 @@ button{
} }
@media screen and (min-width: 768px) { @media screen and (min-width: 768px) {
#txt2img_preview, #img2img_preview { #txt2img_preview, #img2img_preview, #ti_preview {
position: absolute; position: absolute;
} }
} }
@media screen and (max-width: 767px) { @media screen and (max-width: 767px) {
#txt2img_preview, #img2img_preview { #txt2img_preview, #img2img_preview, #ti_preview {
position: relative; position: relative;
} }
} }
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{ #txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none; display: none;
} }
...@@ -198,6 +220,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s ...@@ -198,6 +220,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
border-top: 1px solid #eee; border-top: 1px solid #eee;
border-left: 1px solid #eee; border-left: 1px solid #eee;
border-right: 1px solid #eee; border-right: 1px solid #eee;
z-index: 300;
} }
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
...@@ -210,6 +234,7 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s ...@@ -210,6 +234,7 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
position: relative; position: relative;
border: none; border: none;
margin-right: 8em;
} }
.gr-panel div.flex-col div.justify-between label span{ .gr-panel div.flex-col div.justify-between label span{
...@@ -247,7 +272,7 @@ input[type="range"]{ ...@@ -247,7 +272,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{ #txt2img_negative_prompt, #img2img_negative_prompt{
} }
#txt2img_progressbar, #img2img_progressbar{ #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute; position: absolute;
z-index: 1000; z-index: 1000;
right: 0; right: 0;
...@@ -393,13 +418,106 @@ input[type="range"]{ ...@@ -393,13 +418,106 @@ input[type="range"]{
#txt2img_interrupt, #img2img_interrupt{ #txt2img_interrupt, #img2img_interrupt{
position: absolute; position: absolute;
width: 100%; width: 50%;
height: 72px; height: 72px;
background: #b4c0cc; background: #b4c0cc;
border-radius: 8px; border-radius: 0px;
display: none;
}
#txt2img_skip, #img2img_skip{
position: absolute;
width: 50%;
right: 0px;
height: 72px;
background: #b4c0cc;
border-radius: 0px;
display: none; display: none;
} }
.red { .red {
color: red; color: red;
} }
.gallery-item {
--tw-bg-opacity: 0 !important;
}
#img2img_image div.h-60{
height: 480px;
}
#context-menu{
z-index:9999;
position:absolute;
display:block;
padding:0px 0;
border:2px solid #a55000;
border-radius:8px;
box-shadow:1px 1px 2px #CE6400;
width: 200px;
}
.context-menu-items{
list-style: none;
margin: 0;
padding: 0;
}
.context-menu-items a{
display:block;
padding:5px;
cursor:pointer;
}
.context-menu-items a:hover{
background: #a55000;
}
#quicksettings {
gap: 0.4em;
}
#quicksettings > div{
border: none;
background: none;
flex: unset;
gap: 0.5em;
}
#quicksettings > div > div{
max-width: 32em;
min-width: 24em;
padding: 0;
}
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
mix-blend-mode: multiply;
pointer-events: none;
}
/* gradio 3.4.1 stuff for editable scrollbar values */
.gr-box > div > div > input.gr-text-input{
position: absolute;
right: 0.5em;
top: -0.6em;
z-index: 200;
width: 8em;
}
#quicksettings .gr-box > div > div > input.gr-text-input {
top: -1.12em;
}
.row.gr-compact{
overflow: visible;
}
a photo of a [filewords]
a rendering of a [filewords]
a cropped photo of the [filewords]
the photo of a [filewords]
a photo of a clean [filewords]
a photo of a dirty [filewords]
a dark photo of the [filewords]
a photo of my [filewords]
a photo of the cool [filewords]
a close-up photo of a [filewords]
a bright photo of the [filewords]
a cropped photo of a [filewords]
a photo of the [filewords]
a good photo of the [filewords]
a photo of one [filewords]
a close-up photo of the [filewords]
a rendition of the [filewords]
a photo of the clean [filewords]
a rendition of a [filewords]
a photo of a nice [filewords]
a good photo of a [filewords]
a photo of the nice [filewords]
a photo of the small [filewords]
a photo of the weird [filewords]
a photo of the large [filewords]
a photo of a cool [filewords]
a photo of a small [filewords]
a painting, art by [name]
a rendering, art by [name]
a cropped painting, art by [name]
the painting, art by [name]
a clean painting, art by [name]
a dirty painting, art by [name]
a dark painting, art by [name]
a picture, art by [name]
a cool painting, art by [name]
a close-up painting, art by [name]
a bright painting, art by [name]
a cropped painting, art by [name]
a good painting, art by [name]
a close-up painting, art by [name]
a rendition, art by [name]
a nice painting, art by [name]
a small painting, art by [name]
a weird painting, art by [name]
a large painting, art by [name]
a painting of [filewords], art by [name]
a rendering of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
the painting of [filewords], art by [name]
a clean painting of [filewords], art by [name]
a dirty painting of [filewords], art by [name]
a dark painting of [filewords], art by [name]
a picture of [filewords], art by [name]
a cool painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a bright painting of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
a good painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a rendition of [filewords], art by [name]
a nice painting of [filewords], art by [name]
a small painting of [filewords], art by [name]
a weird painting of [filewords], art by [name]
a large painting of [filewords], art by [name]
a photo of a [name]
a rendering of a [name]
a cropped photo of the [name]
the photo of a [name]
a photo of a clean [name]
a photo of a dirty [name]
a dark photo of the [name]
a photo of my [name]
a photo of the cool [name]
a close-up photo of a [name]
a bright photo of the [name]
a cropped photo of a [name]
a photo of the [name]
a good photo of the [name]
a photo of one [name]
a close-up photo of the [name]
a rendition of the [name]
a photo of the clean [name]
a rendition of a [name]
a photo of a nice [name]
a good photo of a [name]
a photo of the nice [name]
a photo of the small [name]
a photo of the weird [name]
a photo of the large [name]
a photo of a cool [name]
a photo of a small [name]
a photo of a [name], [filewords]
a rendering of a [name], [filewords]
a cropped photo of the [name], [filewords]
the photo of a [name], [filewords]
a photo of a clean [name], [filewords]
a photo of a dirty [name], [filewords]
a dark photo of the [name], [filewords]
a photo of my [name], [filewords]
a photo of the cool [name], [filewords]
a close-up photo of a [name], [filewords]
a bright photo of the [name], [filewords]
a cropped photo of a [name], [filewords]
a photo of the [name], [filewords]
a good photo of the [name], [filewords]
a photo of one [name], [filewords]
a close-up photo of the [name], [filewords]
a rendition of the [name], [filewords]
a photo of the clean [name], [filewords]
a rendition of a [name], [filewords]
a photo of a nice [name], [filewords]
a good photo of a [name], [filewords]
a photo of the nice [name], [filewords]
a photo of the small [name], [filewords]
a photo of the weird [name], [filewords]
a photo of the large [name], [filewords]
a photo of a cool [name], [filewords]
a photo of a small [name], [filewords]
txt2img_Screenshot.png

526 KB | W: | H:

txt2img_Screenshot.png

329 KB | W: | H:

txt2img_Screenshot.png
txt2img_Screenshot.png
txt2img_Screenshot.png
txt2img_Screenshot.png
  • 2-up
  • Swipe
  • Onion skin
import os import os
import threading import threading
import time
from modules import devices import importlib
from modules.paths import script_path
import signal import signal
import threading import threading
import modules.paths
from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path
from modules import devices, sd_samplers
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
import modules.bsrgan_model as bsrgan
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
import modules.gfpgan_model as gfpgan import modules.gfpgan_model as gfpgan
import modules.img2img import modules.img2img
import modules.ldsr_model as ldsr
import modules.lowvram import modules.lowvram
import modules.realesrgan_model as realesrgan import modules.paths
import modules.scripts import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.shared as shared import modules.shared as shared
import modules.swinir_model as swinir
import modules.txt2img import modules.txt2img
import modules.ui import modules.ui
from modules import devices
from modules import modelloader from modules import modelloader
from modules.paths import script_path from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
modelloader.cleanup_models()
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -46,7 +45,7 @@ def wrap_queued_call(func): ...@@ -46,7 +45,7 @@ def wrap_queued_call(func):
return f return f
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc() devices.torch_gc()
...@@ -57,7 +56,9 @@ def wrap_gradio_gpu_call(func): ...@@ -57,7 +56,9 @@ def wrap_gradio_gpu_call(func):
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.current_image_sampling_step = 0 shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
...@@ -69,16 +70,27 @@ def wrap_gradio_gpu_call(func): ...@@ -69,16 +70,27 @@ def wrap_gradio_gpu_call(func):
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def initialize():
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model() shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui(): def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
...@@ -86,23 +98,40 @@ def webui(): ...@@ -86,23 +98,40 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
demo = modules.ui.create_ui( while 1:
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
)
demo.launch( demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app, local_url, share_url = demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None, server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port, server_port=cmd_opts.port,
debug=cmd_opts.gradio_debug, debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
) )
app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1:
time.sleep(0.5)
if getattr(demo, 'do_restart', False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
sd_samplers.set_samplers()
print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
print('Refreshing Model List')
modules.sd_models.list_models()
print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":
webui() webui()
...@@ -82,8 +82,8 @@ then ...@@ -82,8 +82,8 @@ then
clone_dir="${PWD##*/}" clone_dir="${PWD##*/}"
fi fi
# Check prequisites # Check prerequisites
for preq in git python3 for preq in "${GIT}" "${python_cmd}"
do do
if ! hash "${preq}" &>/dev/null if ! hash "${preq}" &>/dev/null
then then
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment