Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--sampling_backend",
type=str,
choices=["triton", "sglang_kernel"],
choices=["triton", "flashinfer"],
default="triton",
help="""sampling used impl. 'triton' is use torch and triton kernel,
sglang_kernel use sglang_kernel impl""",
flashinfer use flashinfer sampling impl""",
)
parser.add_argument(
"--penalty_counter_mode",
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class StartArgs:
default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]}
)
llm_kv_quant_group_size: int = field(default=8)
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]})
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "flashinfer"]})
penalty_counter_mode: str = field(
default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def _top_p_top_k_sample(
b_top_ks: torch.Tensor,
exist_req_use_random_seed: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_env_start_args().sampling_backend == "triton":
sampling_backend = get_env_start_args().sampling_backend

if sampling_backend == "triton":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The flashinfer sampling backend does not support per-request random seeds (exist_req_use_random_seed), which will break deterministic/seeded generation when a custom seed is provided. To ensure correctness, we should fall back to the triton sampling implementation when exist_req_use_random_seed is True.

Suggested change
if sampling_backend == "triton":
if sampling_backend == "triton" or exist_req_use_random_seed:

probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks)
if not exist_req_use_random_seed:
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
Expand All @@ -124,8 +126,8 @@ def _top_p_top_k_sample(
next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index))
return next_token_ids.view(-1), next_token_logprobs.view(-1)

elif get_env_start_args().sampling_backend == "sglang_kernel":
from sgl_kernel import top_k_top_p_sampling_from_probs
elif sampling_backend == "flashinfer":
from flashinfer.sampling import top_k_top_p_sampling_from_probs

batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ frozendict==2.4.6
atomics==1.0.3
easydict==1.13
hypercorn==0.18.0
flashinfer-python==0.6.8.post1
flashinfer-cubin==0.6.8.post1
flashinfer-python==0.6.12
flashinfer-cubin==0.6.12
sglang-kernel==0.4.2.post1
httpx==0.28.1
librosa==0.11.0
Expand Down
14 changes: 12 additions & 2 deletions test/benchmark/service/benchmark_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio):


model_name = []
sampling_config = {
"temperature": 1.0,
"top_p": 0.9,
"top_k": -1,
}


# Minimal fix: one retry on transient network errors.
Expand All @@ -123,7 +128,9 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session):
"max_tokens": max_new_tokens,
"ignore_eos": True,
"stream": True,
"temperature": 0.0,
"temperature": sampling_config["temperature"],
"top_p": sampling_config["top_p"],
"top_k": sampling_config["top_k"],
"best_of": 1,
}
headers = {"Content-Type": "application/json"}
Expand Down Expand Up @@ -166,9 +173,12 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session):
data = {
"inputs": text_input,
"parameters": {
"do_sample": False,
"do_sample": True,
"ignore_eos": True,
"max_new_tokens": max_new_tokens,
"temperature": sampling_config["temperature"],
"top_p": sampling_config["top_p"],
"top_k": sampling_config["top_k"],
"add_special_tokens": False,
},
}
Expand Down
Loading