diff --git a/.gitignore b/.gitignore index 0c20b5e..1bbaaba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ *.DS_Store -__pycache__/ \ No newline at end of file +__pycache__/ +benchmark/data/*context/* +test/* \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 239e606..adba792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ openai = "^1.3.5" anthropic = "^0.7.4" datasets = "^2.15.0" plotly = "^5.18.0" +vllm = "^0.2.6" [tool.poetry.scripts] refchecker-cli = "refchecker.cli:main" diff --git a/refchecker/cli.py b/refchecker/cli.py index 756842f..d867f11 100644 --- a/refchecker/cli.py +++ b/refchecker/cli.py @@ -3,7 +3,7 @@ from argparse import ArgumentParser, RawTextHelpFormatter from tqdm import tqdm -from .extractor import Claude2Extractor, GPT4Extractor +from .extractor import Claude2Extractor, GPT4Extractor, MistralExtractor from .checker import Claude2Checker, GPT4Checker, NLIChecker from .retriever import GoogleRetriever from .aggregator import strict_agg, soft_agg, major_agg @@ -12,78 +12,106 @@ def get_args(): parser = ArgumentParser(formatter_class=RawTextHelpFormatter) parser.add_argument( - "mode", nargs="?", choices=["extract", "check", "extract-check"], - help="extract: Extract triplets from provided responses.\n" - "check: Check whether the provided triplets are factual.\n" - "extract-check: Extract triplets and check whether they are factual." + "mode", + nargs="?", + choices=["extract", "check", "extract-check"], + help="extract: Extract triplets from provided responses.\ncheck: Check whether the provided triplets are factual.\nextract-check: Extract triplets and check whether they are factual.", ) parser.add_argument( - "--input_path", type=str, required=True, - help="Input path to the json file." + "--input_path", type=str, required=True, help="Input path to the json file." ) parser.add_argument( - "--output_path", type=str, required=True, - help="Output path to the result json file." + "--output_path", + type=str, + required=True, + help="Output path to the result json file.", ) parser.add_argument( - "--cache_dir", type=str, default="./.cache", - help="Path to the cache directory. Default: ./.cache" + "--cache_dir", + type=str, + default="./.cache", + help="Path to the cache directory. Default: ./.cache", ) parser.add_argument( - '--extractor_name', type=str, default="claude2", - choices=["gpt4", "claude2"], - help="Model used for extracting triplets. Default: claude2." + "--extractor_name", + type=str, + default="claude2", + choices=["gpt4", "claude2", "mistral", "mixtral"], + help="Model used for extracting triplets. Default: claude2.", ) parser.add_argument( - '--extractor_max_new_tokens', type=int, default=500, - help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500" + "--extractor_max_new_tokens", + type=int, + default=500, + help="Max generated tokens of the extractor, set a larger value for longer documents. Default: 500", ) parser.add_argument( - "--checker_name", type=str, default="claude2", + "--checker_name", + type=str, + default="claude2", choices=["gpt4", "claude2", "nli"], - help="Model used for checking whether the triplets are factual. " - "Default: claude2." + help="Model used for checking whether the triplets are factual. Default: claude2.", ) parser.add_argument( - "--retriever_name", type=str, default="google", choices=["google"], - help="Model used for retrieving reference (currently only google is" - " supported). Default: google." + "--retriever_name", + type=str, + default="google", + choices=["google"], + help="Model used for retrieving reference (currently only google is supported). Default: google.", ) parser.add_argument( - "--aggregator_name", type=str, default="soft", + "--aggregator_name", + type=str, + default="soft", choices=["strict", "soft", "major"], - help="Aggregator used for aggregating the results from multiple " - "triplets. Default: soft.\n" - "* strict: If any of the triplets is Contradiction, the response" - " is Contradiction.\nIf all of the triplets are Entailment, the " - "response is Entailment. Otherwise, the\nresponse is Neutral.\n" - "* soft: The ratio of each category is calculated.\n" - "* major: The category with the most votes is selected." + help="Aggregator used for aggregating the results from multiple triplets. Default: soft.\n* strict: If any of the triplets is Contradiction, the response is Contradiction.\nIf all of the triplets are Entailment, the response is Entailment. Otherwise, the\nresponse is Neutral.\n* soft: The ratio of each category is calculated.\n* major: The category with the most votes is selected.", ) parser.add_argument( - "--openai_key", type=str, default="", - help="Path to the openai api key file. Required if openAI models are" - " used." + "--openai_key", + type=str, + default="", + help="Path to the openai api key file. Required if openAI models are" " used.", ) parser.add_argument( - "--anthropic_key", type=str, default="", - help="Path to the Anthropic api key file. Required if the Anthropic " - "Claude2 api is used." + "--anthropic_key", + type=str, + default="", + help="Path to the Anthropic api key file. Required if the Anthropic Claude2 api is used.", ) parser.add_argument( - "--aws_bedrock_region", type=str, default="", - help="AWS region where the Amazon Bedrock api is deployed. Required if " - "the Amazon Bedrock api is used." + "--aws_bedrock_region", + type=str, + default="", + help="AWS region where the Amazon Bedrock api is deployed. Required if the Amazon Bedrock api is used.", ) parser.add_argument( - "--use_retrieval", action="store_true", - help="Whether to use retrieval to find the reference for checking. " - "Required if the reference\nfield in input data is not provided." + "--use_retrieval", + action="store_true", + help="Whether to use retrieval to find the reference for checking. Required if the reference\nfield in input data is not provided.", ) parser.add_argument( - "--serper_api_key", type=str, default="", - help="Path to the serper api key file. Required if the google retriever" - " is used." + "--serper_api_key", + type=str, + default="", + help="Path to the serper api key file. Required if the google retriever is used.", + ) + parser.add_argument( + "--local_llm_checkpoint_path", + type=str, + default=None, + help="Specify the local LLM checkpoint path if you use one other than the official release. By default, the official release of the specified LLM is used.", + ) + parser.add_argument( + "--extractor_ngpus", + type=int, + default=None, + help="Specify the number of GPUs you want to use in launching a local model. By default, 1 is used for small models and up to all are used for larger ones.", + ) + parser.add_argument( + "--nli_device", + type=int, + default=None, + help="Specify the device in using NLI model as checker. By default uses 0.", ) return parser.parse_args() @@ -124,21 +152,29 @@ def extract(args): extractor = Claude2Extractor() elif args.extractor_name == "gpt4": extractor = GPT4Extractor() + elif args.extractor_name in ["mixtral", "mistral"]: + extractor = MistralExtractor( + model_path=args.local_llm_checkpoint_path, + use_gpu_num=args.extractor_ngpus, + model_name=args.extractor_name, + ) else: raise NotImplementedError # load data with open(args.input_path, "r") as fp: input_data = json.load(fp) - + # extract triplets - print('Extracting') + print("Extracting") output_data = [] for item in tqdm(input_data): assert "response" in item, "response field is required" response = item["response"] question = item.get("question", None) - triplets = extractor.extract_claim_triplets(response, question, max_new_tokens=args.extractor_max_new_tokens) + triplets = extractor.extract_claim_triplets( + response, question, max_new_tokens=args.extractor_max_new_tokens + ) out_item = {**item, **{"triplets": triplets}} output_data.append(out_item) with open(args.output_path, "w") as fp: @@ -152,17 +188,17 @@ def check(args): elif args.checker_name == "gpt4": checker = GPT4Checker() elif args.checker_name == "nli": - checker = NLIChecker() + checker = NLIChecker(device=args.nli_device) else: raise NotImplementedError - + retriever = None if args.use_retrieval: if args.retriever_name == "google": retriever = GoogleRetriever(args.cache_dir) else: raise NotImplementedError - + if args.aggregator_name == "strict": agg_fn = strict_agg elif args.aggregator_name == "soft": @@ -171,13 +207,13 @@ def check(args): agg_fn = major_agg else: raise NotImplementedError - + # load data with open(args.input_path, "r") as fp: input_data = json.load(fp) - + # check triplets - print('Checking') + print("Checking") output_data = [] for item in tqdm(input_data): assert "triplets" in item, "triplets field is required" @@ -186,21 +222,19 @@ def check(args): reference = retriever.retrieve(item["response"]) item["reference"] = reference else: - assert "reference" in item, \ - "reference field is required if retriever is not used." + assert ( + "reference" in item + ), "reference field is required if retriever is not used." reference = item["reference"] question = item.get("question", None) - results = [ - checker.check(t, reference, question=question) - for t in triplets - ] + results = [checker.check(t, reference, question=question) for t in triplets] agg_results = agg_fn(results) out_item = { **item, **{ "Y": agg_results, "ys": results, - } + }, } output_data.append(out_item) with open(args.output_path, "w") as fp: diff --git a/refchecker/extractor/__init__.py b/refchecker/extractor/__init__.py index dc1fc28..778c3b2 100644 --- a/refchecker/extractor/__init__.py +++ b/refchecker/extractor/__init__.py @@ -1,2 +1,3 @@ from .claude2_extractor import Claude2Extractor from .gpt4_extractor import GPT4Extractor +from .mistral_extractor import MistralExtractor diff --git a/refchecker/extractor/mistral_extractor.py b/refchecker/extractor/mistral_extractor.py new file mode 100644 index 0000000..8432881 --- /dev/null +++ b/refchecker/extractor/mistral_extractor.py @@ -0,0 +1,143 @@ +from .extractor_base import ExtractorBase + +import torch +from ..utils import get_response_from_mistral +from typing import Optional, Literal + +MISTRAL_KG_EXTRACTION_PROMPT_Q = """Given a question and a candidate answer to the question, please extract a KG from the candidate answer condition on the question and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. +Please note that this is an extraction task, so do not care about whether the content of the candidate answer is factual or not, just extract the triplets from it. + +Here are some in-context examples of extraction: + +### Question: +Given these paragraphs about the Tesla bot, what is its alias? + +### Candidate Answer: +Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. + +### KG: +("Optimus", "is", "robotic humanoid") +("Optimus", "under development by", "Tesla, Inc.") +("Optimus", "also known as", "Tesla Bot") +("Tesla, Inc.", "announced", "Optimus") +("Announcement of Optimus", "occured at", "Artificial Intelligence (AI) Day event") +("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") +("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") + +### Question: +here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? + +### Candidate Answer: +11 years + +### KG: +("Andre Weiss at University of Dijon in Paris", "duration", "11 years") + +Now generate the KG for the following candidate answer based on the provided question: + +### Question: +{q} + +### Candidate Answer: +{a} + +### KG: +""" + +MISTRAL_KG_EXTRACTION_PROMPT = """Given an input text, please extract a KG from the text and represent the KG with triples formatted with ("head", "relation", "tail"), each triplet in a line. Please note that this is an extraction task, so do not care about whether the content of the candidate answer is factual or not, just extract the triplets from it. + +Here are some in-context examples of extraction: + +### Input: +Optimus (or Tesla Bot) is a robotic humanoid under development by Tesla, Inc. It was announced at the company's Artificial Intelligence (AI) Day event on August 19, 2021. + +### KG: +("Optimus", "is", "robotic humanoid") +("Optimus", "under development by", "Tesla, Inc.") +("Optimus", "also known as", "Tesla Bot") +("Tesla, Inc.", "announced", "Optimus") +("Announcement of Optimus", "occured at", "Artificial Intelligence (AI) Day event") +("Artificial Intelligence (AI) Day event", "held on", "August 19, 2021") +("Artificial Intelligence (AI) Day event", "organized by", "Tesla, Inc.") + +### Input: +Question: here is some text about Andre Weiss, how many years was Andre at University of Dijon in Paris? +Answer: 11 years + +### KG: +("Andre Weiss at University of Dijon in Paris", "duration", "11 years") + +Now generate the KG for the following input text based on the provided question: + +### Input Text: +{input_text} + +### KG: +""" + + +class MistralExtractor(ExtractorBase): + def __init__( + self, + claim_format: str = "triplet", + model_name: Literal["mixtral", "mistral"] = "mixtral", + model_path: Optional[str] = None, + use_gpu_num: Optional[int] = None, + ) -> None: + super().__init__(claim_format=claim_format) + if self.claim_format == "triplet": + self.prompt_temp_wq = MISTRAL_KG_EXTRACTION_PROMPT_Q + self.prompt_temp = MISTRAL_KG_EXTRACTION_PROMPT + + self.model_name = model_name + self.model_path = model_path + self.use_gpu_num = use_gpu_num + + def extract_claim_triplets(self, response, question=None, max_new_tokens=500): + if question is None: + prompt = self.prompt_temp.format(input_text=response) + else: + prompt = self.prompt_temp_wq.format(q=question, a=response) + mistral_response = get_response_from_mistral( + prompt, + max_new_tokens=max_new_tokens, + model_name=self.model_name, + model_path=self.model_path, + use_gpu_num=self.use_gpu_num, + ) + if mistral_response and len(mistral_response): + kg_str = None + if "###" in mistral_response: + kg_str = mistral_response[: mistral_response.index("###")] + else: + kg_str = mistral_response + triplets = self._parse_claim_triplets(kg_str) + return triplets + return [] + + +if __name__ == "__main__": + import json + + extractor = MistralExtractor() + example_for_test = { + "question": "full time student how many hours", + "response": 'Based on the provided passages, the number of hours required to be considered a full-time student can vary depending on the context. However, some common requirements mentioned are:\n\n- Passage 0: Full-time status is usually considered as a schedule of 12 or more semester or quarter hours, but there is no data on how many credits full-time students are actually taking.\n\n- Passage 1: For undergraduate students in the summer, 12 hours is considered full-time.\n\n- Passage 2: For fall and spring semesters, a full-time college student completes at least 12 semester hours. Some schools may require 15 semester hours. In the summer, completing 6 semester hours is considered full-time.\n\n- Passage 3: For graduate students, a normal full-time load is nine graduate-level semester hours. But for those holding graduate assistantships, full-time status is six semester hours.\n\n- Passage 4: According to the University Bulletin, graduate students taking 9 or more credit hours per semester (6 credits in the summer) are considered full-time.\n\n- Passage 5: The normal load for full-time students is 3 courses (9 credits).\n\n- Passage 8: For graduate students, 9 hours is considered full-time.\n\nBased on the information provided, the answer to the question "full time student how many hours?" is not explicitly stated or consistent across all passages. It can range from 12 hours for undergraduate students in the summer to 9 or more credit hours for graduate students. Some schools may require 15 semester hours for full-time status.', + } + + print( + json.dumps( + extractor.extract_claim_triplets( + response=example_for_test["response"], + question=example_for_test["question"], + ), + indent=4, + ) + ) + + print( + json.dumps( + extractor.extract_claim_triplets(response=example_for_test["response"]), + indent=4, + ) + ) diff --git a/refchecker/utils.py b/refchecker/utils.py index 98e428f..d9f3b7d 100644 --- a/refchecker/utils.py +++ b/refchecker/utils.py @@ -11,6 +11,9 @@ import anthropic from anthropic import HUMAN_PROMPT, AI_PROMPT +from vllm import LLM, SamplingParams +import torch +from typing import List, Tuple, Union, Optional, Literal # Setup spaCy NLP nlp = None @@ -22,8 +25,11 @@ bedrock = None anthropic_client = None +# Setup Local LLM API +vllm_global_dict = {"mixtral": None, "mistral": None} -def sentencize(text): + +def sentencize(text: str): """Split text into sentences""" global nlp if not nlp: @@ -32,7 +38,7 @@ def sentencize(text): return [sent for sent in doc.sents] -def split_text(text, segment_len=200): +def split_text(text: str, segment_len: int = 200): """Split text into segments according to sentence boundaries.""" segments, seg = [], [] sents = [[token.text for token in sent] for sent in sentencize(text)] @@ -44,7 +50,7 @@ def split_text(text, segment_len=200): if len(seg) > segment_len: # split into chunks of segment_len seg = [ - " ".join(seg[i:i+segment_len]) + " ".join(seg[i : i + segment_len]) for i in range(0, len(seg), segment_len) ] segments.extend(seg) @@ -57,38 +63,31 @@ def split_text(text, segment_len=200): def get_openai_model_response( - prompt, - temperature=0, - model='gpt-3.5-turbo', - n_choices=1, - max_new_tokens=500 + prompt, temperature=0, model="gpt-3.5-turbo", n_choices=1, max_new_tokens=500 ): global openai_client if not openai_client: openai_client = openai.OpenAI() - + if not prompt or len(prompt) == 0: return None while True: try: if isinstance(prompt, str): - messages = [{ - 'role': 'user', - 'content': prompt - }] + messages = [{"role": "user", "content": prompt}] elif isinstance(prompt, list): messages = prompt else: return None res_choices = openai_client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - n=n_choices, - max_tokens=max_new_tokens - ).choices + model=model, + messages=messages, + temperature=temperature, + n=n_choices, + max_tokens=max_new_tokens, + ).choices if n_choices == 1: response = res_choices[0].message.content else: @@ -97,7 +96,11 @@ def get_openai_model_response( if response and len(response) > 0: return response except Exception as e: - if isinstance(e, OpenAIRateLimitError) or isinstance(e, OpenAIAPIError) or isinstance(e, OpenAITimeout): + if ( + isinstance(e, OpenAIRateLimitError) + or isinstance(e, OpenAIAPIError) + or isinstance(e, OpenAITimeout) + ): time.sleep(10) continue print(type(e), e) @@ -106,26 +109,22 @@ def get_openai_model_response( def get_claude2_response(prompt, temperature=0, max_new_tokens=500): - if os.environ.get('aws_bedrock_region'): + if os.environ.get("aws_bedrock_region"): global bedrock if not bedrock: bedrock = boto3.client( - service_name='bedrock-runtime', - region_name=os.environ.get('aws_bedrock_region') + service_name="bedrock-runtime", + region_name=os.environ.get("aws_bedrock_region"), ) return _get_bedrock_claude_completion( - prompt=prompt, - temperature=temperature, - max_new_tokens=max_new_tokens + prompt=prompt, temperature=temperature, max_new_tokens=max_new_tokens ) else: global anthropic_client if not anthropic_client: anthropic_client = anthropic.Anthropic() return _get_anthropic_claude_completion( - prompt=prompt, - temperature=temperature, - max_new_tokens=max_new_tokens + prompt=prompt, temperature=temperature, max_new_tokens=max_new_tokens ) @@ -134,23 +133,27 @@ def _get_bedrock_claude_completion(prompt, temperature=0, max_new_tokens=300): return None while True: try: - body = json.dumps({ - "prompt": f"\n\nHuman: {prompt} \n\nAssistant:", - "max_tokens_to_sample": max_new_tokens, - "temperature": temperature, - "top_p": 0.9, - }) - modelId = 'anthropic.claude-v2' - accept = 'application/json' - contentType = 'application/json' - - response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType) - - response_body = json.loads(response.get('body').read()) + body = json.dumps( + { + "prompt": f"\n\nHuman: {prompt} \n\nAssistant:", + "max_tokens_to_sample": max_new_tokens, + "temperature": temperature, + "top_p": 0.9, + } + ) + modelId = "anthropic.claude-v2" + accept = "application/json" + contentType = "application/json" + + response = bedrock.invoke_model( + body=body, modelId=modelId, accept=accept, contentType=contentType + ) + + response_body = json.loads(response.get("body").read()) # text - return response_body.get('completion') + return response_body.get("completion") except Exception as e: - if e.response['Error']['Code'] == 'ThrottlingException': + if e.response["Error"]["Code"] == "ThrottlingException": time.sleep(10) continue print(type(e), e) @@ -161,7 +164,7 @@ def _get_bedrock_claude_completion(prompt, temperature=0, max_new_tokens=300): def _get_anthropic_claude_completion(prompt, temperature=0, max_new_tokens=300): if not prompt or len(prompt) == 0: return None - + completion = anthropic_client.completions.create( model="claude-2", max_tokens_to_sample=max_new_tokens, @@ -169,3 +172,70 @@ def _get_anthropic_claude_completion(prompt, temperature=0, max_new_tokens=300): prompt=f"{HUMAN_PROMPT} {prompt}{AI_PROMPT}", ) return completion.completion + + +def _get_response_from_local_llm( + llm: LLM, prompt=None, temperature=0, max_new_tokens=300 +): + """Receive prompt or prompt token ids as input and give the output of llm.""" + outputs = llm.generate( + prompt, + sampling_params=SamplingParams( + temperature=temperature, max_tokens=max_new_tokens + ), + use_tqdm=False, + ) + llm_output = outputs[0].outputs[0].text + return llm_output + + +def _clear_llm_dict(): + """Clear unused LLMs to save memory""" + global vllm_global_dict + for key in vllm_global_dict: + vllm_global_dict[key] = None + torch.cuda.empty_cache() + + +def get_response_from_mistral( + prompt: str, + temperature=0, + max_new_tokens=300, + model_name: Literal["mistral", "mixtral"] = "mixtral", + model_path: Optional[str] = None, + use_gpu_num: Optional[int] = None, +): + if vllm_global_dict[model_name] is None: + # Clear GPUs to save memory + _clear_llm_dict() + # Specify args + if model_name == "mixtral": + model_path = ( + model_path if model_path else "mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + tp_size = use_gpu_num if use_gpu_num else torch.cuda.device_count() + elif model_name == "mistral": + model_path = ( + model_path if model_path else "mistralai/Mistral-7B-Instruct-v0.2" + ) + tp_size = use_gpu_num if use_gpu_num else 1 + + # model setup + llm = LLM( + model_path, + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) + tokenizer = llm.get_tokenizer() + tokenizer.pad_token_id = tokenizer.eos_token_id + llm.set_tokenizer(tokenizer) + vllm_global_dict[model_name] = llm + prompt = f"[INST] {prompt} [/INST]" + llm: LLM = vllm_global_dict[model_name] + + return _get_response_from_local_llm( + llm, + prompt=prompt, + temperature=temperature, + max_new_tokens=max_new_tokens, + )