From 368c497638332d5fca56165af2aebf5cd07ee279 Mon Sep 17 00:00:00 2001 From: henryj18 <49120145+henryj18@users.noreply.github.com> Date: Tue, 28 May 2024 06:36:24 -0500 Subject: [PATCH 1/2] Test the codebase with timm vit --- gato/policy/embeddings.py | 54 +++++++++++++++++++++++++++++++++++++- gato/policy/gato_policy.py | 12 ++++----- gato/tasks/caption_task.py | 36 ++++++++++++++----------- gato/tasks/vqa_task.py | 23 ++++++++-------- 4 files changed, 90 insertions(+), 35 deletions(-) diff --git a/gato/policy/embeddings.py b/gato/policy/embeddings.py index 3b328d4..c0a4965 100644 --- a/gato/policy/embeddings.py +++ b/gato/policy/embeddings.py @@ -1,9 +1,60 @@ import torch import torch.nn as nn -from einops import rearrange +import timm +from PIL import Image +from torchvision import transforms +from einops import rearrange import math +class Pretrained_ImageEmbedding(nn.Module): + def __init__( + self, + embed_dim=768, + patch_size=16, + # the batch_size and embed_dim passed in here are just placeholders, the embedding model used below will use the deame default value 768 and 16 + model_name = 'vit_base_patch16_224' # You can choose different variants, patch size is 16*16, image size is 224*224 + + ): + super().__init__() + # Load a pre-trained Vision Transformer model + self.model = timm.create_model(model_name, pretrained=True) + self.model.eval() # Set the model to evaluation mode + + # Define the image transformation pipeline + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), # Resize the image to the required input size + transforms.ToTensor(), # Convert the image to a PyTorch tensor + transforms.Normalize( + mean=(0.485, 0.456, 0.406), # Normalization mean (ImageNet values) + std=(0.229, 0.224, 0.225) # Normalization std (ImageNet values) + ), + ]) + + self.embedding = None + # Define a hook function to capture the embedding + def hook(module, input, output): + global embedding + embedding = output + + # Register the hook to the patch embedding layer. + self.hook_handle = self.model.patch_embed.register_forward_hook(hook) + + # After all data processed, remember to call image_embedding.cleanup() which is defined below + # to unregister the hook, assuming the instantiated object from this class is named "image_embedding" + # But this might not be necessary because image_embedding might need to be used even after all data processed? + def cleanup(self): + #unregister the hook + self.hook_handle.remove() + + def forward(self, img): # the img here is the orignal image returned from Image.open(), it is not the image data tensor + img = self.transform(img).unsqueeze(0) # Add a batch dimension + # Perform a forward pass to trigger the hook + with torch.no_grad(): + _ = self.model(img) + + # Return the captured embedding + return embedding class ImageEmbedding(nn.Module): def __init__( @@ -129,3 +180,4 @@ def forward(self, x): h = self.conv1(self.act1(self.gn1(x))) h = self.conv2(self.act2(self.gn2(h))) return x + h + \ No newline at end of file diff --git a/gato/policy/gato_policy.py b/gato/policy/gato_policy.py index c42fd69..5aef4f2 100644 --- a/gato/policy/gato_policy.py +++ b/gato/policy/gato_policy.py @@ -474,9 +474,8 @@ def predict_text(self, batch_dict, max_length=20, deterministic=True): # This funciton can be used to generate a response from an image, such as generating the caption or # an answer to a question about an image, it is adapted from the original predict_caption() function - def predict_response(self, image, prompt_tokens = [], max_length=128, deterministic=True): + def predict_response(self, image_embeddings, prompt_tokens = [], max_length=128, deterministic=True): """ - image is in the format of 1 x 3 x H x W, where 1 is the num_images, 3 is the 3 RGB channels, default value for H and W is 256 prompt_tokens is a list of text tokens: if the predicted response is the caption for the image, then it is an empty list if the predicted response is an answer to a question about the image, then this list are the tokens of the question @@ -486,7 +485,6 @@ def predict_response(self, image, prompt_tokens = [], max_length=128, determinis start_token = self.token_starts[action_str] end_token = self.token_ends[action_str] - image_embeddings = self.image_embedding(image.to(self.device)) # the image embedding that will be used to generate response n_images = image_embeddings.shape[0] n_patches = image_embeddings.shape[1] assert n_images == 1, "number of images should always be 1 for predicting response" @@ -543,13 +541,13 @@ def predict_response(self, image, prompt_tokens = [], max_length=128, determinis return pred_logits, pred_response - def predict_caption(self, image, max_length=128, deterministic=True): - pred_logits, pred_caption = self.predict_response(image, prompt_tokens = [], max_length=max_length, deterministic=deterministic) + def predict_caption(self, image_embeddings, max_length=128, deterministic=True): + pred_logits, pred_caption = self.predict_response(image_embeddings, prompt_tokens = [], max_length=max_length, deterministic=deterministic) return pred_logits, pred_caption - def predict_answer(self, image, question, max_length=16, deterministic=True): + def predict_answer(self, image_embeddings, question, max_length=16, deterministic=True): prompt_tokens = self.text_tokenizer.encode(question) - pred_logits, pred_answer = self.predict_response(image, prompt_tokens = prompt_tokens, max_length=max_length, deterministic=deterministic) + pred_logits, pred_answer = self.predict_response(image_embeddings, prompt_tokens = prompt_tokens, max_length=max_length, deterministic=deterministic) return pred_logits, pred_answer # infer how many tokens needed to generate using environment, and restrict tokens generated to valid tokens for env diff --git a/gato/tasks/caption_task.py b/gato/tasks/caption_task.py index 54a91d0..c2ff867 100644 --- a/gato/tasks/caption_task.py +++ b/gato/tasks/caption_task.py @@ -1,5 +1,6 @@ # Assume all datasets are downloaded and available from local directories from gato.tasks.task import Task +from gato.policy.embeddings import Pretrained_ImageEmbedding import os import tarfile @@ -43,6 +44,7 @@ def __init__(self, tokenizer_model:str, caption_dataset, train_data, test_data = assert len(train_data) > 0, "Must provide train datasets for caption task" self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) self.dataset = {} + self.image_embedding = Pretrained_ImageEmbedding() if len(test_data) > 0: # Note: len(train_data_directories)>0 also holds due to the abpve-mentioned assert self.dataset['train'] = self.process_data(caption_dataset, train_data) @@ -87,22 +89,23 @@ def process_data(self, caption_dataset, data_directories): # Iterate through all of the bundles to extract jpg and txt (caption) and place them into the desiganted data structure for idx, bundle in enumerate(data_loader): item = {} - img = Image.open(io.BytesIO(bundle['jpg'][0])) # bundle['jpg'] is a list of length 1 - img_data = np.asarray(img) - # Through testing of processing multiple .tar files, we have figured out that we need "try except" in the following - # because sometimes the img_data is only (256, 256) insetad of (256, 256,3) (assuming all image sizes are 256x256) - # and the following transpose will raise an error and everything grinds to a halt. It is perhaps a bug in the "img2dataset" unitlity - # used to downlaod datasets into tar files. When such error occurs, we just ignore the current bundle and move to the next one - try: - img_data = img_data.transpose(2, 0, 1) # reshape from (256, 256, 3) to (3, 256, 256) + try: + # place the following into this try block in case it failes to open the image, we just imgnore this sameple and continue + img = Image.open(io.BytesIO(bundle['jpg'][0])) # bundle['jpg'] is a list of length 1 + img_data = np.asarray(img) + + # Through testing of processing multiple .tar files, we have seen that normally imag_data should be of the shape + # (256, 256, 3) (assuming all image sizes are 256x256), so img_data.ndim is 3. But sometimes, we see img_data is + # of the shape (256, 256) only, it is perhaps a bug in the "img2dataset" unitlity used to downlaod datasets into tar files. + # When such error occurs, we need to catch it and ignore the current bundle and move to the next one. + # If this is not caught, it will throw exception and everything grinds to a halt + if img_data.ndim < 3: + continue except: continue - - # Need to add a new dimension to (3, 256, 256) so it becomes (1, 3, 256, 256) where the added dummy dimension at dim 0 is the num_images. - # In this case, num_images is always 1. This is for the purpose of aligning the data structure with that in the model training - item['image'] = torch.tensor(img_data[np.newaxis, :]) + item['image'] = img item['text'] = bundle['txt'][0].decode('utf-8') - dataset.append(item) + dataset.append(item) return dataset def sample_batch(self, batch_size): @@ -112,7 +115,10 @@ def sample_batch(self, batch_size): batch_dicts = [] for item in selected_examples: batch_dict = { - 'images': item['image'], + # By default, image will be resized to 224*224, patch size is 16, image_embedding() will return embedding tensor + # of the shape (1, 196, 768), where 1 is an added batch dimension, 196 is the numbe of patches, which is the + # result of 14 multiplied by 14, where 14 is the result of image size divided by patch size, i.e. 224/16=14 + 'image_imbeddings': self.image_embedding(item['image']), 'text':self.text_tokenizer.encode(item['text']) } batch_dicts.append(batch_dict) @@ -141,7 +147,7 @@ def evaluate(self, model, num_examples_to_test=50, deterministic=True, log_examp target_tokens = tokenizer.encode(target_caption) # Generate prediction - pred_logits, pred_caption = model.module.predict_caption(image, max_length = len(target_tokens),deterministic=deterministic) + pred_logits, pred_caption = model.module.predict_caption(self.image_embedding(image).to(model.device), max_length = len(target_tokens),deterministic=deterministic) if log_examples_to_output and idx%10==0: print(f'Target caption: {target_caption} \n Predicted caption : {pred_caption}') print("----") diff --git a/gato/tasks/vqa_task.py b/gato/tasks/vqa_task.py index 16b1f58..db79066 100644 --- a/gato/tasks/vqa_task.py +++ b/gato/tasks/vqa_task.py @@ -1,9 +1,8 @@ # Assume all datasets are downloaded and available from local directories from gato.tasks.task import Task - +from gato.policy.embeddings import Pretrained_ImageEmbedding import os from PIL import Image -import io # need to use BytesIO import numpy as np import math @@ -36,6 +35,8 @@ def __init__(self, tokenizer_model:str, self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) self.dataset = {} + self.image_embedding = Pretrained_ImageEmbedding() + if not vqa_dataset.endswith('/'): vqa_dataset = vqa_dataset + '/' @@ -66,16 +67,11 @@ def process_data(self, vqa_dataset, data_directories, img_name_prefix, img_file_ image_id = str(annotations[idx]['image_id']) img_file_name = img_name_prefix[dir_idx] + '0' * (img_file_name_len[dir_idx] - len(image_id) - len(img_name_prefix[dir_idx])) + image_id + '.jpg' try: - # if the image file does not exist or transpose does not work due to damaged data, we simply discard this sample and move to next - img = Image.open(data_directory + img_file_name) - img= img.resize((256, 256)) - img_data = np.asarray(img) - img_data = img_data.transpose(2, 0, 1) # reshape from (256, 256, 3) to (3, 256, 256) + # if the image file does not exist or some other error occurs, we simply discard this sample and move to next + img = Image.open(data_directory + img_file_name) except: continue - # Need to add a new dimension to (3, 256, 256) so it becomes (1, 3, 256, 256) where the added dummy dimension at dim 0 is the num_images. - # In this case, num_images is always 1. This is for the purpose of aligning the data structure with that in the model training - item['image'] = torch.tensor(img_data[np.newaxis, :]) + item['image'] = img item['question'] = questions[idx]['question'] item['answers'] = annotations[idx]['answers'] dataset.append(item) @@ -90,7 +86,10 @@ def sample_batch(self, batch_size): for item in selected_examples: answer_idx = random.randint(0, len(item['answers'])-1) # randomly choose an answer out of the set of answers batch_dict = { - 'images': item['image'], + # By default, image will be resized to 224*224, patch size is 16, image_embedding() will return embedding tensor + # of the shape (1, 196, 768), where 1 is an added batch dimension, 196 is the numbe of patches, which is the + # result of 14 multiplied by 14, where 14 is the result of image size divided by patch size, i.e. 224/16=14 + 'image_imbeddings': self.image_embedding(item['image']), # 'text' is to concat the question and a randomly chosen answer with a space in between 'text': self.text_tokenizer.encode(item['question'] + ' ' + item['answers'][answer_idx]['answer']) } @@ -121,7 +120,7 @@ def evaluate(self, model, num_examples_to_test=50, deterministic=True, log_examp target_tokens = tokenizer.encode(target_answer) # Generate prediction - pred_logits, pred_answer = model.predict_answer(image, question, max_length = len(target_tokens),deterministic=deterministic) + pred_logits, pred_answer = model.predict_answer(self.image_embedding(image).to(model.device), question, max_length = len(target_tokens),deterministic=deterministic) if log_examples_to_output and idx%10==0: print(f'Target answer: {target_answer} \n Predicted answer : {pred_answer}') print("----") From bb7d42b9cf58c3931201f9eecb4cdf40e4f1f993 Mon Sep 17 00:00:00 2001 From: henryj18 <49120145+henryj18@users.noreply.github.com> Date: Tue, 28 May 2024 22:26:52 -0500 Subject: [PATCH 2/2] Test another method to extract image embeddings from timm vit --- gato/policy/embeddings.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gato/policy/embeddings.py b/gato/policy/embeddings.py index c0a4965..f7c3c35 100644 --- a/gato/policy/embeddings.py +++ b/gato/policy/embeddings.py @@ -31,6 +31,8 @@ def __init__( ), ]) + """ + # This commented-out block is the first method we tested to extract the image embedding self.embedding = None # Define a hook function to capture the embedding def hook(module, input, output): @@ -55,6 +57,14 @@ def forward(self, img): # the img here is the orignal image returned from Image. # Return the captured embedding return embedding + """ + # The following is the second method we tested to extract the image embedding + def forward(self, img): # the img here is the orignal image returned from Image.open(), it is not the image data tensor + img = self.transform(img).unsqueeze(0) # Add a batch dimension + # Perform a forward pass to trigger the hook + with torch.no_grad(): # Disable gradient computation + embedding = self.model.forward_features(img) # Get the features from the model, this is the image embedding + return embedding class ImageEmbedding(nn.Module): def __init__(