Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion gato/policy/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,70 @@
import torch
import torch.nn as nn
from einops import rearrange
import timm
from PIL import Image
from torchvision import transforms

from einops import rearrange
import math

class Pretrained_ImageEmbedding(nn.Module):
def __init__(
self,
embed_dim=768,
patch_size=16,
# the batch_size and embed_dim passed in here are just placeholders, the embedding model used below will use the deame default value 768 and 16
model_name = 'vit_base_patch16_224' # You can choose different variants, patch size is 16*16, image size is 224*224

):
super().__init__()
# Load a pre-trained Vision Transformer model
self.model = timm.create_model(model_name, pretrained=True)
self.model.eval() # Set the model to evaluation mode

# Define the image transformation pipeline
self.transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize the image to the required input size
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize(
mean=(0.485, 0.456, 0.406), # Normalization mean (ImageNet values)
std=(0.229, 0.224, 0.225) # Normalization std (ImageNet values)
),
])

"""
# This commented-out block is the first method we tested to extract the image embedding
self.embedding = None
# Define a hook function to capture the embedding
def hook(module, input, output):
global embedding
embedding = output

# Register the hook to the patch embedding layer.
self.hook_handle = self.model.patch_embed.register_forward_hook(hook)

# After all data processed, remember to call image_embedding.cleanup() which is defined below
# to unregister the hook, assuming the instantiated object from this class is named "image_embedding"
# But this might not be necessary because image_embedding might need to be used even after all data processed?
def cleanup(self):
#unregister the hook
self.hook_handle.remove()

def forward(self, img): # the img here is the orignal image returned from Image.open(), it is not the image data tensor
img = self.transform(img).unsqueeze(0) # Add a batch dimension
# Perform a forward pass to trigger the hook
with torch.no_grad():
_ = self.model(img)

# Return the captured embedding
return embedding
"""
# The following is the second method we tested to extract the image embedding
def forward(self, img): # the img here is the orignal image returned from Image.open(), it is not the image data tensor
img = self.transform(img).unsqueeze(0) # Add a batch dimension
# Perform a forward pass to trigger the hook
with torch.no_grad(): # Disable gradient computation
embedding = self.model.forward_features(img) # Get the features from the model, this is the image embedding
return embedding

class ImageEmbedding(nn.Module):
def __init__(
Expand Down Expand Up @@ -129,3 +190,4 @@ def forward(self, x):
h = self.conv1(self.act1(self.gn1(x)))
h = self.conv2(self.act2(self.gn2(h)))
return x + h

12 changes: 5 additions & 7 deletions gato/policy/gato_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,8 @@ def predict_text(self, batch_dict, max_length=20, deterministic=True):

# This funciton can be used to generate a response from an image, such as generating the caption or
# an answer to a question about an image, it is adapted from the original predict_caption() function
def predict_response(self, image, prompt_tokens = [], max_length=128, deterministic=True):
def predict_response(self, image_embeddings, prompt_tokens = [], max_length=128, deterministic=True):
"""
image is in the format of 1 x 3 x H x W, where 1 is the num_images, 3 is the 3 RGB channels, default value for H and W is 256
prompt_tokens is a list of text tokens:
if the predicted response is the caption for the image, then it is an empty list
if the predicted response is an answer to a question about the image, then this list are the tokens of the question
Expand All @@ -486,7 +485,6 @@ def predict_response(self, image, prompt_tokens = [], max_length=128, determinis
start_token = self.token_starts[action_str]
end_token = self.token_ends[action_str]

image_embeddings = self.image_embedding(image.to(self.device)) # the image embedding that will be used to generate response
n_images = image_embeddings.shape[0]
n_patches = image_embeddings.shape[1]
assert n_images == 1, "number of images should always be 1 for predicting response"
Expand Down Expand Up @@ -543,13 +541,13 @@ def predict_response(self, image, prompt_tokens = [], max_length=128, determinis

return pred_logits, pred_response

def predict_caption(self, image, max_length=128, deterministic=True):
pred_logits, pred_caption = self.predict_response(image, prompt_tokens = [], max_length=max_length, deterministic=deterministic)
def predict_caption(self, image_embeddings, max_length=128, deterministic=True):
pred_logits, pred_caption = self.predict_response(image_embeddings, prompt_tokens = [], max_length=max_length, deterministic=deterministic)
return pred_logits, pred_caption

def predict_answer(self, image, question, max_length=16, deterministic=True):
def predict_answer(self, image_embeddings, question, max_length=16, deterministic=True):
prompt_tokens = self.text_tokenizer.encode(question)
pred_logits, pred_answer = self.predict_response(image, prompt_tokens = prompt_tokens, max_length=max_length, deterministic=deterministic)
pred_logits, pred_answer = self.predict_response(image_embeddings, prompt_tokens = prompt_tokens, max_length=max_length, deterministic=deterministic)
return pred_logits, pred_answer

# infer how many tokens needed to generate using environment, and restrict tokens generated to valid tokens for env
Expand Down
36 changes: 21 additions & 15 deletions gato/tasks/caption_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Assume all datasets are downloaded and available from local directories
from gato.tasks.task import Task
from gato.policy.embeddings import Pretrained_ImageEmbedding

import os
import tarfile
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self, tokenizer_model:str, caption_dataset, train_data, test_data =
assert len(train_data) > 0, "Must provide train datasets for caption task"
self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
self.dataset = {}
self.image_embedding = Pretrained_ImageEmbedding()

if len(test_data) > 0: # Note: len(train_data_directories)>0 also holds due to the abpve-mentioned assert
self.dataset['train'] = self.process_data(caption_dataset, train_data)
Expand Down Expand Up @@ -87,22 +89,23 @@ def process_data(self, caption_dataset, data_directories):
# Iterate through all of the bundles to extract jpg and txt (caption) and place them into the desiganted data structure
for idx, bundle in enumerate(data_loader):
item = {}
img = Image.open(io.BytesIO(bundle['jpg'][0])) # bundle['jpg'] is a list of length 1
img_data = np.asarray(img)
# Through testing of processing multiple .tar files, we have figured out that we need "try except" in the following
# because sometimes the img_data is only (256, 256) insetad of (256, 256,3) (assuming all image sizes are 256x256)
# and the following transpose will raise an error and everything grinds to a halt. It is perhaps a bug in the "img2dataset" unitlity
# used to downlaod datasets into tar files. When such error occurs, we just ignore the current bundle and move to the next one
try:
img_data = img_data.transpose(2, 0, 1) # reshape from (256, 256, 3) to (3, 256, 256)
try:
# place the following into this try block in case it failes to open the image, we just imgnore this sameple and continue
img = Image.open(io.BytesIO(bundle['jpg'][0])) # bundle['jpg'] is a list of length 1
img_data = np.asarray(img)

# Through testing of processing multiple .tar files, we have seen that normally imag_data should be of the shape
# (256, 256, 3) (assuming all image sizes are 256x256), so img_data.ndim is 3. But sometimes, we see img_data is
# of the shape (256, 256) only, it is perhaps a bug in the "img2dataset" unitlity used to downlaod datasets into tar files.
# When such error occurs, we need to catch it and ignore the current bundle and move to the next one.
# If this is not caught, it will throw exception and everything grinds to a halt
if img_data.ndim < 3:
continue
except:
continue

# Need to add a new dimension to (3, 256, 256) so it becomes (1, 3, 256, 256) where the added dummy dimension at dim 0 is the num_images.
# In this case, num_images is always 1. This is for the purpose of aligning the data structure with that in the model training
item['image'] = torch.tensor(img_data[np.newaxis, :])
item['image'] = img
item['text'] = bundle['txt'][0].decode('utf-8')
dataset.append(item)
dataset.append(item)
return dataset

def sample_batch(self, batch_size):
Expand All @@ -112,7 +115,10 @@ def sample_batch(self, batch_size):
batch_dicts = []
for item in selected_examples:
batch_dict = {
'images': item['image'],
# By default, image will be resized to 224*224, patch size is 16, image_embedding() will return embedding tensor
# of the shape (1, 196, 768), where 1 is an added batch dimension, 196 is the numbe of patches, which is the
# result of 14 multiplied by 14, where 14 is the result of image size divided by patch size, i.e. 224/16=14
'image_imbeddings': self.image_embedding(item['image']),
'text':self.text_tokenizer.encode(item['text'])
}
batch_dicts.append(batch_dict)
Expand Down Expand Up @@ -141,7 +147,7 @@ def evaluate(self, model, num_examples_to_test=50, deterministic=True, log_examp
target_tokens = tokenizer.encode(target_caption)

# Generate prediction
pred_logits, pred_caption = model.module.predict_caption(image, max_length = len(target_tokens),deterministic=deterministic)
pred_logits, pred_caption = model.module.predict_caption(self.image_embedding(image).to(model.device), max_length = len(target_tokens),deterministic=deterministic)
if log_examples_to_output and idx%10==0:
print(f'Target caption: {target_caption} \n Predicted caption : {pred_caption}')
print("----")
Expand Down
23 changes: 11 additions & 12 deletions gato/tasks/vqa_task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Assume all datasets are downloaded and available from local directories
from gato.tasks.task import Task

from gato.policy.embeddings import Pretrained_ImageEmbedding
import os
from PIL import Image
import io # need to use BytesIO

import numpy as np
import math
Expand Down Expand Up @@ -36,6 +35,8 @@ def __init__(self, tokenizer_model:str,

self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
self.dataset = {}
self.image_embedding = Pretrained_ImageEmbedding()

if not vqa_dataset.endswith('/'):
vqa_dataset = vqa_dataset + '/'

Expand Down Expand Up @@ -66,16 +67,11 @@ def process_data(self, vqa_dataset, data_directories, img_name_prefix, img_file_
image_id = str(annotations[idx]['image_id'])
img_file_name = img_name_prefix[dir_idx] + '0' * (img_file_name_len[dir_idx] - len(image_id) - len(img_name_prefix[dir_idx])) + image_id + '.jpg'
try:
# if the image file does not exist or transpose does not work due to damaged data, we simply discard this sample and move to next
img = Image.open(data_directory + img_file_name)
img= img.resize((256, 256))
img_data = np.asarray(img)
img_data = img_data.transpose(2, 0, 1) # reshape from (256, 256, 3) to (3, 256, 256)
# if the image file does not exist or some other error occurs, we simply discard this sample and move to next
img = Image.open(data_directory + img_file_name)
except:
continue
# Need to add a new dimension to (3, 256, 256) so it becomes (1, 3, 256, 256) where the added dummy dimension at dim 0 is the num_images.
# In this case, num_images is always 1. This is for the purpose of aligning the data structure with that in the model training
item['image'] = torch.tensor(img_data[np.newaxis, :])
item['image'] = img
item['question'] = questions[idx]['question']
item['answers'] = annotations[idx]['answers']
dataset.append(item)
Expand All @@ -90,7 +86,10 @@ def sample_batch(self, batch_size):
for item in selected_examples:
answer_idx = random.randint(0, len(item['answers'])-1) # randomly choose an answer out of the set of answers
batch_dict = {
'images': item['image'],
# By default, image will be resized to 224*224, patch size is 16, image_embedding() will return embedding tensor
# of the shape (1, 196, 768), where 1 is an added batch dimension, 196 is the numbe of patches, which is the
# result of 14 multiplied by 14, where 14 is the result of image size divided by patch size, i.e. 224/16=14
'image_imbeddings': self.image_embedding(item['image']),
# 'text' is to concat the question and a randomly chosen answer with a space in between
'text': self.text_tokenizer.encode(item['question'] + ' ' + item['answers'][answer_idx]['answer'])
}
Expand Down Expand Up @@ -121,7 +120,7 @@ def evaluate(self, model, num_examples_to_test=50, deterministic=True, log_examp
target_tokens = tokenizer.encode(target_answer)

# Generate prediction
pred_logits, pred_answer = model.predict_answer(image, question, max_length = len(target_tokens),deterministic=deterministic)
pred_logits, pred_answer = model.predict_answer(self.image_embedding(image).to(model.device), question, max_length = len(target_tokens),deterministic=deterministic)
if log_examples_to_output and idx%10==0:
print(f'Target answer: {target_answer} \n Predicted answer : {pred_answer}')
print("----")
Expand Down