From da6b37042a74aca39d24cd5f581d0d853361336b Mon Sep 17 00:00:00 2001 From: Christopher Date: Wed, 28 Feb 2024 18:21:58 -0500 Subject: [PATCH 1/4] Fix printing logging --- .../algorithms/llama_2_single_kdma_adm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index c3fdc12d..cd106eb5 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -124,11 +124,11 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec def load_model(self, model=None, tokenizer=None): assert (model is None) == (tokenizer is None), "model and tokenizer must both be None or both be not None." if model is not None: - print('Loading model and tokenizer from provided objects.') + log.info('Loading model and tokenizer from provided objects.') self.model = model self.tokenizer = tokenizer else: - print('Loading model:', self.hf_model) + log.info('Loading model: %s', self.hf_model) if self.device == 'auto': self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision, device_map='auto') else: @@ -282,7 +282,7 @@ def respond_to_dialog(self, dialog, prefix=None): else: new_dialog.append(message) dialog = new_dialog - print('INPUT\n', dialog) + log.info('INPUT\n %s', dialog) prompt_tokens = [self.tokenizer.apply_chat_template(dialog, tokenize=True)] inference_pair['input'] = self.tokenizer.apply_chat_template(dialog, tokenize=False) @@ -298,11 +298,11 @@ def respond_to_dialog(self, dialog, prefix=None): outputs = self.model.generate(prompt_tokens, return_dict_in_generate=True, output_scores=True, max_new_tokens=512, temperature=self.temperature, do_sample=True) - # Print the generated model output + # log.info the generated model output generated_output = self.tokenizer.decode(outputs.sequences[0][prompt_length:]) inference_pair['output'] = generated_output - print('INFERENCE PAIR\n', inference_pair) + log.info('INFERENCE PAIR\n %s', inference_pair) return generated_output, inference_pair @@ -422,7 +422,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam if not good_parse: reasoning, answer_idx, parse_method = Llama2SingleKDMAADM.bert_similarity_parse(high_response, shuffled_choices) - print('CHOSEN ANSWER IDX', answer_idx, shuffled_choices) + log.info('CHOSEN ANSWER IDX %s %s', answer_idx, shuffled_choices) assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}' responses.append({ @@ -594,10 +594,10 @@ def parse_generated_output(generated_output, n_choices): @staticmethod def bert_similarity_parse(generated_output, choices): - print('BERT SIMILARITY PARSE') + log.info('BERT SIMILARITY PARSE') force_choice_func = build_force_choice_func('bert') answer_idx, _ = force_choice_func(generated_output, choices) - print('ANSWER IDX', answer_idx, type(answer_idx)) + log.info('ANSWER IDX %s %s', answer_idx, type(answer_idx)) return generated_output, answer_idx, 'bert_similarity' @staticmethod From f303ce1879f23779f7e81fbabcfc0e033deb8bb2 Mon Sep 17 00:00:00 2001 From: Christopher Date: Fri, 1 Mar 2024 15:30:09 -0500 Subject: [PATCH 2/4] In progress --- .../algorithms/llama_2_single_kdma_adm.py | 50 ++++++++++++++++++- align_system/evaluation/adm_evaluator.py | 2 +- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index cd106eb5..d911dbd3 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -3,6 +3,7 @@ import random import os import pathlib +import random from align_system.algorithms.lib.aligned_decision_maker import AlignedDecisionMaker from jinja2.exceptions import TemplateError @@ -422,7 +423,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam if not good_parse: reasoning, answer_idx, parse_method = Llama2SingleKDMAADM.bert_similarity_parse(high_response, shuffled_choices) - log.info('CHOSEN ANSWER IDX %s %s', answer_idx, shuffled_choices) + log.explain('CHOSEN ANSWER IDX %s %s', answer_idx, shuffled_choices) assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}' responses.append({ @@ -749,12 +750,59 @@ def run_aligned_decision_maker_with_voting( break return reasoning, answer_idx, responses, inference_pairs + + def format_single_incontext_prompt(self, sample): + prompt = sample['scenario'] + if sample['state'] is not None: + prompt += f'\n{sample["state"]}' + + choices = sample['choices'] + + labels = kwargs.get('labels', {}) + + alignment_target = None + if target_kdma_values is not None: + target_kdma = next(iter(next(iter(filter(lambda x: len(x) > 0, labels))))) # get the frist key of the first label that is not empty + + for label in labels: + assert len(label) == 0 or (target_kdma in label and len(label) == 1), f'All labels must have the same KDMA: labels={labels}' + + alignment_target = { + target_kdma: target_kdma_values[target_kdma] + } + + + #TODO: add prompt completetion here for choices as well. + def __call__(self, sample, target_kdma_values, **kwargs): + """ Build the prompt and send to the LLM to ask for a single KDMA + + + """ prompt = sample['scenario'] if sample['state'] is not None: prompt += f'\n{sample["state"]}' + if 'incontext' in kwargs: + possible_samples = [] + + for sam in kwargs['dataset']: + if sam['probe_id'] != sample['probe_id']: + possible_samples.append(sam) + + if len(possible_samples) < kwargs['incontext']['number']: + raise(f'Not enough possible incontext samples to learn from here. + Only {len(possible_samples)} samples while asking for + {kwargs['incontext']['number']} in context samples') + + if kwargs['incontext']['method'] == 'random': + chosen_sample = random.sample(possible_samples, kwargs['incontext']['number']) + else: + raise(f'"{kwargs['incontext']['method']}" is not a valid incontext method. Please use "random", ') + + + if 'retriever' in kwargs: # retriever_prompt = "How would you treat the following injuries: {}".format(prompt) retriever_prompt = "{} {}".format(prompt, sample['probe']) diff --git a/align_system/evaluation/adm_evaluator.py b/align_system/evaluation/adm_evaluator.py index ff7eda9f..4a742e4f 100644 --- a/align_system/evaluation/adm_evaluator.py +++ b/align_system/evaluation/adm_evaluator.py @@ -11,7 +11,7 @@ def generate_outputs(dataset, adm, target_kdma_values, **kwargs): }) continue - outputs.append(adm(input_, target_kdma_values, labels=label, **kwargs)) + outputs.append(adm(input_, target_kdma_values, labels=label, dataset=dataset, **kwargs)) return outputs From 805ae3800da631e4777f8a1dd0bb5e9905a0fd03 Mon Sep 17 00:00:00 2001 From: Christopher Date: Fri, 1 Mar 2024 19:03:37 -0500 Subject: [PATCH 3/4] First working cut of incontext learning --- .../algorithms/llama_2_single_kdma_adm.py | 58 ++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index d911dbd3..bb0edb23 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -114,6 +114,7 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec self.hf_model = hf_model self.temperature = temperature self.chat_template = kwargs.get('chat_template', None) + self.dataset = [] assert precision in ['full', 'half'], "precision must be either 'full' or 'half'." self.precision = torch.float32 if precision == 'full' else torch.float16 @@ -403,6 +404,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, n_positive_sam shuffled_choices, system_message=system_message) + if not logged_aligned_dialog: log.debug("[bold]*ALIGNED DIALOG*[/bold]", extra={"markup": True}) @@ -751,25 +753,17 @@ def run_aligned_decision_maker_with_voting( return reasoning, answer_idx, responses, inference_pairs - def format_single_incontext_prompt(self, sample): + def format_single_incontext_prompt(self, sample, labels): prompt = sample['scenario'] if sample['state'] is not None: prompt += f'\n{sample["state"]}' - choices = sample['choices'] - - labels = kwargs.get('labels', {}) - - alignment_target = None - if target_kdma_values is not None: - target_kdma = next(iter(next(iter(filter(lambda x: len(x) > 0, labels))))) # get the frist key of the first label that is not empty - - for label in labels: - assert len(label) == 0 or (target_kdma in label and len(label) == 1), f'All labels must have the same KDMA: labels={labels}' - - alignment_target = { - target_kdma: target_kdma_values[target_kdma] - } + for choice, label in zip(sample['choices'],labels): + level = 'high' if list(label.values())[0] > 5 else 'low' + attribute = list(label.keys())[0].replace('_', ' ') + prompt += f' If you had a {level} {attribute}, you would select {choice}.' + + return prompt #TODO: add prompt completetion here for choices as well. @@ -787,21 +781,33 @@ def __call__(self, sample, target_kdma_values, **kwargs): if 'incontext' in kwargs: possible_samples = [] + #sam has both info in first element and labels in second element for sam in kwargs['dataset']: - if sam['probe_id'] != sample['probe_id']: + if sam[0]['probe_id'] != sample['probe_id']: + possible_samples.append(sam) - if len(possible_samples) < kwargs['incontext']['number']: - raise(f'Not enough possible incontext samples to learn from here. - Only {len(possible_samples)} samples while asking for - {kwargs['incontext']['number']} in context samples') - - if kwargs['incontext']['method'] == 'random': - chosen_sample = random.sample(possible_samples, kwargs['incontext']['number']) - else: - raise(f'"{kwargs['incontext']['method']}" is not a valid incontext method. Please use "random", ') - + if len(possible_samples) < kwargs['incontext']['number']: + raise RuntimeError(f'Not enough possible incontext samples to learn from here.' + f'Only {len(possible_samples)} samples while asking for' + f'{kwargs["incontext"]["number"]} in context samples') + + if kwargs['incontext']['method'] == 'random': + chosen_sample = random.sample(possible_samples, kwargs['incontext']['number']) + else: + raise(f'"{kwargs["incontext"]["method"]}" is not a valid incontext method. Please use "random", ') + + incontext_prompt_start = ' Here are some examples of similar problems with their attributes. ' + + + extra_prompts = [incontext_prompt_start] + for cs, cl in chosen_sample: + extra_prompts.append(self.format_single_incontext_prompt(cs, cl)) + + extra_prompts.append(' Given these similar examples, please answer the question for the following scenario. ') + extra_prompts = ''.join(extra_prompts) + prompt = extra_prompts + prompt if 'retriever' in kwargs: # retriever_prompt = "How would you treat the following injuries: {}".format(prompt) From a7754ec474dae4c07f305a3ac5799e542782d4b3 Mon Sep 17 00:00:00 2001 From: Christopher Date: Wed, 6 Mar 2024 13:32:35 -0500 Subject: [PATCH 4/4] updating incontex for saying example --- align_system/algorithms/llama_2_single_kdma_adm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/align_system/algorithms/llama_2_single_kdma_adm.py b/align_system/algorithms/llama_2_single_kdma_adm.py index bb0edb23..926b8d06 100644 --- a/align_system/algorithms/llama_2_single_kdma_adm.py +++ b/align_system/algorithms/llama_2_single_kdma_adm.py @@ -801,8 +801,10 @@ def __call__(self, sample, target_kdma_values, **kwargs): extra_prompts = [incontext_prompt_start] + ci = 1 for cs, cl in chosen_sample: - extra_prompts.append(self.format_single_incontext_prompt(cs, cl)) + extra_prompts.append(f' Example {ci}' + self.format_single_incontext_prompt(cs, cl)) + ci += 1 extra_prompts.append(' Given these similar examples, please answer the question for the following scenario. ')