diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 2db6c67e7..434d7c191 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -644,10 +644,10 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--sampling_backend", type=str, - choices=["triton", "sglang_kernel"], + choices=["triton", "flashinfer"], default="triton", help="""sampling used impl. 'triton' is use torch and triton kernel, - sglang_kernel use sglang_kernel impl""", + flashinfer use flashinfer sampling impl""", ) parser.add_argument( "--penalty_counter_mode", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 6d0ee0746..2d623a4eb 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -149,7 +149,7 @@ class StartArgs: default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} ) llm_kv_quant_group_size: int = field(default=8) - sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) + sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "flashinfer"]}) penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 41e89da9a..5b29ea051 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -114,7 +114,9 @@ def _top_p_top_k_sample( b_top_ks: torch.Tensor, exist_req_use_random_seed: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - if get_env_start_args().sampling_backend == "triton": + sampling_backend = get_env_start_args().sampling_backend + + if sampling_backend == "triton": probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks) if not exist_req_use_random_seed: sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) @@ -124,8 +126,8 @@ def _top_p_top_k_sample( next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) return next_token_ids.view(-1), next_token_logprobs.view(-1) - elif get_env_start_args().sampling_backend == "sglang_kernel": - from sgl_kernel import top_k_top_p_sampling_from_probs + elif sampling_backend == "flashinfer": + from flashinfer.sampling import top_k_top_p_sampling_from_probs batch_next_token_ids = top_k_top_p_sampling_from_probs( probs, diff --git a/requirements.txt b/requirements.txt index f124ce76f..c58fee551 100644 --- a/requirements.txt +++ b/requirements.txt @@ -80,8 +80,8 @@ frozendict==2.4.6 atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 -flashinfer-python==0.6.8.post1 -flashinfer-cubin==0.6.8.post1 +flashinfer-python==0.6.12 +flashinfer-cubin==0.6.12 sglang-kernel==0.4.2.post1 httpx==0.28.1 librosa==0.11.0 diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index a9083091e..43f60b91d 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -108,6 +108,11 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio): model_name = [] +sampling_config = { + "temperature": 1.0, + "top_p": 0.9, + "top_k": -1, +} # Minimal fix: one retry on transient network errors. @@ -123,7 +128,9 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session): "max_tokens": max_new_tokens, "ignore_eos": True, "stream": True, - "temperature": 0.0, + "temperature": sampling_config["temperature"], + "top_p": sampling_config["top_p"], + "top_k": sampling_config["top_k"], "best_of": 1, } headers = {"Content-Type": "application/json"} @@ -166,9 +173,12 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session): data = { "inputs": text_input, "parameters": { - "do_sample": False, + "do_sample": True, "ignore_eos": True, "max_new_tokens": max_new_tokens, + "temperature": sampling_config["temperature"], + "top_p": sampling_config["top_p"], + "top_k": sampling_config["top_k"], "add_special_tokens": False, }, }