We have been through so many variations of ControlNet so we’ll now understand how to implement this framework ourselves.
Let’s begin with creating our custom ControlNet pipeline which takes several important components to be initialized. The various components can be seen in the code below:
class ControlNetDiffusionPipelineCustom:
"""custom implementation of the ControlNet Diffusion Pipeline"""
def __init__(self, vae, tokenizer, text_encoder,
unet, controller, scheduler, image_processor,
control_image_processor):
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.unet = unet
self.scheduler = scheduler
self.controlnet = controlnet
self.image_processor = image_processor
self.control_image_processor = control_image_processor
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
We will now define methods to get text embeddings and prompt embeddings:
def get_text_embeds(self, text):
"""returns embeddings for the given `text`"""
# tokenize the text
text_input = self.tokenizer(text,
padding='max_length',
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
# embed the text
with torch.no_grad():
text_embeds = self.text_encoder(text_input.input_ids.to(self.device))[0]
return text_embeds
def get_prompt_embeds(self, prompt):
"""returns prompt embeddings based on classifier free guidance"""
if isinstance(prompt, str):
prompt = [prompt]
# get conditional prompt embeddings
cond_embeds = self.get_text_embeds(prompt)
# get unconditional prompt embeddings
uncond_embeds = self.get_text_embeds([''] * len(prompt))
# concatenate the above 2 embeds
prompt_embeds = torch.cat([uncond_embeds, cond_embeds])
return prompt_embeds
The get_prompt_embeds()
method utilizes the get_text_embeds()
method to generate embeddings for doing classifier-free guidance.
We now define a method to post-process images for us. This method takes the raw output by the VAE and converts it to the PIL image format:
def transform_image(self, image):
"""convert image from pytorch tensor to PIL format"""
image = self.image_processor.postprocess(image, output_type='pil')
return image
Next, we define a method that samples a random image latent of appropriate shape as follows and scales it for the scheduler:
def get_initial_latents(self, height, width, num_channels_latents, batch_size):
"""returns noise latent tensor of relevant shape scaled by the scheduler"""
image_latents = torch.randn((batch_size,
num_channels_latents,
height // 8,
width // 8)).to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
image_latents = image_latents * self.scheduler.init_noise_sigma
return image_latents
Next, we come to the most important method which will denoise the latents to create the actual image:
def denoise_latents(self, prompt_embeds, controlnet_image,
timesteps, latents, guidance_scale=7.5):
"""denoises latents from noisy latent to a meaningful latent as conditioned by controlnet"""
# use autocast for automatic mixed precision (AMP) inference
with autocast('cuda'):
for i, t in tqdm(enumerate(timesteps)):
# duplicate image latents to do classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
control_model_input = latents
controlnet_prompt_embeds = prompt_embeds
# get output from the control net blocks
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
return_dict=False,
# predict noise residuals
with torch.no_grad():
noise_pred = self.unet(
latent_model_input,
encoder_hidden_states=prompt_embeds,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)['sample']
# separate predictions for unconditional and conditional outputs
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# perform guidance
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# remove the noise from the current sample i.e. go from x_t to x_{t-1}
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
return latents
We use autocasting to use mixed precision which lets us do inference faster. During each timestep, we duplicate the image latents because we’re doing classifier-free guidance, then make the input latents and prompt embeds for ControlNet and pass it to the ControlNet model and get the outputs for the downsampling block and the middle block. We use these outputs and pass them as additional residuals to our U-Net which automatically adds these residuals to the down blocks (the middle and the decoder blocks) in the U-Net as we previously discussed in the ControlNet diagram. This gives us the noise predictions for the conditional and unconditional inputs on which we do classifier-free guidance to get the final noise predictions. This predicted noise is removed from the latents through the scheduler.step()
method.
Now all there’s left is to define a helper method for handling preprocessing:
def prepare_controlnet_image(self, image, height, width):
"""preprocesses the controlnet image"""
# process the image
image = self.control_image_processor.preprocess(image, height, width).to(dtype=torch.float32)
# send image to CUDA
image = image.to(self.device)
# repeat the image for classifier free guidance
image = torch.cat([image] * 2)
return image
We're utilizing the control_image_processor
component to preprocess.
Finally, we’ll utilize and combine all the helper methods we just wrote into a single ready-to-use method below. It should be easy for you to follow the code as we have already covered what all the previously defined methods do:
def __call__(self, prompt, image, num_inference_steps=20,
guidance_scale=7.5, height=512, width=512):
"""generates new image based on the `prompt` and the `image`"""
# encode input prompt
prompt_embeds = self.get_prompt_embeds(prompt)
# prepare image for controlnet
controlnet_image = self.prepare_controlnet_image(image, height, width)
height, width = controlnet_image.shape[-2:]
# prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# prepare the initial image in the latent space (noise on which we will do reverse diffusion)
num_channels_latents = self.unet.config.in_channels
batch_size = prompt_embeds.shape[0] // 2
latents = self.get_initial_latents(height, width, num_channels_latents, batch_size)
# denoise latents
latents = self.denoise_latents(prompt_embeds,
controlnet_image,
timesteps,
latents,
guidance_scale)
# decode latents to get the image into pixel space
latents = latents.to(torch.float16) # change dtype of latents since
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# convert to PIL Image format
image = image.detach() # detach to remove any computed gradients
image = self.transform_image(image)
return image
An important thing to notice is before converting the output of the VAE’s decoder into the PIL image (i.e., transform_image()
method) we need to remove any gradients that might be associated with it. This is because PyTorch doesn’t allow tensors with gradients to be converted to the intermediary numpy arrays.
Now that we’re done coding our custom pipeline, let’s test it. We will first need all the components for our pipeline which can be extracted from any of the Hugging Face pipelines we used earlier. We are going to continue with the OpenPose variant and use the same example image: