
SGLang on ECS: efficiently serving leading OSS LLMs on AWS
Benefits of SGLang as a runtime environment for OSS LLM. Sharing a Docker image to run it on ECS. Reasoning inference examples with Qwen QwQ 32-B.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@function
def tip_suggestion(s):
s += assistant(
"Here are two tips for staying healthy: "
"1. Balanced Diet. 2. Regular Exercise.\n\n"
)
forks = s.fork(2)
for i, f in enumerate(forks):
f += assistant(
f"Now, expand tip {i+1} into a paragraph:\n"
+ gen("detailed_tip", max_tokens=256, stop="\n\n")
)
s += assistant("Tip 1:" + forks[0]["detailed_tip"] + "\n")
s += assistant("Tip 2:" + forks[1]["detailed_tip"] + "\n")
s += assistant(
"To summarize the above two tips, I can say:\n" + gen("summary", max_tokens=512)
)
state = tip_suggestion()
print_highlight(state["summary"])
docker pull didierdurand/lic-sglang:amzn2023-latest
The image that we publish is built directly on GitHub via this GitHub Action. You can see corresponding executions on this page- The image is based on Amazon Linux 2023. But, it can be adapted to other flavors of Linux: Ubuntu, RedHat, etc.
- The SGLang project is still in activedevelopment with some features or parameters yet to be added. So, we copy into the image at build time a bash shell customize_sglang.sh that allows for customization. For example, as of this writing, we update some http timeout parameter in the source code via
sed 's/timeout_keep_alive=5,/timeout_keep_alive=500,/' -i $FILE
for file forFILE='/usr/local/lib/python3.12/site-packages/sglang/srt/entrypoints/http_server.py
. You can add your own customizations in this shell. - We also copy into the image a start_sglang.sh to dynamically build the start SGLang command from env variables received from the
docker run
command. Different models have different requirements for the various parameters proposed by SGLang. It allows to keep the launch parameters external to the image: the same Docker image can be used for multiple LLMs. - It is unsustainable to include the weights of the LLM into the image: they are most often too big (60B+ for QwQ-32B for example) and they would tie the image to a specific LLM. The latency of a live fetch is too long for live pull from HuggingFace at each start. So, we use a Docker Bind Mounts at container start to link the
/home/model
directory of the image to an external directory of the home server, where the model weights are stored, in our fast on a fast AWS EBS volume. - We define multiple Docker ENV variables to collect parameters required to issue the right start command for SGLang. Those variables will be populated via
--env
option. (doc for details) - the final
|| sleep infinity
is a trick to keep the container up & running even if the SGLang start command fails for any reason. It allows to connect to the container viadocker exect -it <container-id> /bin/bash
to debug the problem.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
FROM public.ecr.aws/amazonlinux/amazonlinux:2023
# to extend / customize SGLang at build time
ARG CUSTOMIZE_SGLANG="customize_sglang.sh"
# to extend / customize SGLang launch steps and parameters
ARG START_SGLANG="start_sglang.sh"
# model dir must be created at image build time to allow volume bind mounts on container start
ARG SGL_MODEL_DIR="/home/model"
# to create directory for accessing model weights via local storage (Docker volume)
WORKDIR ${SGL_MODEL_DIR}
# to create directory for SGLang extensions
WORKDIR "/home/sglang"
# versions of components
ARG CUDA_VERSION="124"
ARG PYTHON_VERSION="3.12"
ARG TORCH_VERSION="2.5"
ARG SGL_VERSION="0.4.3.post4"
ARG SGL_LINKS="https://flashinfer.ai/whl/cu${CUDA_VERSION}/torch${TORCH_VERSION}/flashinfer-python"
ARG SGLANG_TP_SIZE=2
ARG SGL_HOST="0.0.0.0"
ARG SGL_PORT=30000
# debug, info, warning, error
ARG SGL_LOG_LEVEL="info"
# install tools
RUN yum update -y \
&& yum install -y awscli wget findutils which grep sed git patch \
&& yum install -y kernel-headers kernel-devel python${PYTHON_VERSION}-devel \
&& yum clean all
# install Python & sglang
RUN yum install -y python${PYTHON_VERSION} \
&& yum clean all \
&& python${PYTHON_VERSION} -m ensurepip --upgrade \
&& python${PYTHON_VERSION} -m pip install --upgrade pip \
&& python${PYTHON_VERSION} -m pip install --upgrade --no-cache-dir "sglang[all]==${SGL_VERSION}" --find-links ${SGL_LINKS}
# to be able to know the build versions at runtime
RUN echo "cuda version: ${CUDA_VERSION}" >> sglang_versions.txt \
&& echo "python version: ${PYTHON_VERSION}" >> sglang_versions.txt \
&& echo "torch version: ${TORCH_VERSION}" >> sglang_versions.txt \
&& echo "sglang version: ${SGL_VERSION}" >> sglang_versions.txt
COPY "extend/"${CUSTOMIZE_SGLANG} ${CUSTOMIZE_SGLANG}
COPY "extend/"${START_SGLANG} ${START_SGLANG}
RUN ls -lh \
&& bash ${CUSTOMIZE_SGLANG}
# turn needed build args into runtime env vars
# set up python version
ENV PYTHON_VERSION=${PYTHON_VERSION}
# communication parameters
ENV SGL_PORT=${SGL_PORT}
ENV SGL_HOST=${SGL_HOST}
# SGLang parameters
ENV SGL_TP_SIZE=${SGL_TP_SIZE}
ENV SGL_LOG_LEVEL=${SGL_LOG_LEVEL}
ENV SGL_PARAMS=""
# model info
ENV SGL_MODEL=""
ENV SGL_MODEL_DIR=${SGL_MODEL_DIR}
EXPOSE ${SGL_PORT}
CMD ["bash", "-c", "bash start_sglang.sh || sleep infinity"]
Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=6.09 GB, mem usage=15.54 GB.
show that the total model (60B+) is loaded in 4 equal parts on each GPU.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
python3.12 -m sglang.launch_server --model Qwen/QwQ-32B --model-path /home/model/Qwen/QwQ-32B --host 0.0.0.0 --port 30000 --tensor-parallel-size 4 --log-level info --enable-metrics --trust-remote-code --enable-p2p-check
INFO 03-07 12:29:02 __init__.py:190] Automatically detected platform cuda.
[2025-03-08 12:29:04] server_args=ServerArgs(model_path='/home/model/Qwen/QwQ-32B', tokenizer_path='/home/model/Qwen/QwQ-32B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=True, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='/home/model/Qwen/QwQ-32B', chat_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30000, mem_fraction_static=0.85, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=2048, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=4, stream_interval=1, stream_output=False, random_seed=811642422, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=True, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=5, speculative_eagle_topk=4, speculative_num_draft_tokens=8, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=80, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=True, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False, flashinfer_mla_disable_ragged=False, warmups=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False)
INFO 03-07 12:29:08 __init__.py:190] Automatically detected platform cuda.
INFO 03-07 12:29:08 __init__.py:190] Automatically detected platform cuda.
INFO 03-07 12:29:08 __init__.py:190] Automatically detected platform cuda.
INFO 03-07 12:29:08 __init__.py:190] Automatically detected platform cuda.
INFO 03-07 12:29:09 __init__.py:190] Automatically detected platform cuda.
[2025-03-08 12:29:11 TP3] Init torch distributed begin.
[2025-03-08 12:29:11 TP1] Init torch distributed begin.
[2025-03-08 12:29:11 TP0] Init torch distributed begin.
[2025-03-08 12:29:11 TP2] Init torch distributed begin.
[2025-03-08 12:29:12 TP0] sglang is using nccl==2.21.5
[2025-03-08 12:29:12 TP1] sglang is using nccl==2.21.5
[2025-03-08 12:29:12 TP2] sglang is using nccl==2.21.5
[2025-03-08 12:29:12 TP3] sglang is using nccl==2.21.5
[2025-03-08 12:29:12 TP0] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-03-08 12:29:12 TP1] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-03-08 12:29:12 TP2] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-03-08 12:29:12 TP3] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
[2025-03-08 12:29:13 TP2] Init torch distributed ends. mem usage=0.13 GB
[2025-03-08 12:29:13 TP3] Init torch distributed ends. mem usage=0.13 GB
[2025-03-08 12:29:13 TP1] Init torch distributed ends. mem usage=0.13 GB
[2025-03-08 12:29:13 TP0] Init torch distributed ends. mem usage=0.13 GB
[2025-03-08 12:29:13 TP3] Load weight begin. avail mem=21.63 GB
[2025-03-08 12:29:13 TP0] Load weight begin. avail mem=21.63 GB
[2025-03-08 12:29:13 TP1] Load weight begin. avail mem=21.63 GB
[2025-03-08 12:29:13 TP2] Load weight begin. avail mem=21.63 GB
[2025-03-08 12:29:13 TP3] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-08 12:29:13 TP1] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-08 12:29:13 TP0] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-08 12:29:13 TP2] The following error message 'operation scheduled before its operands' can be ignored.
Loading safetensors checkpoint shards: 0% Completed | 0/14 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 7% Completed | 1/14 [00:00<00:04, 2.93it/s]
Loading safetensors checkpoint shards: 14% Completed | 2/14 [00:00<00:05, 2.00it/s]
Loading safetensors checkpoint shards: 21% Completed | 3/14 [00:01<00:05, 1.86it/s]
Loading safetensors checkpoint shards: 29% Completed | 4/14 [00:02<00:05, 1.77it/s]
Loading safetensors checkpoint shards: 36% Completed | 5/14 [00:02<00:05, 1.77it/s]
Loading safetensors checkpoint shards: 43% Completed | 6/14 [00:03<00:04, 1.75it/s]
Loading safetensors checkpoint shards: 50% Completed | 7/14 [00:03<00:04, 1.74it/s]
Loading safetensors checkpoint shards: 57% Completed | 8/14 [00:04<00:03, 1.74it/s]
Loading safetensors checkpoint shards: 64% Completed | 9/14 [00:05<00:02, 1.71it/s]
Loading safetensors checkpoint shards: 71% Completed | 10/14 [00:05<00:02, 1.69it/s]
Loading safetensors checkpoint shards: 79% Completed | 11/14 [00:05<00:01, 2.03it/s]
Loading safetensors checkpoint shards: 86% Completed | 12/14 [00:06<00:00, 2.05it/s]
Loading safetensors checkpoint shards: 93% Completed | 13/14 [00:06<00:00, 1.95it/s]
Loading safetensors checkpoint shards: 100% Completed | 14/14 [00:07<00:00, 1.87it/s]
Loading safetensors checkpoint shards: 100% Completed | 14/14 [00:07<00:00, 1.85it/s]
[2025-03-08 12:29:21 TP1] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=6.09 GB, mem usage=15.54 GB.
[2025-03-08 12:29:21 TP2] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=6.09 GB, mem usage=15.54 GB.
[2025-03-08 12:29:21 TP3] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=6.09 GB, mem usage=15.54 GB.
[2025-03-08 12:29:21 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=6.09 GB, mem usage=15.54 GB.
[2025-03-08 12:29:21 TP0] KV Cache is allocated. #tokens: 46611, K size: 1.42 GB, V size: 1.42 GB
[2025-03-08 12:29:21 TP1] KV Cache is allocated. #tokens: 46611, K size: 1.42 GB, V size: 1.42 GB
[2025-03-08 12:29:21 TP2] KV Cache is allocated. #tokens: 46611, K size: 1.42 GB, V size: 1.42 GB
[2025-03-08 12:29:21 TP3] KV Cache is allocated. #tokens: 46611, K size: 1.42 GB, V size: 1.42 GB
[2025-03-08 12:29:21 TP0] Memory pool end. avail mem=2.09 GB
[2025-03-08 12:29:21 TP1] Memory pool end. avail mem=2.09 GB
[2025-03-08 12:29:21 TP2] Memory pool end. avail mem=2.09 GB
[2025-03-08 12:29:21 TP3] Memory pool end. avail mem=2.09 GB
[2025-03-08 12:29:21 TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=1.45 GB
0%| | 0/13 [00:00<?, ?it/s][2025-03-08 12:29:21 TP1] Capture cuda graph begin. This can take up to several minutes. avail mem=1.45 GB
[2025-03-08 12:29:21 TP2] Capture cuda graph begin. This can take up to several minutes. avail mem=1.45 GB
[2025-03-08 12:29:21 TP3] Capture cuda graph begin. This can take up to several minutes. avail mem=1.45 GB
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00, 1.80it/s]
[2025-03-08 12:29:28 TP0] Capture cuda graph end. Time elapsed: 7.23 s. avail mem=0.53 GB. mem usage=0.93 GB.
[2025-03-08 12:29:28 TP3] Capture cuda graph end. Time elapsed: 7.22 s. avail mem=0.53 GB. mem usage=0.93 GB.
[2025-03-08 12:29:28 TP2] Capture cuda graph end. Time elapsed: 7.24 s. avail mem=0.53 GB. mem usage=0.93 GB.
[2025-03-08 12:29:28 TP1] Capture cuda graph end. Time elapsed: 7.25 s. avail mem=0.53 GB. mem usage=0.93 GB.
[2025-03-08 12:29:29 TP0] max_total_num_tokens=46611, chunked_prefill_size=2048, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2025-03-08 12:29:29 TP3] max_total_num_tokens=46611, chunked_prefill_size=2048, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2025-03-08 12:29:29 TP2] max_total_num_tokens=46611, chunked_prefill_size=2048, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2025-03-08 12:29:29 TP1] max_total_num_tokens=46611, chunked_prefill_size=2048, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2025-03-08 12:29:29] INFO: Started server process [3207]
[2025-03-08 12:29:29] INFO: Waiting for application startup.
[2025-03-08 12:29:29] INFO: Application startup complete.
[2025-03-08 12:29:29] INFO: Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2025-03-08 12:29:30] INFO: 127.0.0.1:49534 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-03-08 12:29:30 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-08 12:29:32] INFO: 10.0.2.141:37662 - "GET /health HTTP/1.1" 200 OK
[2025-03-08 12:29:33] INFO: 127.0.0.1:49540 - "POST /generate HTTP/1.1" 200 OK
[2025-03-08 12:29:33] The server is fired up and ready to roll!
[2025-03-08 12:29:33] INFO: 10.0.0.207:24382 - "GET /health HTTP/1.1" 200 OK
[2025-03-08 12:29:37] INFO: 10.0.2.141:42304 - "GET /health HTTP/1.1" 200 OK
[2025-03-08 12:29:38] INFO: 10.0.0.207:24394 - "GET /health HTTP/1.1" 200 OK
[2025-03-08 12:29:42] INFO: 10.0.2.141:42306 - "GET /health HTTP/1.1" 200 OK
[2025-03-08 12:29:43] INFO: 10.0.0.207:39328 - "GET /health HTTP/1.1" 200 OK
[2025-03-08 12:29:47] INFO: 10.0.2.141:18968 - "GET /health HTTP/1.1" 200 OK
How many letters R in word 'strawberry' ?
- S
- T
- R
- A
- W
- B
- E
- R
- R
- Y
- The third letter is R.
- Then after E (7th), the next two letters are R and R. So that's two more. So total 1 + 2 = 3. So three R's. Hmm, but I've heard sometimes people might think there are two. Maybe a common mistake? Let me check again. Let me spell it slowly: S-T-R-A-W-B-E-R-R-Y. So after the B and E comes two R's in a row, then Y. So that's two R's after the E, plus the one after T. So total three. Therefore, the answer is 2? Wait no, three. Wait, maybe I'm confused. Let me count the Rs only:
- S – no
- T – no
- R – yes (1)
- A – no
- W – no
- B – no
- E – no
- R – yes (2)
- R – yes (3)
- Y – no
- S
- T
- R
- A
- W
- B
- E
- R
- R
- Y
A man has 53 socks in his drawer: 21 identical blue, 15 identical black and 17 identical red. The lights are out, and he is completely in the dark. How many socks must he take out to make 100 percent certain he has at least one pair of black socks?
- Blue socks: 21
- Red socks: 17
Any opinions in this post are those of the individual author and may not reflect the opinions of AWS.