From b7fb1a0b359c0685b13ca0bed1f88863285dceba Mon Sep 17 00:00:00 2001 From: void-b583x2-NULL Date: Thu, 28 Dec 2023 09:47:48 +0000 Subject: [PATCH 1/6] mixtral extractor test --- refchecker/extractor/mixtral_extractor.py | 214 ++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 refchecker/extractor/mixtral_extractor.py diff --git a/refchecker/extractor/mixtral_extractor.py b/refchecker/extractor/mixtral_extractor.py new file mode 100644 index 0000000..9d88f02 --- /dev/null +++ b/refchecker/extractor/mixtral_extractor.py @@ -0,0 +1,214 @@ +from .extractor_base import ExtractorBase + +import torch +from vllm import LLM, SamplingParams +from typing import List, Tuple, Union + +one_digit_tensor = torch.ones((1, 1), dtype=torch.long) + +MISTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. +Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. + +Here are some in-context examples of extraction: + +### Question: +Given these paragraphs about the Tesla bot, what is its alias? + +### Candidate Answer: +Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. + +### KG: +("Optimus", "is", "robotic humanoid") +("Optimus", "under development by", "Tesla, Inc.") +("Optimus", "also known as", "Tesla Bot") +("Tesla, Inc.", "announced", "Optimus") +("Announcement of Optimus", "occured at", "Artificial Intelligence (AI) Day event") +("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") +("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") + +### Question: +here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? + +### Candidate Answer: +11 years + +### KG: +("Andre Weiss at University of Dijon in Paris", "duration", "11 years") + +Now generate the KG for the following candidate answer based on the provided question: + +### Question: +{q} + +### Candidate Answer: +{a} + +### KG: +""" + +MISTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. + +Here are some in-context examples of extraction: + +### Input: +Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. + +### KG: +("Optimus", "is", "robotic humanoid") +("Optimus", "under development by", "Tesla, Inc.") +("Optimus", "also known as", "Tesla Bot") +("Tesla, Inc.", "announced", "Optimus") +("Announcement of Optimus", "occured at", "Artificial Intelligence (AI) Day event") +("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") +("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") + +### Input: +Question: here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? +Answer: 11 years + +### KG: +("Andre Weiss at University of Dijon in Paris", "duration", "11 years") + +Now generate the KG for the following input text based on the provided question: + +### Input Text: +{input_text} + +### KG: +""" + + +class MistralClaimExtractor(ExtractorBase): + def __init__( + self, + claim_format: str = "triplet", + model="mistralai/Mixtral-8x7B-Instruct-v0.1", + load_format="safetensors", + tensor_parallel_size=-1, + temperature=0.0, + max_new_tokens=2048, + ) -> None: + self.llm = LLM( + model, + load_format=load_format, + tensor_parallel_size=torch.cuda.device_count() + if tensor_parallel_size < 0 + else tensor_parallel_size, + trust_remote_code=True, + ) + self.sampling_params = SamplingParams( + temperature=temperature, max_tokens=max_new_tokens + ) + self.tokenizer = self.llm.get_tokenizer() + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.llm.set_tokenizer(self.tokenizer) + self.claim_format = claim_format + + if self.claim_format == "triplet": + self.prompt_temp_wq = MISTRAL_KG_EXTRACTION_PROMPT_Q + self.prompt_temp = MISTRAL_KG_EXTRACTION_PROMPT + + def _mistral_encode_conversation( + self, prompt, answer_history: List[Tuple[str, str]] = [] + ): + """Encode conversations for Mixtral. + answer_hisory: [(u_1, a_1), (u_2, a_2), ... (u_k, a_k)] + Assume all strings are .strip()-ed and does not include anything like [INST][/INST] + """ + tokenizer = self.tokenizer + final_prompted_tensors = [] + + for i in range(len(answer_history)): + user_prompt, model_output = answer_history[i] + for_tokenize_prompt = f"[INST] {user_prompt} [/INST] {model_output}" + tokenized_id = tokenizer.encode( + for_tokenize_prompt, add_special_tokens=False, return_tensors="pt" + ) + final_prompted_tensors.append(tokenized_id) + final_prompted_tensors[-1] = torch.cat( + (final_prompted_tensors[-1], one_digit_tensor * tokenizer.eos_token_id), + dim=-1, + ) + + # last prompt + for_tokenize_prompt = f"[INST] {prompt} [/INST]" + tokenized_id = tokenizer.encode( + for_tokenize_prompt, add_special_tokens=False, return_tensors="pt" + ) + final_prompted_tensors.append(tokenized_id) + + # BOS + final_prompted_tensors[0] = torch.cat( + (one_digit_tensor * tokenizer.bos_token_id, final_prompted_tensors[0]), + dim=-1, + ) + + final_prompted_tensor = torch.cat(final_prompted_tensors, dim=-1) + return final_prompted_tensor + + def _get_response_from_mistral( + self, raw_prompt: str, conversation_history: List[str] = [] + ): + llm = self.llm + outputs = llm.generate( + prompt_token_ids=self._mistral_encode_conversation( + raw_prompt, conversation_history + ).tolist(), + sampling_params=self.sampling_params, + use_tqdm=False, + ) + llm_output = outputs[0].outputs[0].text + conversation_history = conversation_history.copy() + conversation_history.append((raw_prompt, llm_output)) + return llm_output, conversation_history + + def extract_claim_triplets(self, response, question=None): + if question is None: + prompt = self.prompt_temp.format(input_text=response) + else: + prompt = self.prompt_temp_wq.format(q=question, a=response) + mistral_response, _ = self._get_response_from_mistral(prompt) + if mistral_response and len(mistral_response): + kg_str = None + if "###" in mistral_response: + kg_str = mistral_response[: mistral_response.index("###")] + else: + kg_str = mistral_response + triplets = self._parse_claim_triplets(kg_str) + return triplets + return [] + + +if __name__ == "__main__": + import json + + extractor = MistralClaimExtractor() + example_for_test_q = { + "id": "192218", + "src": "msmarco/msmarco_chatgpt_answers.json", + "question": "full time student how many hours", + "response": 'Based on the provided passages, the number of hours required to be considered a full-time student can vary depending on the context. However, some common requirements mentioned are:\n\n- Passage 0: Full-time status is usually considered as a schedule of 12 or more semester or quarter hours, but there is no data on how many credits full-time students are actually taking.\n\n- Passage 1: For undergraduate students in the summer, 12 hours is considered full-time.\n\n- Passage 2: For fall and spring semesters, a full-time college student completes at least 12 semester hours. Some schools may require 15 semester hours. In the summer, completing 6 semester hours is considered full-time.\n\n- Passage 3: For graduate students, a normal full-time load is nine graduate-level semester hours. But for those holding graduate assistantships, full-time status is six semester hours.\n\n- Passage 4: According to the University Bulletin, graduate students taking 9 or more credit hours per semester (6 credits in the summer) are considered full-time.\n\n- Passage 5: The normal load for full-time students is 3 courses (9 credits).\n\n- Passage 8: For graduate students, 9 hours is considered full-time.\n\nBased on the information provided, the answer to the question "full time student how many hours?" is not explicitly stated or consistent across all passages. It can range from 12 hours for undergraduate students in the summer to 9 or more credit hours for graduate students. Some schools may require 15 semester hours for full-time status.', + } + + example_for_test = example_for_test_q.copy() + example_for_test[ + "input" + ] = f"Question: {example_for_test_q['question']}\nAnswer: {example_for_test_q['response']}" + + print( + json.dumps( + extractor.extract_claim_triplets( + example_for_test_q["response"], example_for_test_q["question"] + ), + ensure_ascii=False, + indent=4, + ) + ) + + print( + json.dumps( + extractor.extract_claim_triplets(example_for_test["input"]), + ensure_ascii=False, + indent=4, + ) + ) From 2b98cd6b8d1cf265becf2c6ecca9257d0711dfa5 Mon Sep 17 00:00:00 2001 From: rudongyu Date: Fri, 29 Dec 2023 02:35:26 +0000 Subject: [PATCH 2/6] add vllm dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 239e606..adba792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ openai = "^1.3.5" anthropic = "^0.7.4" datasets = "^2.15.0" plotly = "^5.18.0" +vllm = "^0.2.6" [tool.poetry.scripts] refchecker-cli = "refchecker.cli:main" From bbeba88b93b82318ef39ee78a269ebd3c9328d9c Mon Sep 17 00:00:00 2001 From: rudongyu Date: Fri, 29 Dec 2023 03:03:21 +0000 Subject: [PATCH 3/6] mclean api and modify code related to extractor addition --- refchecker/cli.py | 6 +- refchecker/extractor/__init__.py | 1 + refchecker/extractor/mixtral_extractor.py | 73 ++++++++++------------- 3 files changed, 38 insertions(+), 42 deletions(-) diff --git a/refchecker/cli.py b/refchecker/cli.py index 756842f..cf377d0 100644 --- a/refchecker/cli.py +++ b/refchecker/cli.py @@ -3,7 +3,7 @@ from argparse import ArgumentParser, RawTextHelpFormatter from tqdm import tqdm -from .extractor import Claude2Extractor, GPT4Extractor +from .extractor import Claude2Extractor, GPT4Extractor, MixtralExtractor from .checker import Claude2Checker, GPT4Checker, NLIChecker from .retriever import GoogleRetriever from .aggregator import strict_agg, soft_agg, major_agg @@ -31,7 +31,7 @@ def get_args(): ) parser.add_argument( '--extractor_name', type=str, default="claude2", - choices=["gpt4", "claude2"], + choices=["gpt4", "claude2", "mixtral"], help="Model used for extracting triplets. Default: claude2." ) parser.add_argument( @@ -124,6 +124,8 @@ def extract(args): extractor = Claude2Extractor() elif args.extractor_name == "gpt4": extractor = GPT4Extractor() + elif args.extractor_name == "mixtral": + extractor = MixtralExtractor() else: raise NotImplementedError diff --git a/refchecker/extractor/__init__.py b/refchecker/extractor/__init__.py index dc1fc28..1d47fdc 100644 --- a/refchecker/extractor/__init__.py +++ b/refchecker/extractor/__init__.py @@ -1,2 +1,3 @@ from .claude2_extractor import Claude2Extractor from .gpt4_extractor import GPT4Extractor +from .mixtral_extractor import MixtralExtractor diff --git a/refchecker/extractor/mixtral_extractor.py b/refchecker/extractor/mixtral_extractor.py index 9d88f02..b78d05b 100644 --- a/refchecker/extractor/mixtral_extractor.py +++ b/refchecker/extractor/mixtral_extractor.py @@ -2,11 +2,11 @@ import torch from vllm import LLM, SamplingParams -from typing import List, Tuple, Union +from typing import List, Tuple one_digit_tensor = torch.ones((1, 1), dtype=torch.long) -MISTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. +MIXTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Here are some in-context examples of extraction: @@ -46,7 +46,7 @@ ### KG: """ -MISTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. +MIXTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Here are some in-context examples of extraction: @@ -78,37 +78,30 @@ """ -class MistralClaimExtractor(ExtractorBase): +class MixtralExtractor(ExtractorBase): def __init__( self, claim_format: str = "triplet", - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - load_format="safetensors", tensor_parallel_size=-1, - temperature=0.0, - max_new_tokens=2048, ) -> None: + super().__init__(claim_format=claim_format) self.llm = LLM( - model, - load_format=load_format, + "mistralai/Mixtral-8x7B-Instruct-v0.1", + load_format="safetensors", tensor_parallel_size=torch.cuda.device_count() if tensor_parallel_size < 0 else tensor_parallel_size, trust_remote_code=True, ) - self.sampling_params = SamplingParams( - temperature=temperature, max_tokens=max_new_tokens - ) self.tokenizer = self.llm.get_tokenizer() self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.llm.set_tokenizer(self.tokenizer) - self.claim_format = claim_format if self.claim_format == "triplet": - self.prompt_temp_wq = MISTRAL_KG_EXTRACTION_PROMPT_Q - self.prompt_temp = MISTRAL_KG_EXTRACTION_PROMPT + self.prompt_temp_wq = MIXTRAL_KG_EXTRACTION_PROMPT_Q + self.prompt_temp = MIXTRAL_KG_EXTRACTION_PROMPT - def _mistral_encode_conversation( + def _mixtral_encode_conversation( self, prompt, answer_history: List[Tuple[str, str]] = [] ): """Encode conversations for Mixtral. @@ -146,15 +139,19 @@ def _mistral_encode_conversation( final_prompted_tensor = torch.cat(final_prompted_tensors, dim=-1) return final_prompted_tensor - def _get_response_from_mistral( - self, raw_prompt: str, conversation_history: List[str] = [] + # TODO move this to utils for general LLMs based on vllm interface + def _get_response_from_mixtral( + self, raw_prompt: str, conversation_history: List[str] = [], + max_new_tokens: int = 500 ): llm = self.llm outputs = llm.generate( - prompt_token_ids=self._mistral_encode_conversation( + prompt_token_ids=self._mixtral_encode_conversation( raw_prompt, conversation_history ).tolist(), - sampling_params=self.sampling_params, + sampling_params=SamplingParams( + temperature=0., max_tokens=max_new_tokens + ), use_tqdm=False, ) llm_output = outputs[0].outputs[0].text @@ -162,18 +159,20 @@ def _get_response_from_mistral( conversation_history.append((raw_prompt, llm_output)) return llm_output, conversation_history - def extract_claim_triplets(self, response, question=None): + def extract_claim_triplets(self, response, question=None, max_new_tokens=500): if question is None: prompt = self.prompt_temp.format(input_text=response) else: prompt = self.prompt_temp_wq.format(q=question, a=response) - mistral_response, _ = self._get_response_from_mistral(prompt) - if mistral_response and len(mistral_response): + mixtral_response, _ = self._get_response_from_mixtral( + prompt, max_new_tokens=max_new_tokens + ) + if mixtral_response and len(mixtral_response): kg_str = None - if "###" in mistral_response: - kg_str = mistral_response[: mistral_response.index("###")] + if "###" in mixtral_response: + kg_str = mixtral_response[: mixtral_response.index("###")] else: - kg_str = mistral_response + kg_str = mixtral_response triplets = self._parse_claim_triplets(kg_str) return triplets return [] @@ -182,33 +181,27 @@ def extract_claim_triplets(self, response, question=None): if __name__ == "__main__": import json - extractor = MistralClaimExtractor() - example_for_test_q = { - "id": "192218", - "src": "msmarco/msmarco_chatgpt_answers.json", + extractor = MixtralExtractor() + example_for_test = { "question": "full time student how many hours", "response": 'Based on the provided passages, the number of hours required to be considered a full-time student can vary depending on the context. However, some common requirements mentioned are:\n\n- Passage 0: Full-time status is usually considered as a schedule of 12 or more semester or quarter hours, but there is no data on how many credits full-time students are actually taking.\n\n- Passage 1: For undergraduate students in the summer, 12 hours is considered full-time.\n\n- Passage 2: For fall and spring semesters, a full-time college student completes at least 12 semester hours. Some schools may require 15 semester hours. In the summer, completing 6 semester hours is considered full-time.\n\n- Passage 3: For graduate students, a normal full-time load is nine graduate-level semester hours. But for those holding graduate assistantships, full-time status is six semester hours.\n\n- Passage 4: According to the University Bulletin, graduate students taking 9 or more credit hours per semester (6 credits in the summer) are considered full-time.\n\n- Passage 5: The normal load for full-time students is 3 courses (9 credits).\n\n- Passage 8: For graduate students, 9 hours is considered full-time.\n\nBased on the information provided, the answer to the question "full time student how many hours?" is not explicitly stated or consistent across all passages. It can range from 12 hours for undergraduate students in the summer to 9 or more credit hours for graduate students. Some schools may require 15 semester hours for full-time status.', } - example_for_test = example_for_test_q.copy() - example_for_test[ - "input" - ] = f"Question: {example_for_test_q['question']}\nAnswer: {example_for_test_q['response']}" - print( json.dumps( extractor.extract_claim_triplets( - example_for_test_q["response"], example_for_test_q["question"] + response=example_for_test["response"], + question=example_for_test["question"] ), - ensure_ascii=False, indent=4, ) ) print( json.dumps( - extractor.extract_claim_triplets(example_for_test["input"]), - ensure_ascii=False, + extractor.extract_claim_triplets( + response=example_for_test["response"] + ), indent=4, ) ) From 04eb3f84d021425dfeb3705d0f74a2e623061da6 Mon Sep 17 00:00:00 2001 From: void-b583x2-NULL Date: Wed, 17 Jan 2024 08:49:27 +0000 Subject: [PATCH 4/6] LLM API detached into utils.py --- .gitignore | 4 +- refchecker/cli.py | 158 +++++++++------ refchecker/extractor/__init__.py | 2 +- ...tral_extractor.py => mistral_extractor.py} | 116 +++-------- refchecker/utils.py | 185 +++++++++++++----- 5 files changed, 265 insertions(+), 200 deletions(-) rename refchecker/extractor/{mixtral_extractor.py => mistral_extractor.py} (59%) diff --git a/.gitignore b/.gitignore index 0c20b5e..1bbaaba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ *.DS_Store -__pycache__/ \ No newline at end of file +__pycache__/ +benchmark/data/*context/* +test/* \ No newline at end of file diff --git a/refchecker/cli.py b/refchecker/cli.py index cf377d0..d867f11 100644 --- a/refchecker/cli.py +++ b/refchecker/cli.py @@ -3,7 +3,7 @@ from argparse import ArgumentParser, RawTextHelpFormatter from tqdm import tqdm -from .extractor import Claude2Extractor, GPT4Extractor, MixtralExtractor +from .extractor import Claude2Extractor, GPT4Extractor, MistralExtractor from .checker import Claude2Checker, GPT4Checker, NLIChecker from .retriever import GoogleRetriever from .aggregator import strict_agg, soft_agg, major_agg @@ -12,78 +12,106 @@ def get_args(): parser = ArgumentParser(formatter_class=RawTextHelpFormatter) parser.add_argument( - "mode", nargs="?", choices=["extract", "check", "extract-check"], - help="extract: Extract triplets from provided responses.\n" - "check: Check whether the provided triplets are factual.\n" - "extract-check: Extract triplets and check whether they are factual." + "mode", + nargs="?", + choices=["extract", "check", "extract-check"], + help="extract: Extract triplets from provided responses.\ncheck: Check whether the provided triplets are factual.\nextract-check: Extract triplets and check whether they are factual.", ) parser.add_argument( - "--input_path", type=str, required=True, - help="Input path to the json file." + "--input_path", type=str, required=True, help="Input path to the json file." ) parser.add_argument( - "--output_path", type=str, required=True, - help="Output path to the result json file." + "--output_path", + type=str, + required=True, + help="Output path to the result json file.", ) parser.add_argument( - "--cache_dir", type=str, default="./.cache", - help="Path to the cache directory. Default: ./.cache" + "--cache_dir", + type=str, + default="./.cache", + help="Path to the cache directory. Default: ./.cache", ) parser.add_argument( - '--extractor_name', type=str, default="claude2", - choices=["gpt4", "claude2", "mixtral"], - help="Model used for extracting triplets. Default: claude2." + "--extractor_name", + type=str, + default="claude2", + choices=["gpt4", "claude2", "mistral", "mixtral"], + help="Model used for extracting triplets. Default: claude2.", ) parser.add_argument( - '--extractor_max_new_tokens', type=int, default=500, - help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500" + "--extractor_max_new_tokens", + type=int, + default=500, + help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500", ) parser.add_argument( - "--checker_name", type=str, default="claude2", + "--checker_name", + type=str, + default="claude2", choices=["gpt4", "claude2", "nli"], - help="Model used for checking whether the triplets are factual. " - "Default: claude2." + help="Model used for checking whether the triplets are factual. Default: claude2.", ) parser.add_argument( - "--retriever_name", type=str, default="google", choices=["google"], - help="Model used for retrieving reference (currently only google is" - " supported). Default: google." + "--retriever_name", + type=str, + default="google", + choices=["google"], + help="Model used for retrieving reference (currently only google is supported). Default: google.", ) parser.add_argument( - "--aggregator_name", type=str, default="soft", + "--aggregator_name", + type=str, + default="soft", choices=["strict", "soft", "major"], - help="Aggregator used for aggregating the results from multiple " - "triplets. Default: soft.\n" - "* strict: If any of the triplets is Contradiction, the response" - " is Contradiction.\nIf all of the triplets are Entailment, the " - "response is Entailment. Otherwise, the\nresponse is Neutral.\n" - "* soft: The ratio of each category is calculated.\n" - "* major: The category with the most votes is selected." + help="Aggregator used for aggregating the results from multiple triplets. Default: soft.\n* strict: If any of the triplets is Contradiction, the response is Contradiction.\nIf all of the triplets are Entailment, the response is Entailment. Otherwise, the\nresponse is Neutral.\n* soft: The ratio of each category is calculated.\n* major: The category with the most votes is selected.", ) parser.add_argument( - "--openai_key", type=str, default="", - help="Path to the openai api key file. Required if openAI models are" - " used." + "--openai_key", + type=str, + default="", + help="Path to the openai api key file. Required if openAI models are" " used.", ) parser.add_argument( - "--anthropic_key", type=str, default="", - help="Path to the Anthropic api key file. Required if the Anthropic " - "Claude2 api is used." + "--anthropic_key", + type=str, + default="", + help="Path to the Anthropic api key file. Required if the Anthropic Claude2 api is used.", ) parser.add_argument( - "--aws_bedrock_region", type=str, default="", - help="AWS region where the Amazon Bedrock api is deployed. Required if " - "the Amazon Bedrock api is used." + "--aws_bedrock_region", + type=str, + default="", + help="AWS region where the Amazon Bedrock api is deployed. Required if the Amazon Bedrock api is used.", ) parser.add_argument( - "--use_retrieval", action="store_true", - help="Whether to use retrieval to find the reference for checking. " - "Required if the reference\nfield in input data is not provided." + "--use_retrieval", + action="store_true", + help="Whether to use retrieval to find the reference for checking. Required if the reference\nfield in input data is not provided.", ) parser.add_argument( - "--serper_api_key", type=str, default="", - help="Path to the serper api key file. Required if the google retriever" - " is used." + "--serper_api_key", + type=str, + default="", + help="Path to the serper api key file. Required if the google retriever is used.", + ) + parser.add_argument( + "--local_llm_checkpoint_path", + type=str, + default=None, + help="Specify the local LLM checkpoint path if you use one other than the official release. By default, the official release of the specified LLM is used.", + ) + parser.add_argument( + "--extractor_ngpus", + type=int, + default=None, + help="Specify the number of GPUs you want to use in launching a local model. By default, 1 is used for small models and up to all are used for larger ones.", + ) + parser.add_argument( + "--nli_device", + type=int, + default=None, + help="Specify the device in using NLI model as checker. By default uses 0.", ) return parser.parse_args() @@ -124,23 +152,29 @@ def extract(args): extractor = Claude2Extractor() elif args.extractor_name == "gpt4": extractor = GPT4Extractor() - elif args.extractor_name == "mixtral": - extractor = MixtralExtractor() + elif args.extractor_name in ["mixtral", "mistral"]: + extractor = MistralExtractor( + model_path=args.local_llm_checkpoint_path, + use_gpu_num=args.extractor_ngpus, + model_name=args.extractor_name, + ) else: raise NotImplementedError # load data with open(args.input_path, "r") as fp: input_data = json.load(fp) - + # extract triplets - print('Extracting') + print("Extracting") output_data = [] for item in tqdm(input_data): assert "response" in item, "response field is required" response = item["response"] question = item.get("question", None) - triplets = extractor.extract_claim_triplets(response, question, max_new_tokens=args.extractor_max_new_tokens) + triplets = extractor.extract_claim_triplets( + response, question, max_new_tokens=args.extractor_max_new_tokens + ) out_item = {**item, **{"triplets": triplets}} output_data.append(out_item) with open(args.output_path, "w") as fp: @@ -154,17 +188,17 @@ def check(args): elif args.checker_name == "gpt4": checker = GPT4Checker() elif args.checker_name == "nli": - checker = NLIChecker() + checker = NLIChecker(device=args.nli_device) else: raise NotImplementedError - + retriever = None if args.use_retrieval: if args.retriever_name == "google": retriever = GoogleRetriever(args.cache_dir) else: raise NotImplementedError - + if args.aggregator_name == "strict": agg_fn = strict_agg elif args.aggregator_name == "soft": @@ -173,13 +207,13 @@ def check(args): agg_fn = major_agg else: raise NotImplementedError - + # load data with open(args.input_path, "r") as fp: input_data = json.load(fp) - + # check triplets - print('Checking') + print("Checking") output_data = [] for item in tqdm(input_data): assert "triplets" in item, "triplets field is required" @@ -188,21 +222,19 @@ def check(args): reference = retriever.retrieve(item["response"]) item["reference"] = reference else: - assert "reference" in item, \ - "reference field is required if retriever is not used." + assert ( + "reference" in item + ), "reference field is required if retriever is not used." reference = item["reference"] question = item.get("question", None) - results = [ - checker.check(t, reference, question=question) - for t in triplets - ] + results = [checker.check(t, reference, question=question) for t in triplets] agg_results = agg_fn(results) out_item = { **item, **{ "Y": agg_results, "ys": results, - } + }, } output_data.append(out_item) with open(args.output_path, "w") as fp: diff --git a/refchecker/extractor/__init__.py b/refchecker/extractor/__init__.py index 1d47fdc..778c3b2 100644 --- a/refchecker/extractor/__init__.py +++ b/refchecker/extractor/__init__.py @@ -1,3 +1,3 @@ from .claude2_extractor import Claude2Extractor from .gpt4_extractor import GPT4Extractor -from .mixtral_extractor import MixtralExtractor +from .mistral_extractor import MistralExtractor diff --git a/refchecker/extractor/mixtral_extractor.py b/refchecker/extractor/mistral_extractor.py similarity index 59% rename from refchecker/extractor/mixtral_extractor.py rename to refchecker/extractor/mistral_extractor.py index b78d05b..09bfd52 100644 --- a/refchecker/extractor/mixtral_extractor.py +++ b/refchecker/extractor/mistral_extractor.py @@ -1,12 +1,10 @@ from .extractor_base import ExtractorBase import torch -from vllm import LLM, SamplingParams -from typing import List, Tuple +from ..utils import get_response_from_mistral +from typing import Optional, Literal -one_digit_tensor = torch.ones((1, 1), dtype=torch.long) - -MIXTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. +MISTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Here are some in-context examples of extraction: @@ -46,7 +44,7 @@ ### KG: """ -MIXTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. +MISTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Here are some in-context examples of extraction: @@ -78,101 +76,41 @@ """ -class MixtralExtractor(ExtractorBase): +class MistralExtractor(ExtractorBase): def __init__( self, claim_format: str = "triplet", - tensor_parallel_size=-1, + model_name: Literal["mixtral", "mistral"] = "mixtral", + model_path: Optional[str] = None, + use_gpu_num: Optional[int] = None, ) -> None: super().__init__(claim_format=claim_format) - self.llm = LLM( - "mistralai/Mixtral-8x7B-Instruct-v0.1", - load_format="safetensors", - tensor_parallel_size=torch.cuda.device_count() - if tensor_parallel_size < 0 - else tensor_parallel_size, - trust_remote_code=True, - ) - self.tokenizer = self.llm.get_tokenizer() - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.llm.set_tokenizer(self.tokenizer) - if self.claim_format == "triplet": - self.prompt_temp_wq = MIXTRAL_KG_EXTRACTION_PROMPT_Q - self.prompt_temp = MIXTRAL_KG_EXTRACTION_PROMPT - - def _mixtral_encode_conversation( - self, prompt, answer_history: List[Tuple[str, str]] = [] - ): - """Encode conversations for Mixtral. - answer_hisory: [(u_1, a_1), (u_2, a_2), ... (u_k, a_k)] - Assume all strings are .strip()-ed and does not include anything like [INST][/INST] - """ - tokenizer = self.tokenizer - final_prompted_tensors = [] - - for i in range(len(answer_history)): - user_prompt, model_output = answer_history[i] - for_tokenize_prompt = f"[INST] {user_prompt} [/INST] {model_output}" - tokenized_id = tokenizer.encode( - for_tokenize_prompt, add_special_tokens=False, return_tensors="pt" - ) - final_prompted_tensors.append(tokenized_id) - final_prompted_tensors[-1] = torch.cat( - (final_prompted_tensors[-1], one_digit_tensor * tokenizer.eos_token_id), - dim=-1, - ) - - # last prompt - for_tokenize_prompt = f"[INST] {prompt} [/INST]" - tokenized_id = tokenizer.encode( - for_tokenize_prompt, add_special_tokens=False, return_tensors="pt" - ) - final_prompted_tensors.append(tokenized_id) - - # BOS - final_prompted_tensors[0] = torch.cat( - (one_digit_tensor * tokenizer.bos_token_id, final_prompted_tensors[0]), - dim=-1, - ) + self.prompt_temp_wq = MISTRAL_KG_EXTRACTION_PROMPT_Q + self.prompt_temp = MISTRAL_KG_EXTRACTION_PROMPT - final_prompted_tensor = torch.cat(final_prompted_tensors, dim=-1) - return final_prompted_tensor - - # TODO move this to utils for general LLMs based on vllm interface - def _get_response_from_mixtral( - self, raw_prompt: str, conversation_history: List[str] = [], - max_new_tokens: int = 500 - ): - llm = self.llm - outputs = llm.generate( - prompt_token_ids=self._mixtral_encode_conversation( - raw_prompt, conversation_history - ).tolist(), - sampling_params=SamplingParams( - temperature=0., max_tokens=max_new_tokens - ), - use_tqdm=False, - ) - llm_output = outputs[0].outputs[0].text - conversation_history = conversation_history.copy() - conversation_history.append((raw_prompt, llm_output)) - return llm_output, conversation_history + self.model_name = model_name + self.model_path = model_path + self.use_gpu_num = use_gpu_num def extract_claim_triplets(self, response, question=None, max_new_tokens=500): if question is None: prompt = self.prompt_temp.format(input_text=response) else: prompt = self.prompt_temp_wq.format(q=question, a=response) - mixtral_response, _ = self._get_response_from_mixtral( - prompt, max_new_tokens=max_new_tokens + mistral_response = get_response_from_mistral( + prompt, + max_new_tokens=max_new_tokens, + model_name=self.model_name, + model_path=self.model_path, + use_gpu_num=self.use_gpu_num, ) - if mixtral_response and len(mixtral_response): + if mistral_response and len(mistral_response): kg_str = None - if "###" in mixtral_response: - kg_str = mixtral_response[: mixtral_response.index("###")] + if "###" in mistral_response: + kg_str = mistral_response[: mistral_response.index("###")] else: - kg_str = mixtral_response + kg_str = mistral_response triplets = self._parse_claim_triplets(kg_str) return triplets return [] @@ -181,7 +119,7 @@ def extract_claim_triplets(self, response, question=None, max_new_tokens=500): if __name__ == "__main__": import json - extractor = MixtralExtractor() + extractor = MistralExtractor() example_for_test = { "question": "full time student how many hours", "response": 'Based on the provided passages, the number of hours required to be considered a full-time student can vary depending on the context. However, some common requirements mentioned are:\n\n- Passage 0: Full-time status is usually considered as a schedule of 12 or more semester or quarter hours, but there is no data on how many credits full-time students are actually taking.\n\n- Passage 1: For undergraduate students in the summer, 12 hours is considered full-time.\n\n- Passage 2: For fall and spring semesters, a full-time college student completes at least 12 semester hours. Some schools may require 15 semester hours. In the summer, completing 6 semester hours is considered full-time.\n\n- Passage 3: For graduate students, a normal full-time load is nine graduate-level semester hours. But for those holding graduate assistantships, full-time status is six semester hours.\n\n- Passage 4: According to the University Bulletin, graduate students taking 9 or more credit hours per semester (6 credits in the summer) are considered full-time.\n\n- Passage 5: The normal load for full-time students is 3 courses (9 credits).\n\n- Passage 8: For graduate students, 9 hours is considered full-time.\n\nBased on the information provided, the answer to the question "full time student how many hours?" is not explicitly stated or consistent across all passages. It can range from 12 hours for undergraduate students in the summer to 9 or more credit hours for graduate students. Some schools may require 15 semester hours for full-time status.', @@ -191,7 +129,7 @@ def extract_claim_triplets(self, response, question=None, max_new_tokens=500): json.dumps( extractor.extract_claim_triplets( response=example_for_test["response"], - question=example_for_test["question"] + question=example_for_test["question"], ), indent=4, ) @@ -199,9 +137,7 @@ def extract_claim_triplets(self, response, question=None, max_new_tokens=500): print( json.dumps( - extractor.extract_claim_triplets( - response=example_for_test["response"] - ), + extractor.extract_claim_triplets(response=example_for_test["response"]), indent=4, ) ) diff --git a/refchecker/utils.py b/refchecker/utils.py index 98e428f..974105c 100644 --- a/refchecker/utils.py +++ b/refchecker/utils.py @@ -11,6 +11,9 @@ import anthropic from anthropic import HUMAN_PROMPT, AI_PROMPT +from vllm import LLM, SamplingParams +import torch +from typing import List, Tuple, Union, Optional, Literal # Setup spaCy NLP nlp = None @@ -22,8 +25,11 @@ bedrock = None anthropic_client = None +# Setup Local LLM API +vllm_global_dict = {"mixtral": None, "mistral": None} -def sentencize(text): + +def sentencize(text: str): """Split text into sentences""" global nlp if not nlp: @@ -32,7 +38,7 @@ def sentencize(text): return [sent for sent in doc.sents] -def split_text(text, segment_len=200): +def split_text(text: str, segment_len: int = 200): """Split text into segments according to sentence boundaries.""" segments, seg = [], [] sents = [[token.text for token in sent] for sent in sentencize(text)] @@ -44,7 +50,7 @@ def split_text(text, segment_len=200): if len(seg) > segment_len: # split into chunks of segment_len seg = [ - " ".join(seg[i:i+segment_len]) + " ".join(seg[i : i + segment_len]) for i in range(0, len(seg), segment_len) ] segments.extend(seg) @@ -57,38 +63,31 @@ def split_text(text, segment_len=200): def get_openai_model_response( - prompt, - temperature=0, - model='gpt-3.5-turbo', - n_choices=1, - max_new_tokens=500 + prompt, temperature=0, model="gpt-3.5-turbo", n_choices=1, max_new_tokens=500 ): global openai_client if not openai_client: openai_client = openai.OpenAI() - + if not prompt or len(prompt) == 0: return None while True: try: if isinstance(prompt, str): - messages = [{ - 'role': 'user', - 'content': prompt - }] + messages = [{"role": "user", "content": prompt}] elif isinstance(prompt, list): messages = prompt else: return None res_choices = openai_client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - n=n_choices, - max_tokens=max_new_tokens - ).choices + model=model, + messages=messages, + temperature=temperature, + n=n_choices, + max_tokens=max_new_tokens, + ).choices if n_choices == 1: response = res_choices[0].message.content else: @@ -97,7 +96,11 @@ def get_openai_model_response( if response and len(response) > 0: return response except Exception as e: - if isinstance(e, OpenAIRateLimitError) or isinstance(e, OpenAIAPIError) or isinstance(e, OpenAITimeout): + if ( + isinstance(e, OpenAIRateLimitError) + or isinstance(e, OpenAIAPIError) + or isinstance(e, OpenAITimeout) + ): time.sleep(10) continue print(type(e), e) @@ -106,26 +109,22 @@ def get_openai_model_response( def get_claude2_response(prompt, temperature=0, max_new_tokens=500): - if os.environ.get('aws_bedrock_region'): + if os.environ.get("aws_bedrock_region"): global bedrock if not bedrock: bedrock = boto3.client( - service_name='bedrock-runtime', - region_name=os.environ.get('aws_bedrock_region') + service_name="bedrock-runtime", + region_name=os.environ.get("aws_bedrock_region"), ) return _get_bedrock_claude_completion( - prompt=prompt, - temperature=temperature, - max_new_tokens=max_new_tokens + prompt=prompt, temperature=temperature, max_new_tokens=max_new_tokens ) else: global anthropic_client if not anthropic_client: anthropic_client = anthropic.Anthropic() return _get_anthropic_claude_completion( - prompt=prompt, - temperature=temperature, - max_new_tokens=max_new_tokens + prompt=prompt, temperature=temperature, max_new_tokens=max_new_tokens ) @@ -134,23 +133,27 @@ def _get_bedrock_claude_completion(prompt, temperature=0, max_new_tokens=300): return None while True: try: - body = json.dumps({ - "prompt": f"\n\nHuman: {prompt} \n\nAssistant:", - "max_tokens_to_sample": max_new_tokens, - "temperature": temperature, - "top_p": 0.9, - }) - modelId = 'anthropic.claude-v2' - accept = 'application/json' - contentType = 'application/json' - - response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType) - - response_body = json.loads(response.get('body').read()) + body = json.dumps( + { + "prompt": f"\n\nHuman: {prompt} \n\nAssistant:", + "max_tokens_to_sample": max_new_tokens, + "temperature": temperature, + "top_p": 0.9, + } + ) + modelId = "anthropic.claude-v2" + accept = "application/json" + contentType = "application/json" + + response = bedrock.invoke_model( + body=body, modelId=modelId, accept=accept, contentType=contentType + ) + + response_body = json.loads(response.get("body").read()) # text - return response_body.get('completion') + return response_body.get("completion") except Exception as e: - if e.response['Error']['Code'] == 'ThrottlingException': + if e.response["Error"]["Code"] == "ThrottlingException": time.sleep(10) continue print(type(e), e) @@ -161,7 +164,7 @@ def _get_bedrock_claude_completion(prompt, temperature=0, max_new_tokens=300): def _get_anthropic_claude_completion(prompt, temperature=0, max_new_tokens=300): if not prompt or len(prompt) == 0: return None - + completion = anthropic_client.completions.create( model="claude-2", max_tokens_to_sample=max_new_tokens, @@ -169,3 +172,95 @@ def _get_anthropic_claude_completion(prompt, temperature=0, max_new_tokens=300): prompt=f"{HUMAN_PROMPT} {prompt}{AI_PROMPT}", ) return completion.completion + + +def _get_response_from_local_llm( + llm: LLM, prompt=None, prompt_token_ids=None, temperature=0, max_new_tokens=300 +): + """Receive prompt or prompt token ids as input and give the output of llm.""" + outputs = llm.generate( + prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=SamplingParams( + temperature=temperature, max_tokens=max_new_tokens + ), + use_tqdm=False, + ) + llm_output = outputs[0].outputs[0].text + return llm_output + + +def _mistral_encode_conversation(tokenizer, prompt: str): + """Encode prompts for Mistral/Mixtral. + Assume all strings are .strip()-ed and does not include anything like [INST][/INST] + """ + final_prompted_tensors = [] + one_digit_tensor = torch.ones((1, 1), dtype=torch.long) + + # prompt + for_tokenize_prompt = f"[INST] {prompt} [/INST]" + tokenized_id = tokenizer.encode( + for_tokenize_prompt, add_special_tokens=False, return_tensors="pt" + ) + final_prompted_tensors.append(tokenized_id) + + # BOS + final_prompted_tensors[0] = torch.cat( + (one_digit_tensor * tokenizer.bos_token_id, final_prompted_tensors[0]), + dim=-1, + ) + + final_prompted_tensor = torch.cat(final_prompted_tensors, dim=-1) + return final_prompted_tensor + + +def _clear_llm_dict(): + """Clear unused LLMs to save memory""" + global vllm_global_dict + for key in vllm_global_dict: + vllm_global_dict[key] = None + torch.cuda.empty_cache() + + +def get_response_from_mistral( + prompt: str, + temperature=0, + max_new_tokens=300, + model_name: Literal["mistral", "mixtral"] = "mixtral", + model_path: Optional[str] = None, + use_gpu_num: Optional[int] = None, +): + if vllm_global_dict[model_name] is None: + # Clear GPUs to save memory + _clear_llm_dict() + # Specify args + if model_name == "mixtral": + model_path = ( + model_path if model_path else "mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + tp_size = use_gpu_num if use_gpu_num else torch.cuda.device_count() + elif model_name == "mistral": + model_path = ( + model_path if model_path else "mistralai/Mistral-7B-Instruct-v0.2" + ) + tp_size = use_gpu_num if use_gpu_num else 2 + + # model setup + llm = LLM( + model_path, + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) + tokenizer = llm.get_tokenizer() + tokenizer.pad_token_id = tokenizer.eos_token_id + llm.set_tokenizer(tokenizer) + vllm_global_dict[model_name] = llm + + llm: LLM = vllm_global_dict[model_name] + mistral_encoded_tensor = _mistral_encode_conversation(llm.get_tokenizer(), prompt) + return _get_response_from_local_llm( + llm, + prompt_token_ids=mistral_encoded_tensor.tolist(), + temperature=temperature, + max_new_tokens=max_new_tokens, + ) From 25ac95ae8bed04238fe332c246f9b110a15b3e2f Mon Sep 17 00:00:00 2001 From: void-b583x2-NULL Date: Thu, 18 Jan 2024 08:48:10 +0000 Subject: [PATCH 5/6] shorten Mistral encode --- refchecker/utils.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/refchecker/utils.py b/refchecker/utils.py index 974105c..2b65bf7 100644 --- a/refchecker/utils.py +++ b/refchecker/utils.py @@ -175,12 +175,11 @@ def _get_anthropic_claude_completion(prompt, temperature=0, max_new_tokens=300): def _get_response_from_local_llm( - llm: LLM, prompt=None, prompt_token_ids=None, temperature=0, max_new_tokens=300 + llm: LLM, prompt=None, temperature=0, max_new_tokens=300 ): """Receive prompt or prompt token ids as input and give the output of llm.""" outputs = llm.generate( prompt, - prompt_token_ids=prompt_token_ids, sampling_params=SamplingParams( temperature=temperature, max_tokens=max_new_tokens ), @@ -190,30 +189,6 @@ def _get_response_from_local_llm( return llm_output -def _mistral_encode_conversation(tokenizer, prompt: str): - """Encode prompts for Mistral/Mixtral. - Assume all strings are .strip()-ed and does not include anything like [INST][/INST] - """ - final_prompted_tensors = [] - one_digit_tensor = torch.ones((1, 1), dtype=torch.long) - - # prompt - for_tokenize_prompt = f"[INST] {prompt} [/INST]" - tokenized_id = tokenizer.encode( - for_tokenize_prompt, add_special_tokens=False, return_tensors="pt" - ) - final_prompted_tensors.append(tokenized_id) - - # BOS - final_prompted_tensors[0] = torch.cat( - (one_digit_tensor * tokenizer.bos_token_id, final_prompted_tensors[0]), - dim=-1, - ) - - final_prompted_tensor = torch.cat(final_prompted_tensors, dim=-1) - return final_prompted_tensor - - def _clear_llm_dict(): """Clear unused LLMs to save memory""" global vllm_global_dict @@ -255,12 +230,12 @@ def get_response_from_mistral( tokenizer.pad_token_id = tokenizer.eos_token_id llm.set_tokenizer(tokenizer) vllm_global_dict[model_name] = llm - + prompt = f"[INST] {prompt} [/INST]" llm: LLM = vllm_global_dict[model_name] - mistral_encoded_tensor = _mistral_encode_conversation(llm.get_tokenizer(), prompt) + return _get_response_from_local_llm( llm, - prompt_token_ids=mistral_encoded_tensor.tolist(), + prompt=prompt, temperature=temperature, max_new_tokens=max_new_tokens, ) From 8e89248ed079530361d250f01eefba78ff7838cb Mon Sep 17 00:00:00 2001 From: void-b583x2-NULL Date: Wed, 24 Jan 2024 06:01:26 +0000 Subject: [PATCH 6/6] minor changes --- refchecker/extractor/mistral_extractor.py | 4 ++-- refchecker/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/refchecker/extractor/mistral_extractor.py b/refchecker/extractor/mistral_extractor.py index 09bfd52..8432881 100644 --- a/refchecker/extractor/mistral_extractor.py +++ b/refchecker/extractor/mistral_extractor.py @@ -5,7 +5,7 @@ from typing import Optional, Literal MISTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. -Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. +Please note that this is an extraction task, so do not care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Here are some in-context examples of extraction: @@ -44,7 +44,7 @@ ### KG: """ -MISTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an EXTRACTION task, so DO NOT care about whether the content of the candidate answer is factual or not, just extract the triplets from it. +MISTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an extraction task, so do not care about whether the content of the candidate answer is factual or not, just extract the triplets from it. Here are some in-context examples of extraction: diff --git a/refchecker/utils.py b/refchecker/utils.py index 2b65bf7..d9f3b7d 100644 --- a/refchecker/utils.py +++ b/refchecker/utils.py @@ -218,7 +218,7 @@ def get_response_from_mistral( model_path = ( model_path if model_path else "mistralai/Mistral-7B-Instruct-v0.2" ) - tp_size = use_gpu_num if use_gpu_num else 2 + tp_size = use_gpu_num if use_gpu_num else 1 # model setup llm = LLM(