Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"sortedcontainers",
"word2number",
"transformers",
"tinker",
]

[project.scripts]
Expand Down
88 changes: 88 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,3 +1219,91 @@ async def test_generate(self):
response.prompt_length, 40960
) # If not long enough, please add more files to prompt
self.assertGreater(response.logprobs.shape[0], 1000)


class TestTinkerAPI(RayUnittestBaseAysnc):
"""Test the Tinker API integration with the vLLM engine."""

def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_tinker_api(self):
from tinker import types
from transformers import AutoTokenizer

engine = self.engines[0]
tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
result_dict = tokenizer.apply_chat_template(
messages,
chat_template=CHAT_TEMPLATE,
add_generation_prompt=False,
padding=False,
truncation=True,
return_tensors="pt",
add_special_tokens=False,
return_assistant_tokens_mask=True,
return_dict=True,
)
prompt = types.ModelInput.from_ints(
result_dict["input_ids"][0].tolist(),
)
# sample api without prompt logprobs
num_samples = 4
response = await engine.sample.remote(
prompt=prompt,
num_samples=num_samples,
sampling_params=types.SamplingParams(temperature=0.7), # no limit on length
)
self.assertEqual(len(response.sequences), num_samples)
for sequence in response.sequences:
self.assertEqual(len(sequence.tokens), len(sequence.logprobs))
self.assertEqual(sequence.stop_reason, "stop")
self.assertIsNone(response.prompt_logprobs)
self.assertIsNone(response.topk_prompt_logprobs)
# sample api with prompt logprobs
num_samples = 2
topk_prompt_logprobs = 3
response = await engine.sample.remote(
prompt=prompt,
num_samples=num_samples,
sampling_params=types.SamplingParams(temperature=0.7, max_tokens=8),
include_prompt_logprobs=True,
topk_prompt_logprobs=topk_prompt_logprobs,
)
self.assertEqual(len(response.sequences), num_samples)
for sequence in response.sequences:
self.assertEqual(len(sequence.tokens), len(sequence.logprobs))
self.assertEqual(sequence.stop_reason, "length")
self.assertEqual(len(response.prompt_logprobs), len(prompt.to_ints()))
self.assertIsNone(response.prompt_logprobs[0])
self.assertEqual(len(response.topk_prompt_logprobs), len(prompt.to_ints()))
self.assertIsNone(response.topk_prompt_logprobs[0])
for topk_logprobs in response.topk_prompt_logprobs[1:]:
self.assertIsNotNone(topk_logprobs)
self.assertEqual(len(topk_logprobs), topk_prompt_logprobs)
# compute_logprob api
response = await engine.sample.remote(
prompt=prompt,
num_samples=1,
sampling_params=types.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
)
self.assertEqual(len(response.sequences), 1)
self.assertEqual(response.sequences[0].stop_reason, "length")
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
self.assertIsNone(response.topk_prompt_logprobs)
88 changes: 87 additions & 1 deletion trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import ray
Expand Down Expand Up @@ -402,6 +402,92 @@ async def logprobs( # type: ignore [override]
dtype=torch.float32,
)

async def sample(
self,
prompt: Any,
num_samples: int,
sampling_params: Any,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
lora_request: Optional[Any] = None,
) -> Any:
"""Tinker compatible sampling interface.

Args:
prompt (ModelInput): The input prompt.
num_samples (int): The number of samples to generate.
sampling_params (SamplingParams): The sampling parameters.
include_prompt_logprobs (bool): Whether to include prompt logprobs.
topk_prompt_logprobs (int): The top-k prompt logprobs to include.
lora_request (LoRARequest, optional): The LoRA request. Defaults to None.
Returns:
SampleResponse: The sample response.
"""
from tinker.types import SampledSequence, SampleResponse

params = {
"max_tokens": sampling_params.max_tokens
if sampling_params.max_tokens is not None
else self.config.max_response_tokens,
"seed": sampling_params.seed if sampling_params.seed is not None else self.config.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"temperature": sampling_params.temperature,
"n": num_samples,
"prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None),
# in vLLM, 0 means only return the chosen token's logprob
"logprobs": 0,
}
if sampling_params.stop is not None:
params["stop"] = sampling_params.stop
req_output = await self._generate_internal(
prompt={"prompt_token_ids": prompt.to_ints()},
lora_request=lora_request,
**params,
)
sequences = []
# vLLM's prompt_logprobs output does not include a value for the first token.
# Initialize with [None] to align with the prompt tokens.
topk_prompt_logprobs_list: List[Optional[List[Tuple[int, float]]]] = [None]
prompt_logprobs: List[Optional[float]] = [None]

# collect prompt logprobs
if include_prompt_logprobs:
for logprob_dict in req_output.prompt_logprobs[1:]:
prompt_logprobs.append(next(iter(logprob_dict.values())).logprob)
if topk_prompt_logprobs > 0:
# collect top-k prompt logprobs
# logprob_dict: {token_id: Logprob(logprob, rank, ...), ...}
logprob_items = list(logprob_dict.items())
# sort by Logprob.rank
logprob_items_sorted = sorted(logprob_items, key=lambda x: x[1].rank)
# pick topk
topk = logprob_items_sorted[:topk_prompt_logprobs]
# record as (token_id, logprob)
topk_prompt_logprobs_list.append(
[(token_id, logprob.logprob) for token_id, logprob in topk]
)
# collect response sequences
for seq_output in req_output.outputs:
seq = SampledSequence(
stop_reason="length" if seq_output.finish_reason == "length" else "stop",
tokens=seq_output.token_ids,
logprobs=[
next(iter(logprob_dict.values())).logprob
for logprob_dict in seq_output.logprobs
],
)
sequences.append(seq)
return SampleResponse(
sequences=sequences,
prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None,
topk_prompt_logprobs=(
topk_prompt_logprobs_list
if include_prompt_logprobs and topk_prompt_logprobs > 0
else None
),
)

async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any:
# Send the request to the LLM engine.
self.request_id += 1
Expand Down