diff --git a/.buildkite/cases/comprehensive-cases.txt b/.buildkite/cases/comprehensive-cases.txt index 170e79b1833..bc3f7975450 100644 --- a/.buildkite/cases/comprehensive-cases.txt +++ b/.buildkite/cases/comprehensive-cases.txt @@ -3,3 +3,6 @@ local_disk.yaml local_cpu_mla.yaml pd.yaml multi_device.yaml +async.yaml +p2p.yaml +layerwise.yaml \ No newline at end of file diff --git a/.buildkite/configs/async.yaml b/.buildkite/configs/async.yaml new file mode 100644 index 00000000000..f9ea31af3e9 --- /dev/null +++ b/.buildkite/configs/async.yaml @@ -0,0 +1,28 @@ +workload: + type: long_doc_qa + max-inflight-requests: 20 + sleep-time-after-warmup: 20 + expected-latency-gain: 1.5 + num-documents: 20 + repeat-count: 1 + hit-miss-ratio: 2:2 + +docker: + env: + - "LMCACHE_CHUNK_SIZE=256" + - "LMCACHE_LOCAL_CPU=False" + - "LMCACHE_MAX_LOCAL_CPU_SIZE=70" + - "LMCACHE_MAX_LOCAL_DISK_SIZE=70" + - "LMCACHE_LOCAL_DISK=\"file:///local/end-to-end-tests/local/\"" + - "LMCACHE_ENABLE_ASYNC_LOADING=True" + - "LMCACHE_EXTRA_CONFIG={\"lookup_backoff_time\": 0.01, \"use_odirect\": True}" + - "LMCACHE_SAVE_UNFULL_CHUNK=False" + +vllm: + model: "meta-llama/Llama-3.1-8B-Instruct" + args: + - "--load-format" + - "dummy" + - "--no-enable-prefix-caching" + - "--kv-transfer-config" + - "{\"kv_connector\":\"LMCacheConnectorV1\",\"kv_role\":\"kv_both\"}" diff --git a/.buildkite/configs/layerwise.yaml b/.buildkite/configs/layerwise.yaml new file mode 100644 index 00000000000..3611059037f --- /dev/null +++ b/.buildkite/configs/layerwise.yaml @@ -0,0 +1,22 @@ +workload: + type: long_doc_qa + max-inflight-requests: 20 + expected-latency-gain: 3 + +docker: + env: + - "LMCACHE_CHUNK_SIZE=256" + - "LMCACHE_LOCAL_CPU=True" + - "LMCACHE_MAX_LOCAL_CPU_SIZE=5" + - "LMCACHE_USE_LAYERWISE=true" + +vllm: + model: "meta-llama/Llama-3.2-1B-Instruct" + args: + - "--load-format" + - "dummy" + - "--no-enable-prefix-caching" + - "--kv-transfer-config" + - "{\"kv_connector\":\"LMCacheConnectorV1\",\"kv_role\":\"kv_both\"}" + - "--compilation-config" + - "{\"cudagraph_mode\":\"PIECEWISE\"}" diff --git a/.buildkite/configs/local_cpu.yaml b/.buildkite/configs/local_cpu.yaml index 5b030852042..37142473d71 100644 --- a/.buildkite/configs/local_cpu.yaml +++ b/.buildkite/configs/local_cpu.yaml @@ -1,7 +1,7 @@ workload: type: long_doc_qa max-inflight-requests: 20 - expected-latency-gain: 3.7 + expected-latency-gain: 3.6 docker: env: diff --git a/.buildkite/configs/p2p.yaml b/.buildkite/configs/p2p.yaml new file mode 100644 index 00000000000..937d6a390b8 --- /dev/null +++ b/.buildkite/configs/p2p.yaml @@ -0,0 +1,63 @@ +workload: + type: long_doc_qa + num-documents: 20 + max-inflight-requests: 2 + repeat-count: 1 + expected-latency: 4 + +feature: + type: p2p + +docker1: + env: + - "LMCACHE_MAX_LOCAL_CPU_SIZE=60" + - "LMCACHE_ENABLE_ASYNC_LOADING=True" + - "LMCACHE_ENABLE_P2P=True" + - "LMCACHE_P2P_HOST=localhost" + - "LMCACHE_P2P_INIT_PORTS=8200" + - "LMCACHE_P2P_LOOKUP_PORTS=8201" + - "LMCACHE_TRANSFER_CHANNEL=nixl" + - "LMCACHE_ENABLE_CONTROLLER=True" + - "LMCACHE_LMCACHE_INSTANCE_ID=lmcache_instance_1" + - "LMCACHE_LMCACHE_WORKER_PORTS=8500" + - "LMCACHE_EXTRA_CONFIG={\"lookup_backoff_time\": 0.001}" + - "LMCACHE_SAVE_UNFULL_CHUNK=False" + - "PYTHONHASHSEED=123" + pull-port: 8300 + reply-port: 8400 + +docker2: + env: + - "LMCACHE_MAX_LOCAL_CPU_SIZE=60" + - "LMCACHE_ENABLE_ASYNC_LOADING=True" + - "LMCACHE_ENABLE_P2P=True" + - "LMCACHE_P2P_HOST=localhost" + - "LMCACHE_P2P_INIT_PORTS=8202" + - "LMCACHE_P2P_LOOKUP_PORTS=8203" + - "LMCACHE_TRANSFER_CHANNEL=nixl" + - "LMCACHE_ENABLE_CONTROLLER=True" + - "LMCACHE_LMCACHE_INSTANCE_ID=lmcache_instance_2" + - "LMCACHE_LMCACHE_WORKER_PORTS=8501" + - "LMCACHE_EXTRA_CONFIG={\"lookup_backoff_time\": 0.001}" + - "LMCACHE_SAVE_UNFULL_CHUNK=False" + - "PYTHONHASHSEED=123" + pull-port: 8300 + reply-port: 8400 + +vllm1: + model: "meta-llama/Llama-3.1-8B-Instruct" + args: + - "--load-format" + - "dummy" + - "--no-enable-prefix-caching" + - "--kv-transfer-config" + - "{\"kv_connector\":\"LMCacheConnectorV1\",\"kv_role\":\"kv_both\"}" + +vllm2: + model: "meta-llama/Llama-3.1-8B-Instruct" + args: + - "--load-format" + - "dummy" + - "--no-enable-prefix-caching" + - "--kv-transfer-config" + - "{\"kv_connector\":\"LMCacheConnectorV1\",\"kv_role\":\"kv_both\"}" diff --git a/.buildkite/scripts/vllm-integration-tests.sh b/.buildkite/scripts/vllm-integration-tests.sh index 0cce059ed11..498a5a3e905 100755 --- a/.buildkite/scripts/vllm-integration-tests.sh +++ b/.buildkite/scripts/vllm-integration-tests.sh @@ -272,6 +272,123 @@ run_pd_lmcache() { sleep 10 } +run_p2p_lmcache() { + local docker1="$1" + local vllm1="$2" + local docker2="$3" + local vllm2="$4" + local cfg_name="$5" + LOGFILE1="/tmp/build_${BUILD_ID}_${cfg_name}1.log" + LOGFILE2="/tmp/build_${BUILD_ID}_${cfg_name}2.log" + + ########## Instance 1 ########## + # docker args + docker1_args=( + --runtime nvidia + --network host + --gpus "device=0" + --volume ~/.cache/huggingface:/root/.cache/huggingface + --env VLLM_USE_FLASHINFER_SAMPLER=0 + --env HF_TOKEN="$HF_TOKEN" + --env UCX_TLS=tcp + --ipc host + --shm-size 4G + ) + while IFS= read -r e; do + [[ -n $e ]] && docker1_args+=(--env "$e") + done < <(yq -r '.env[]?' <<<"$docker1") + pull=$(yq -er '."pull-port"' <<<"$docker1" 2>/dev/null) + docker1_args+=(--env "LMCACHE_CONTROLLER_PULL_URL=localhost:$pull") + reply=$(yq -er '."reply-port"' <<<"$docker1" 2>/dev/null) + docker1_args+=(--env "LMCACHE_CONTROLLER_REPLY_URL=localhost:$reply") + + # vllm args + vllm1_model="$(yq -r '.model' <<<"$vllm1")" + mapfile -t vllm1_cli_args < <(yq -r '.args // [] | .[]' <<<"$vllm1") + cmd_args1=( + lmcache/vllm-openai:build-latest + "$vllm1_model" + ) + cmd_args1+=("${vllm1_cli_args[@]}") + cmd_args1+=("--port" "$PORT1") + + ##### Controller part start ##### + if [ ! -d ".venv" ]; then + UV_PYTHON=python3 uv -q venv + fi + source .venv/bin/activate + uv pip install -r "$ORIG_DIR/requirements/build.txt" > /dev/null 2>&1 + uv pip install torch==2.7.1 httpx fastapi uvicorn > /dev/null 2>&1 + uv pip install -e "$ORIG_DIR" --no-build-isolation > /dev/null 2>&1 + # Start controller + PYTHONHASHSEED=123 lmcache_controller \ + --host localhost \ + --port "$PORT" \ + --monitor-ports "{\"pull\": ${pull}, \"reply\": ${reply}}" \ + > "/tmp/build_${BUILD_ID}_${cfg_name}_controller.log" 2>&1 & + sleep 10 + ##### Controller part end ##### + + # Start docker + CID1=$( + docker run -d \ + "${docker1_args[@]}" \ + "${cmd_args1[@]}" + ) + + # Health check + wait_for_openai_api_server "$PORT1" "$vllm1_model" "$CID1" + + # Logging + touch "$LOGFILE1" + docker logs -f "$CID1" >>"$LOGFILE1" 2>&1 & + + ########## Instance 2 ########## + # docker args + docker2_args=( + --runtime nvidia + --network host + --gpus "device=1" + --volume ~/.cache/huggingface:/root/.cache/huggingface + --env VLLM_USE_FLASHINFER_SAMPLER=0 + --env HF_TOKEN="$HF_TOKEN" + --env UCX_TLS=tcp + --ipc host + --shm-size 4G + ) + while IFS= read -r e; do + [[ -n $e ]] && docker2_args+=(--env "$e") + done < <(yq -r '.env[]?' <<<"$docker2") + pull=$(yq -er '."pull-port"' <<<"$docker2" 2>/dev/null) + docker2_args+=(--env "LMCACHE_CONTROLLER_PULL_URL=localhost:$pull") + reply=$(yq -er '."reply-port"' <<<"$docker2" 2>/dev/null) + docker2_args+=(--env "LMCACHE_CONTROLLER_REPLY_URL=localhost:$reply") + + # vllm args + vllm2_model="$(yq -r '.model' <<<"$vllm2")" + mapfile -t vllm2_cli_args < <(yq -r '.args // [] | .[]' <<<"$vllm2") + cmd_args2=( + lmcache/vllm-openai:build-latest + "$vllm2_model" + ) + cmd_args2+=("${vllm2_cli_args[@]}") + cmd_args2+=("--port" "$PORT2") + + # Start docker + CID2=$( + docker run -d \ + "${docker2_args[@]}" \ + "${cmd_args2[@]}" + ) + + # Health check + wait_for_openai_api_server "$PORT2" "$vllm2_model" "$CID2" + + # Logging + touch "$LOGFILE2" + docker logs -f "$CID2" >>"$LOGFILE2" 2>&1 & +} + usage() { echo "Usage: $0 [OPTIONS]" echo " " @@ -315,6 +432,7 @@ test_vllmopenai_server_with_lmcache_integrated() { run_long_doc_qa() { local workload_config="$1" + local port="$2" echo "→ Running long_doc_qa with customed workload config:" printf '%s\n' "$workload_config" @@ -349,7 +467,7 @@ run_long_doc_qa() { uv -q pip install openai pandas matplotlib python3 "$ORIG_DIR/benchmarks/long_doc_qa/long_doc_qa.py" \ "${workload_args[@]}" \ - --port="$PORT" \ + --port="$port" \ --output="response.txt" } @@ -433,6 +551,15 @@ for cfg_name in "${CONFIG_NAMES[@]}"; do decoder_vllm_args="$(yq '.["vllm-decoder"]' "$cfg_file")" run_pd_lmcache "$prefiller_docker_args" "$prefiller_vllm_args" "$decoder_docker_args" "$decoder_vllm_args" "$cfg_name" model="$(yq -r '.["vllm-prefiller"].model' "$cfg_file")" + elif [[ "$feature_type" == "p2p" ]]; then + PORT1=$(find_available_port 8177) + docker1_args="$(yq '.["docker1"]' "$cfg_file")" + vllm1_args="$(yq '.["vllm1"]' "$cfg_file")" + PORT2=$(find_available_port 8277) + docker2_args="$(yq '.["docker2"]' "$cfg_file")" + vllm2_args="$(yq '.["vllm2"]' "$cfg_file")" + run_p2p_lmcache "$docker1_args" "$vllm1_args" "$docker2_args" "$vllm2_args" "$cfg_name" + model="$(yq -r '.["vllm1"].model' "$cfg_file")" elif [[ -z "$feature_type" ]]; then docker_args="$(yq '.docker' "$cfg_file")" vllm_args="$(yq '.vllm' "$cfg_file")" @@ -446,7 +573,13 @@ for cfg_name in "${CONFIG_NAMES[@]}"; do test_vllmopenai_server_with_lmcache_integrated "$model" elif [ "$test_mode" = "long_doc_qa" ]; then workload_yaml="$(yq "(.workload * {\"model\": \"$model\"}) | del(.type)" "$cfg_file")" - run_long_doc_qa "$workload_yaml" + if [[ "$feature_type" == "p2p" ]]; then + tmp_workload_yaml=$(jq 'del(."expected-latency")' <<< "$workload_yaml") + run_long_doc_qa "$tmp_workload_yaml" "$PORT1" + run_long_doc_qa "$workload_yaml" "$PORT2" + else + run_long_doc_qa "$workload_yaml" "$PORT" + fi fi cleanup 0 diff --git a/.github/ISSUE_TEMPLATE/blank_issue.md b/.github/ISSUE_TEMPLATE/blank_issue.md index 1eb28d33a19..ec9dd7cce71 100644 --- a/.github/ISSUE_TEMPLATE/blank_issue.md +++ b/.github/ISSUE_TEMPLATE/blank_issue.md @@ -6,7 +6,7 @@ labels: '' assignees: '' --- **Label** -Please label your issue so that it can easily be easily categorized under [LMCache Onboarding](https://github.com/LMCache/LMCache/issues/627) +Please label your issue so that it can easily be easily categorized under [LMCache Onboarding](https://github.com/LMCache/LMCache/issues/1882) **Summary** A concise overview of the issue you want to raise. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 029df2539a3..1acfc3494e7 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,7 +7,7 @@ assignees: '' --- **Label** -Please label your issue with "bug" and any other relevant labels so that it can easily be easily categorized under [LMCache Onboarding](https://github.com/LMCache/LMCache/issues/627) +Please label your issue with "bug" and any other relevant labels so that it can easily be easily categorized under [LMCache Onboarding](https://github.com/LMCache/LMCache/issues/1882) **Describe the bug** A clear and concise description of what the bug is. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 2f6e79c5154..d5c5ea0fd56 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -7,7 +7,7 @@ assignees: '' --- **Label** -Please label your issue with "new feature" and any other relevant labels so that it can easily be easily categorized under [LMCache Onboarding](https://github.com/LMCache/LMCache/issues/627) +Please label your issue with "new feature" and any other relevant labels so that it can easily be easily categorized under [LMCache Onboarding](https://github.com/LMCache/LMCache/issues/1882) **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8130df1afe9..2d1b52e3a01 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,40 +1,12 @@ -FILL IN THE PR DESCRIPTION HERE + -FIX #xxxx (*link existing issues this PR will resolve*) +**What this PR does / why we need it**: -**PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** +**Special notes for your reviewers**: ---- +**If applicable**: -
- - PR Checklist (Click to Expand) - -

Thank you for your contribution to LMCache! Before submitting the pull request, please ensure the PR meets the following criteria. This helps us maintain the code quality and improve the efficiency of the review process.

- -

PR Title and Classification

-

Please try to classify PRs for easy understanding of the type of changes. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

- -

Note: If the PR spans more than one category, please include all relevant prefixes.

- -

Code Quality

- -

The PR need to meet the following code quality standards:

- - - -

What to Expect for the Reviews

- -We aim to address all PRs in a timely manner. If no one reviews your PR within 5 days, please @-mention one of KuntaiDu, ApostaC or YaoJiayi. - -
+- [ ] this PR contains user facing changes - docs added +- [ ] this PR contains unit tests diff --git a/.github/workflows/automerge-labeler.yml b/.github/workflows/automerge-labeler.yml new file mode 100644 index 00000000000..d3b49a55784 --- /dev/null +++ b/.github/workflows/automerge-labeler.yml @@ -0,0 +1,17 @@ +name: Label auto-merge PRs + +on: + pull_request_target: + types: [ auto_merge_enabled, auto_merge_disabled ] + +permissions: + pull-requests: write + +jobs: + add_remove_labels: + runs-on: ubuntu-latest + steps: + - uses: ubuntudroid/automerge-labeler@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + label: 'full' diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml index d66af4ed803..ef4b1f44981 100644 --- a/.github/workflows/build_doc.yml +++ b/.github/workflows/build_doc.yml @@ -53,7 +53,7 @@ jobs: rm -rf output/dev - name: Upload doc artifacts to GHA - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: doc-artifacts path: output/ @@ -69,7 +69,7 @@ jobs: egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs - name: Fetch doc artifacts - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 with: name: doc-artifacts path: output diff --git a/.github/workflows/code_quality_checks.yml b/.github/workflows/code_quality_checks.yml index 16cb6c0a17c..f1df04c88cc 100644 --- a/.github/workflows/code_quality_checks.yml +++ b/.github/workflows/code_quality_checks.yml @@ -1,6 +1,7 @@ name: Code Quality on: + workflow_call: pull_request: push: branches: [dev] diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index ed0d7485b38..ccbf598b8ba 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -103,7 +103,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@181d5eefc20863364f96762470ba6f862bdef56b # v3.29.2 + uses: github/codeql-action/init@f443b600d91635bebf5b0d9ebc620189c0d6fba5 # v4.30.8 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} @@ -131,6 +131,6 @@ jobs: exit 1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@181d5eefc20863364f96762470ba6f862bdef56b # v3.29.2 + uses: github/codeql-action/analyze@f443b600d91635bebf5b0d9ebc620189c0d6fba5 # v4.30.8 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 06912e874ab..51549598e2e 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -35,7 +35,7 @@ jobs: astral.sh:443 - name: Login to DockerHub - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: username: ${{ vars.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 7ea008366f7..cec8ca9a576 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -90,11 +90,20 @@ jobs: python -m cibuildwheel --output-dir dist - name: Upload release artifacts to GHA - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: release-artifacts path: dist/ + # Run tests and code quality checks before publishing + test: + name: Run tests + uses: ./.github/workflows/test.yml + + code-quality: + name: Run code quality checks + uses: ./.github/workflows/code_quality_checks.yml + # Push to Test PyPI when: # - a new GitHub release is published # - a PR is merged into dev branch (push only trigger) @@ -110,7 +119,7 @@ jobs: # see https://docs.pypi.org/trusted-publishers/ id-token: write runs-on: ubuntu-latest - needs: build-artifacts + needs: [build-artifacts, test, code-quality] steps: - name: Harden Runner @@ -127,7 +136,7 @@ jobs: rekor.sigstore.dev:443 - name: Fetch release artifacts - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 with: name: release-artifacts path: dist @@ -151,7 +160,7 @@ jobs: contents: write runs-on: ubuntu-latest - needs: build-artifacts + needs: [build-artifacts, test, code-quality] steps: - name: Harden Runner @@ -170,7 +179,7 @@ jobs: rekor.sigstore.dev:443 - name: Fetch release artifacts - uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 with: name: release-artifacts path: dist @@ -220,7 +229,7 @@ jobs: layers.nvcr.io:443 - name: Login to DockerHub - uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: username: ${{ vars.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} @@ -258,7 +267,7 @@ jobs: run: | docker build \ --tag lmcache/vllm-openai:lightweight --tag lmcache/vllm-openai:${{ env.LATEST_TAG }}-lightweight \ - --file docker/Dockerfile.lightweight + --file docker/Dockerfile.lightweight . - name: Push lmcache/vllm-openai:lightweight image to DockerHub run: | diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index c2cfd7eb61a..a1477e73c90 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -58,7 +58,7 @@ jobs: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@05b42c624433fc40578a4040d5cf5e36ddca8cde # v2.4.2 + uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3 with: results_file: results.sarif results_format: sarif @@ -83,7 +83,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: SARIF file path: results.sarif @@ -92,6 +92,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@v3 + uses: github/codeql-action/upload-sarif@v4 with: sarif_file: results.sarif diff --git a/.github/workflows/stale_bot.yml b/.github/workflows/stale_bot.yml index 2afdb65c9eb..70671f05294 100644 --- a/.github/workflows/stale_bot.yml +++ b/.github/workflows/stale_bot.yml @@ -30,7 +30,7 @@ jobs: api.github.com:443 - name: "Stale Action" - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: stale-issue-label: 'stale' stale-issue-message: > diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a62a0a04004..22c2617dfd7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,6 +1,7 @@ name: Test on: + workflow_call: workflow_dispatch: push: branches: @@ -38,8 +39,6 @@ jobs: strategy: matrix: python: - # Disable 3.9 until code supports it in https://github.com/LMCache/LMCache/pull/1584 - # - "3.9" - "3.10" - "3.11" - "3.12" @@ -49,7 +48,7 @@ jobs: steps: - name: "Harden Runner" - uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0 + uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1 with: egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs @@ -63,7 +62,7 @@ jobs: uses: ./.github/actions/free-disk-space - name: Setup Python ${{ matrix.python }} - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: ${{ matrix.python }} cache: pip @@ -74,11 +73,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + python -m pip install vllm python -m pip install -r requirements/test.txt python -m pip install -r requirements/common.txt - python -m pip install torch==2.7.1 torchaudio==2.7.1 torchvision==0.22.1 - - name: "Run non-CUDA unit tests (v1/storage_backend)" + - name: "Run non-CUDA unit tests" run: | - pytest tests/v1/storage_backend/ + pytest --ignore=tests/disagg --ignore=tests/v1/test_nixl_storage.py --ignore=tests/v1/multiprocess/test_cache_server.py diff --git a/.gitignore b/.gitignore index 85ee11dd2b1..e6d10e57f8e 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,9 @@ lmcache/experimental/tests /examples/offline_inference/buggy_example.py /examples/test_example +# benchmark results +*.csv + # disk cache /remote_disk /local_disk diff --git a/README.md b/README.md index 77714598c2f..64384847672 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ | [**Blog**](https://blog.lmcache.ai/) | [**Documentation**](https://docs.lmcache.ai/) -| [**Join Slack**](https://join.slack.com/t/lmcacheworkspace/shared_invite/zt-3bgx768yd-H8WkOTmPtbxVYJ5nuZ4dmA) +| [**Join Slack**](https://join.slack.com/t/lmcacheworkspace/shared_invite/zt-3g8e6xzz8-KzS_HI8bPERGFK5PTB~MYg) | [**Interest Form**](https://forms.gle/MHwLiYDU6kcW3dLj7) | [**Roadmap**](https://github.com/LMCache/LMCache/issues/1253) @@ -47,6 +47,7 @@ By combining LMCache with vLLM, developers achieve 3-10x delay savings and GPU c * High performance CPU KVCache offloading * Disaggregated prefill * P2P KVCache sharing +- [x] Integration with SGLang for KV cache offloading - [x] LMCache is supported in the [vLLM production stack](https://github.com/vllm-project/production-stack/), [llm-d](https://github.com/llm-d/llm-d/), and [KServe](https://github.com/kserve/kserve) - [x] Stable support for non-prefix KV caches - [x] Storage support as follows: @@ -131,6 +132,13 @@ If you use LMCache for your research, please cite our papers: booktitle = {Proceedings of the Twentieth European Conference on Computer Systems}, pages = {94–109}, } + +@article{cheng2025lmcache, + title={LMCache: An Efficient KV Cache Layer for Enterprise-Scale LLM Inference}, + author={Cheng, Yihua and Liu, Yuhan and Yao, Jiayi and An, Yuwei and Chen, Xiaokun and Feng, Shaoting and Huang, Yuyang and Shen, Samuel and Du, Kuntai and Jiang, Junchen}, + journal={arXiv preprint arXiv:2510.09665}, + year={2025} +} ``` ## Socials diff --git a/benchmarks/long_doc_qa/long_doc_qa.py b/benchmarks/long_doc_qa/long_doc_qa.py index 9757ab8ad1c..2de5e2cc820 100644 --- a/benchmarks/long_doc_qa/long_doc_qa.py +++ b/benchmarks/long_doc_qa/long_doc_qa.py @@ -487,12 +487,10 @@ async def main(args): query_mean_ttft = benchmark_df["ttft"].mean() CSI = "\x1b[" RESET = CSI + "0m" + print(f"Warmup round mean TTFT: {warmup_mean_ttft:.3f}s") + print(f"Warmup round time: {warmup_end_time - warmup_start_time:.3f}s") + print(f"Warmup round prompt count: {len(warmup_df)}") print(f"{CSI}36;1m\n=== BENCHMARK RESULTS ==={RESET}") - print(f"{CSI}32mWarmup round mean TTFT: {warmup_mean_ttft:.3f}s{RESET}") - print( - f"{CSI}33mWarmup round time: {warmup_end_time - warmup_start_time:.3f}s{RESET}" - ) - print(f"{CSI}35mWarmup round prompt count: {len(warmup_df)}{RESET}") print(f"{CSI}32mQuery round mean TTFT: {query_mean_ttft:.3f}s{RESET}") print( f"{CSI}33mQuery round time: " diff --git a/benchmarks/long_doc_qa/long_doc_qa_recommender.py b/benchmarks/long_doc_qa/long_doc_qa_recommender.py index a900e59e217..568517a2caf 100644 --- a/benchmarks/long_doc_qa/long_doc_qa_recommender.py +++ b/benchmarks/long_doc_qa/long_doc_qa_recommender.py @@ -42,9 +42,7 @@ def get_tensor_parallel_recommendation(model_name: str): usable_per_gpu_memory = ( per_gpu_memory * 0.9 - intermediate_buffer - minimum_kv_cache_buffer ) - print( - "Estimated usable gpu memory for model weights per gpu: {usable_per_gpu_memory}" - ) + print(f"Usable gpu memory for model weights per gpu: {usable_per_gpu_memory}") initial_tp = math.ceil(total_model_weights_gb / usable_per_gpu_memory) # round up to a power of 2 return 2 ** math.ceil(math.log2(initial_tp)) @@ -158,6 +156,7 @@ def main(model_name: str): f"but {model_name} requires {tp} tensor parallelism to run on your hardware" ) return + print("This will take a while...") per_gpu_kv_cache_GiB, tokens_in_prefix_cache = get_prefix_cache_token_size( model_name, tp ) @@ -186,9 +185,7 @@ def main(model_name: str): def build_argument_parser(): parser = argparse.ArgumentParser() - parser.add_argument( - "--model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct" - ) + parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B") return parser diff --git a/benchmarks/multi_round_qa/multi-round-qa.py b/benchmarks/multi_round_qa/multi-round-qa.py index bda9ea14009..cd02618cab8 100644 --- a/benchmarks/multi_round_qa/multi-round-qa.py +++ b/benchmarks/multi_round_qa/multi-round-qa.py @@ -42,6 +42,9 @@ class WorkloadConfig: # Whether to include user id in request header enable_user_id: bool + # Whether strictly cap active sessions at num_users + enforce_strict_concurrent_users: bool = False + @dataclass class UserConfig: @@ -374,6 +377,10 @@ def __init__( if self.use_sharegpt: self._load_sharegpt_data() + self.enforce_strict_concurrent_users = ( + workload_config.enforce_strict_concurrent_users + ) + def _load_sharegpt_data(self): with open("ShareGPT.json", "r", encoding="utf-8") as file: self.sharegpt_data = json.load(file) @@ -419,6 +426,19 @@ def _remove_finished_sessions(self): self.session_summaries.append(session.summary()) self.sessions = [s for s in self.sessions if not s.finished] + def _can_join_user(self, timestamp: float) -> bool: + # No new user session if gap_between_users time interval not meets + if timestamp - self.last_user_join <= self.gap_between_users: + return False + + # No user seession if active user count is less than configured + if ( + self.enforce_strict_concurrent_users + and len(self.sessions) >= self.workload_config.num_users + ): + return False + return True + def step(self, timestamp: float, executor: RequestExecutor): if self.need_ramp_up: self._ramp_up(timestamp, self.ramp_up_time) @@ -426,7 +446,8 @@ def step(self, timestamp: float, executor: RequestExecutor): if self.start_time is None: self.start_time = timestamp - if timestamp - self.last_user_join > self.gap_between_users: + # Check if can join new user session + if self._can_join_user(timestamp): self._create_user_session() self.last_user_join = timestamp logger.info( @@ -635,6 +656,11 @@ def parse_arguments(): action="store_true", help="Does not send requests to the endpoint (server)", ) + parser.add_argument( + "--enforce-strict-concurrent-users", + action="store_true", + help="Strictly enforce concurrent users count to match --num-users", + ) args = parser.parse_args() return args @@ -688,6 +714,7 @@ def main(): qps=args.qps, model=args.model, enable_user_id=args.request_with_user_id, + enforce_strict_concurrent_users=args.enforce_strict_concurrent_users, ) manager = UserSessionManager( diff --git a/csrc/mem_kernels.cu b/csrc/mem_kernels.cu index b88fa1379f8..7f89616d35b 100644 --- a/csrc/mem_kernels.cu +++ b/csrc/mem_kernels.cu @@ -171,6 +171,56 @@ key_value_offset(const int k_or_v, const int layer_idx, const int token_idx, token_idx * scalars_per_token + scalar_offset; } +template +__global__ void single_layer_kv_transfer_sgl_kernel( + // scalar_t* __restrict__ lmc_key_cache, // [num_tokens, + // num_heads*head_size] scalar_t* __restrict__ lmc_value_cache, // + // [num_tokens, num_heads*head_size] + scalar_t* __restrict__ lmc_key_value_cache, // [num_tokens, 2, + // num_heads*head_size] + // or + // [2, num_tokens, + // num_heads*head_size] + scalar_t* __restrict__ sgl_key_cache, // [num_blocks, block_size, + // num_heads, head_size] + scalar_t* __restrict__ sgl_value_cache, // [num_blocks, block_size, + // num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride_in_64bit, const int lmc_stride, + const int lmc_value_offset, const int num_heads, + const int head_size_in_64bit, const int block_size, const bool direction) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + + if (slot_idx < 0) { + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size_in_64bit; + + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t lmc_key_idx = token_idx * lmc_stride + i; + const int64_t lmc_value_idx = lmc_key_idx + lmc_value_offset; + + const int head_idx = i / head_size_in_64bit; + const int head_offset = i % head_size_in_64bit; + const int64_t sgl_key_value_idx = + block_idx * block_stride_in_64bit + + block_offset * num_heads * head_size_in_64bit + + head_idx * head_size_in_64bit + head_offset; + + if (direction) { + lmc_key_value_cache[lmc_key_idx] = sgl_key_cache[sgl_key_value_idx]; + lmc_key_value_cache[lmc_value_idx] = sgl_value_cache[sgl_key_value_idx]; + } else { + sgl_key_cache[sgl_key_value_idx] = lmc_key_value_cache[lmc_key_idx]; + sgl_value_cache[sgl_key_value_idx] = lmc_key_value_cache[lmc_value_idx]; + } + } +} + /** * Quickly load KV cache between vLLM paged memory and offloading buffer * slot_id = slot_mapping[block.x] @@ -627,3 +677,71 @@ void reshape_and_cache_back_flash( block_stride_in_64bit, key_value_stride, num_heads, head_size_in_64bit, block_size, key_layer_offset, value_layer_offset); } + +void single_layer_kv_transfer_sgl( + // torch::Tensor& lmc_key_cache, // [num_tokens, num_heads*head_size] + // key/value must be on gpu/pinned cpu + // torch::Tensor& lmc_value_cache, // [num_tokens, num_heads*head_size] + + torch::Tensor& lmc_key_value_cache, // [num_tokens, 2, num_heads*head_size] + // or + // [2, num_tokens, num_heads*head_size] + + torch::Tensor& + sgl_key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + sgl_value_cache, // [num_blocks, block_size, num_heads, head_size] + // key_cache/value_cache must be on gpu + torch::Tensor& slot_mapping, // [num_tokens] + const bool direction, // false: LMCache to PagedBuffer, true: PagedBuffer + // to LMCache + const bool token_major // true: lmc_key_value_cache is + // [num_tokens, 2, num_heads*head_size] + // false: lmc_key_value_cache is + // [2, num_tokens, num_heads*head_size] +) { + // int64_t* lmc_key_cache_ptr = get_kernel_ptr(lmc_key_cache); int64_t* lmc_value_cache_ptr = + // get_kernel_ptr(lmc_value_cache); + int64_t* lmc_key_value_cache_ptr = + get_kernel_ptr(lmc_key_value_cache); + + int64_t* sgl_key_cache_ptr = + get_kernel_ptr(sgl_key_cache); + int64_t* sgl_value_cache_ptr = + get_kernel_ptr(sgl_value_cache); + + const int64_t* slot_mapping_ptr = + get_kernel_ptr(slot_mapping); + + int elements_per_entry = 8 / sgl_key_cache.element_size(); + + int num_tokens = slot_mapping.size(0); + int num_heads = sgl_key_cache.size(2); + int head_size_in_64bit = sgl_key_cache.size(3) / elements_per_entry; + + int block_size = sgl_key_cache.size(1); + + int lmc_stride; + int lmc_value_offset; + if (token_major) { + lmc_stride = lmc_key_value_cache.stride(0) / elements_per_entry; + lmc_value_offset = lmc_key_value_cache.stride(1) / elements_per_entry; + } else { + lmc_stride = lmc_key_value_cache.stride(1) / elements_per_entry; + lmc_value_offset = lmc_key_value_cache.stride(0) / elements_per_entry; + } + + int block_stride_in_64bit = sgl_key_cache.stride(0) / elements_per_entry; + TORCH_CHECK(sgl_key_cache.stride(0) == sgl_value_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size_in_64bit, 128)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(sgl_key_cache)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + lmc::single_layer_kv_transfer_sgl_kernel<<>>( + lmc_key_value_cache_ptr, sgl_key_cache_ptr, sgl_value_cache_ptr, + slot_mapping_ptr, block_stride_in_64bit, lmc_stride, lmc_value_offset, + num_heads, head_size_in_64bit, block_size, direction); +} \ No newline at end of file diff --git a/csrc/mem_kernels.cuh b/csrc/mem_kernels.cuh index f6850b38d3f..a7c9318cc68 100644 --- a/csrc/mem_kernels.cuh +++ b/csrc/mem_kernels.cuh @@ -26,6 +26,13 @@ void single_layer_kv_transfer(torch::Tensor& lmc_key_value_cache, const bool token_major = false, const bool vllm_two_major = false); +void single_layer_kv_transfer_sgl(torch::Tensor& lmc_key_value_cache, + torch::Tensor& sgl_key_cache, + torch::Tensor& sgl_value_cache, + torch::Tensor& slot_mapping, + const bool direction, + const bool token_major = false); + void load_and_reshape_flash(torch::Tensor& key_value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const int layer_idx); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 5c431e51c01..5712a219c32 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -16,6 +16,7 @@ PYBIND11_MODULE(c_ops, m) { m.def("multi_layer_kv_transfer_unilateral", &multi_layer_kv_transfer_unilateral); m.def("single_layer_kv_transfer", &single_layer_kv_transfer); + m.def("single_layer_kv_transfer_sgl", &single_layer_kv_transfer_sgl); m.def("load_and_reshape_flash", &load_and_reshape_flash); m.def("reshape_and_cache_back_flash", &reshape_and_cache_back_flash); m.def("encode_fast_new", &encode_cuda_new); diff --git a/docker/Dockerfile b/docker/Dockerfile index fa3ffac77aa..33893f610e6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -18,6 +18,7 @@ ARG CUDA_VERSION ARG PYTHON_VERSION=3.12 ARG UBUNTU_VERSION ENV DEBIAN_FRONTEND=noninteractive +ENV PATH="/opt/venv/bin:${PATH}" # Install Python and other dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ @@ -36,39 +37,6 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ WORKDIR /workspace -# Install and setup nixl -RUN apt-get update -y && \ - apt-get -y install \ - ninja-build \ - pybind11-dev \ - python${PYTHON_VERSION}-dev \ - cmake -RUN export LD_LIBRARY_PATH=/usr/local/cuda/compat/lib.real:$LD_LIBRARY_PATH -RUN export NIXL_PLUGIN_DIR=/usr/local/nixl/lib/x86_64-linux-gnu/plugins -RUN cd /workspace -RUN git clone https://github.com/ai-dynamo/nixl && \ - cd nixl && \ - git checkout b1c22edd8fe10e2e5221c107ee51200fce6f09a8 -RUN cd /workspace/nixl -RUN source /opt/venv/bin/activate -RUN . /opt/venv/bin/activate && \ - uv pip install meson -RUN cd /workspace/nixl && \ - . /opt/venv/bin/activate && \ - rm -rf build && \ - mkdir build && \ - uv run meson setup build/ --prefix=/usr/local/nixl && \ - cd build && \ - ninja && \ - ninja install -RUN echo "/usr/local/nixl/lib/x86_64-linux-gnu" > /etc/ld.so.conf.d/nixl.conf -RUN echo "/usr/local/nixl/lib/x86_64-linux-gnu/plugins" >> /etc/ld.so.conf.d/nixl.conf -RUN ldconfig -RUN cd /workspace/nixl/ && \ - . /opt/venv/bin/activate && \ - uv build --wheel --out-dir /tmp/dist && \ - uv pip install /tmp/dist/nixl-0.3.0-cp312-cp312-linux_x86_64.whl - # Install runtime dependencies COPY ./requirements/common.txt common.txt COPY ./requirements/cuda.txt cuda.txt @@ -119,11 +87,15 @@ RUN --mount=type=cache,target=/root/.cache/ccache,id=ccache \ --mount=type=cache,target=/root/.cache/uv,id=uv-cache,sharing=locked \ . /opt/venv/bin/activate && \ if [ "$VLLM_VERSION" = "nightly" ]; then \ - uv pip install --prerelease=allow 'vllm[runai,tensorizer]' \ + uv pip install --prerelease=allow \ + 'vllm[runai,tensorizer,flashinfer]' \ + 'triton-kernels @ git+https://github.com/triton-lang/triton.git@v3.5.0#subdirectory=python/triton_kernels' \ --extra-index-url https://wheels.vllm.ai/nightly \ --index-strategy unsafe-best-match ; \ else \ - uv pip install --prerelease=allow "vllm[runai,tensorizer]==${VLLM_VERSION}" ; \ + uv pip install --prerelease=allow \ + "vllm[runai,tensorizer,flashinfer]==${VLLM_VERSION}" \ + 'triton-kernels @ git+https://github.com/triton-lang/triton.git@v3.5.0#subdirectory=python/triton_kernels' ; \ fi && \ python3 -c 'import torch; print("TORCH=", torch.__version__)' && \ python3 setup.py bdist_wheel --dist-dir=dist_lmcache && \ @@ -141,7 +113,7 @@ FROM base AS image-release # It is imperative that LMCache uses the same torch version as the # vLLM stable release. RUN . /opt/venv/bin/activate && \ - uv pip install --prerelease=allow vllm[runai,tensorizer] && \ + uv pip install --prerelease=allow vllm[runai,tensorizer,flashinfer] && \ uv pip install lmcache --verbose WORKDIR /workspace diff --git a/docs/source/_static/basic_codepath.svg b/docs/source/_static/basic_codepath.svg new file mode 100644 index 00000000000..d910b87d98e --- /dev/null +++ b/docs/source/_static/basic_codepath.svg @@ -0,0 +1,107 @@ + + + + + + + + GPU Model Runner + + + + start_load_kv() + + + + + + + + wait_for_layer_load() + + + + Attention 1 + + + + save_kv_layer() + + + + + + + + + + + + + + + wait_for_layer_load() + + + + Attention 2 + + + + save_kv_layer() + + + + + + + + + + + + + + + + + + + + + + wait_for_layer_load() + + + + Attention N + + + + save_kv_layer() + + + + + + + + + + + wait_for_save() + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js index e69de29bb2d..12553491b8e 100644 --- a/docs/source/_static/custom.js +++ b/docs/source/_static/custom.js @@ -0,0 +1,17 @@ +document.addEventListener("DOMContentLoaded", function () { + var script = document.createElement("script"); + script.type = "module"; + script.id = "runllm-widget-script" + + script.src = "https://widget.runllm.com"; + + script.setAttribute("version", "stable"); + script.setAttribute("crossorigin", "true"); + script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); + script.setAttribute("runllm-name", "LMCache Assistant"); + script.setAttribute("runllm-position", "BOTTOM_RIGHT"); + script.setAttribute("runllm-assistant-id", "1185"); + + script.async = true; + document.head.appendChild(script); +}); diff --git a/docs/source/_static/full_layerwise_diagram.svg b/docs/source/_static/full_layerwise_diagram.svg new file mode 100644 index 00000000000..90c9973bb19 --- /dev/null +++ b/docs/source/_static/full_layerwise_diagram.svg @@ -0,0 +1,256 @@ + + + + + + GPU Model Runner + + + + + + + KV Cache 1 + + + + KV Cache 2 + + + + + + + + + KV Cache N + + + + + + + start_load_kv() + + 1 + + + + + + + + wait_for_layer_load() + + 2 + + + + Attention 1 + + + + save_kv_layer() + + 3 + + + + + + + + + + + + wait_for_layer_load() + + 2 + + + Attention 2 + + + save_kv_layer() + + 3 + + + + + + + + + + + + + + + + wait_for_layer_load() + + 2 + + + Attention N + + + save_kv_layer() + + 3 + + + + + + + + wait_for_save() + + 4 + + + + CacheEngine + + + Note: Multi-layer Cache Engine Key uses split_layers(N) + + + + Retrieval Generator (N + 2 yields) + retrieve_layer() called by start_load_kv() + • Layer-by-layer memory allocation + • Contains GPU Load Generator (batched_to_gpu) + • Calls StorageManager.layerwise_batched_get() + + + + Storage Generator (N + 1 yields) + store_layer() - FIRST save_kv_layer() only + • Upfront CPU memory allocation (all layers) + • Contains GPU Store Generator (batched_from_gpu) + • Calls StorageManager.batched_put() + + + + StorageManager + + + + layerwise_batched_get() [async → .result()] + + + batched_put(memory_objects) + + + + LayerwiseGPUConnector + VLLMPagedMemLayerwiseGPUConnector + + + + Batch: Req 1 (First) | Req 2 | ... | Req N (Last) + + + + Load GPU Buffer + (use_gpu: true) + + + Store GPU Buffer + (use_gpu: true) + + + + CUDA Streams + load_stream | store_stream | current_stream + + + + Synchronization Details + First Req: store_stream.wait_stream(current_stream) + sync prevents buffer override + Last Req: current_stream.wait_stream(load_stream) prevents forward pass until KV loaded + + + + + + + 1 + start_load_kv() Operations + • Initialize Retrieval Generator + • Call retrieve_layer() - 1st next() (setup) + • Call retrieve_layer() - 2nd next() (layer 0) + + + + + 2 + wait_for_layer_load() Operations + • Call retrieve_layer() - next() (layer i) + • StorageManager.layerwise_batched_get() + • GPU Load Gen: batched_to_gpu() next() + + + + + 3 + save_kv_layer() Operations + • FIRST call: Create Storage Generator + • Call store_layer() - next() (layer i) + • GPU Store Gen: batched_from_gpu() next() + + + + + 4 + wait_for_save() Operations + • Call store_layer() - final next() + • StorageManager.batched_put() + • GPU Store Gen: final next() + cleanup + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/source/_static/kv_cache_calculator.html b/docs/source/_static/kv_cache_calculator.html new file mode 100644 index 00000000000..08710c1bc85 --- /dev/null +++ b/docs/source/_static/kv_cache_calculator.html @@ -0,0 +1,489 @@ + + + + + + KV Cache Size Calculator + + + +
+

KV Cache Size Calculator

+ +
Loading model configurations...
+ + +
+ + + + + diff --git a/docs/source/_static/modelconfig.json b/docs/source/_static/modelconfig.json new file mode 100644 index 00000000000..b74ba949277 --- /dev/null +++ b/docs/source/_static/modelconfig.json @@ -0,0 +1,131 @@ +{ + "meta-llama/Llama-3.1-8B-Instruct": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8 + }, + "meta-llama/Llama-3.1-70B-Instruct": { + "hidden_size": 8192, + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8 + }, + "mistralai/Mistral-7B-Instruct-v0.2": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8 + }, + "mistralai/Mistral-Large-Instruct-2407": { + "hidden_size": 12288, + "num_attention_heads": 96, + "num_hidden_layers": 88, + "num_key_value_heads": 8 + }, + "lmsys/longchat-7b-16k": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32 + }, + "Sao10K/L3-8B-Lunaris-v1": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8 + }, + "meta-llama/Llama-3.2-3B-Instruct": { + "hidden_size": 3072, + "num_attention_heads": 24, + "num_hidden_layers": 28, + "num_key_value_heads": 8 + }, + "deepseek-ai/DeepSeek-V3": { + "hidden_size": 7168, + "num_attention_heads": 128, + "num_hidden_layers": 61, + "num_key_value_heads": 128, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64 + }, + "deepseek-ai/DeepSeek-R1": { + "hidden_size": 7168, + "num_attention_heads": 128, + "num_hidden_layers": 61, + "num_key_value_heads": 128, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64 + }, + "meta-llama/Llama-3.1-405B": { + "hidden_size": 16384, + "num_attention_heads": 128, + "num_hidden_layers": 126, + "num_key_value_heads": 8 + }, + "meta-llama/Llama-3.2-1B-Instruct": { + "hidden_size": 2048, + "num_attention_heads": 32, + "num_hidden_layers": 16, + "num_key_value_heads": 8 + }, + "Qwen/Qwen3-32B": { + "hidden_size": 5120, + "num_attention_heads": 64, + "num_hidden_layers": 64, + "num_key_value_heads": 8, + "head_dim": 128 + }, + "Qwen/Qwen3-14B": { + "hidden_size": 5120, + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "head_dim": 128 + }, + "Qwen/Qwen3-8B": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "head_dim": 128 + }, + "Qwen/Qwen3-4B": { + "hidden_size": 2560, + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "head_dim": 128 + }, + "Qwen/Qwen3-0.6B": { + "hidden_size": 1024, + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "head_dim": 128 + }, + "Qwen/Qwen2.5-7B-Instruct": { + "hidden_size": 3584, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4 + }, + "Qwen/Qwen2.5-3B-Instruct": { + "hidden_size": 2048, + "num_attention_heads": 16, + "num_hidden_layers": 36, + "num_key_value_heads": 2 + }, + "Qwen/Qwen2.5-0.5B": { + "hidden_size": 896, + "num_attention_heads": 14, + "num_hidden_layers": 24, + "num_key_value_heads": 2 + }, + "Qwen/Qwen-7B": { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32 + } +} diff --git a/docs/source/api_reference/configurations.rst b/docs/source/api_reference/configurations.rst index 3b1252521c6..efcf8083687 100644 --- a/docs/source/api_reference/configurations.rst +++ b/docs/source/api_reference/configurations.rst @@ -79,6 +79,47 @@ Basic cache settings that control the core functionality of LMCache. * - extra_config - LMCACHE_EXTRA_CONFIG={"key": value, ...} - Additional configuration as JSON dict. For NUMA manual mode, include "gpu_to_numa_mapping": {gpu_id: numa_node, ...}. Default: {} + +Lazy Memory Allocator Configurations +------------------------------------ + +Settings for the lazy memory allocator that enables gradual memory allocation to reduce startup time and initial memory footprint. + +.. note:: + + The lazy memory allocator is designed for scenarios with large CPU memory configurations. It starts with a small initial allocation and gradually expands as needed, reducing startup wait time and avoiding unnecessary memory consumption when the full capacity is not required. + + **Key characteristics:** + + - **One-time expansion**: Memory expands until target size is reached, then stops + - **No shrinking**: Once allocated, memory is never released back to the system + - **Automatic activation**: Only activates when ``max_local_cpu_size`` exceeds ``lazy_memory_safe_size`` + +.. list-table:: + :header-rows: 1 + :widths: 30 30 40 + + * - YAML Config Name + - Environment Variable + - Description + * - enable_lazy_memory_allocator + - LMCACHE_ENABLE_LAZY_MEMORY_ALLOCATOR + - Whether to enable lazy memory allocator. Values: true/false. Default: false + * - lazy_memory_initial_ratio + - LMCACHE_LAZY_MEMORY_INITIAL_RATIO + - Initial memory allocation ratio (0.0-1.0). Determines the fraction of max_local_cpu_size to allocate at startup. Default: 0.2 (20%) + * - lazy_memory_expand_trigger_ratio + - LMCACHE_LAZY_MEMORY_EXPAND_TRIGGER_RATIO + - Memory usage ratio (0.0-1.0) that triggers expansion. When used memory exceeds this ratio of current capacity, expansion begins. Default: 0.5 (50%) + * - lazy_memory_step_ratio + - LMCACHE_LAZY_MEMORY_STEP_RATIO + - Memory expansion step ratio (0.0-1.0). Each expansion adds this fraction of max_local_cpu_size. Default: 0.1 (10%) + * - lazy_memory_safe_size + - LMCACHE_LAZY_MEMORY_SAFE_SIZE + - Threshold in GB above which lazy allocator activates. If max_local_cpu_size ≤ this value, lazy allocator is disabled regardless of enable_lazy_memory_allocator setting. Default: 0.0 + * - reserve_local_cpu_size + - LMCACHE_RESERVE_LOCAL_CPU_SIZE + - Reserved system memory in GB that should not be allocated by LMCache. Used to prevent out-of-memory conditions. Default: 0.0 Cache Blending Configurations ----------------------------- diff --git a/docs/source/community/meetings.rst b/docs/source/community/meetings.rst index 70dfb0f2fc0..11d40331f24 100644 --- a/docs/source/community/meetings.rst +++ b/docs/source/community/meetings.rst @@ -4,6 +4,7 @@ Community meetings LMCache hosts regular community meetings to discuss updates, address new feature requests, and feedback from the community. If you are interested in contributing to the LMCache projects (core LMCache or Production Stack), we encourage you to join the meetings. +We also host a monthly "Office Hour" on select topics. Meeting schedule ----------------- @@ -21,7 +22,7 @@ Please find the meeting invite link below: - **Meeting link**: `Zoom link `_ - **Calendar Invite**: `Google Calendar `__ -- **Slack Channel**: `#lmcache `_ +- **Slack Channel**: `#lmcache `_ vLLM Production Stack Project +++++++++++++++++++++++++++++++ @@ -37,3 +38,12 @@ Please find the meeting invite link below: .. note:: The Zoom meeting link is the same for both LMCache and Production Stack community meetings. Meeting notes are available here: `Meeting notes `_. + +LMCache Office Hours +++++++++++++++++++++ + +Held monthly on the second Thursday at 2PM ET, 11AM PT. + +- Topic announced on the LMCache #office-hour Slack channel +- Request a calendar invite via `this form `_ + diff --git a/docs/source/conf.py b/docs/source/conf.py index 1442732a545..675008515a2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -36,6 +36,7 @@ "sphinxcontrib.mermaid", # "sphinx_copybutton", "sphinx_multiversion", + "sphinxcontrib.images", ] copybutton_prompt_text = r"^(\$ |>>> |\# )" @@ -73,6 +74,7 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: html_theme = "sphinxawesome_theme" html_static_path = ["_static"] html_css_files = ["custom.css", "scroll.css"] +html_js_files = ["custom.js"] html_favicon = "assets/lmcache-logo.png" html_permalinks_icon = "#" pygments_style = "sphinx" @@ -115,6 +117,11 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: }, ) +images_config = { + "default_image_width": "80%", + "default_image_target": "_blank", +} + html_theme_options = asdict(theme_options) # more_options = { diff --git a/docs/source/developer_guide/contributing.rst b/docs/source/developer_guide/contributing.rst index 3ade0cf3120..0ef13e3ba0d 100644 --- a/docs/source/developer_guide/contributing.rst +++ b/docs/source/developer_guide/contributing.rst @@ -8,7 +8,7 @@ Thank you for your interest in contributing to LMCache! We welcome and accept al - Suggest or implement new features - Improve documentation or contribute a how-to guide -A comprehensive list of good first issues can be found in the issue `[Onboarding]: Welcoming contributors with good first issues! `_. +A comprehensive list of good first issues can be found in the issue `[Onboarding][Q4] Welcoming contributors with good first issues! `_. If you'd like to support our community further, then answering queries, offering PR reviews, and assisting others are also impactful ways to contribute and take LMCache further. @@ -44,9 +44,26 @@ To contribute to this repo, you'll use the Fork and Pull model common in many op - Run unit tests and fix any broken tests - Submit a pull request with detailed descriptions -When your contribution is ready, you can create a pull request. Pull requests are often referred to as "PRs". In general, we follow the standard `GitHub pull request `_ process. Follow the template to provide details about your pull request to the maintainers. It's best to break your contribution into smaller PRs with incremental changes, and include a good description of the changes. We require new unit tests to be contributed with any new functionality added. +When your contribution is ready, you can create a pull request. Pull requests are often referred to as "PRs". In general, we follow the standard `GitHub pull request `_ process. Follow the template to provide details about your pull request to the maintainers. -Before sending pull requests, make sure your changes pass code quality checks and unit tests. These checks will run with the pull request builds. Alternatively, you can run the checks manually on your local machine `as specified in Development <#development>`_ . +Please try to classify PRs for easy understanding of the type of changes. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following: + +- [Bugfix] for bug fixes +- [Build] for build fixes and improvements +- [CI] for continuous integration fixes and iimprovements +- [Core] for changes in the core LMCache logic (e.g., ``LMCacheEngine``, ``Backend`` etc.) +- [Doc] for documentation fixes and improvements +- [Misc] for PRs that do not fit the above categories. Please use this sparingly +- [Model] for adding a new model or improving an existing model. Model name should appear in the title +- [Test] for unit tests + +.. note:: + + If the PR spans more than one category, please include all relevant prefixes + +It's best to break your contribution into smaller PRs with incremental changes, and include a good description of the changes. We require new unit tests to be contributed with any new functionality added and docs if user facing changes. + +Before sending pull requests, make sure your changes pass code quality checks and unit tests. These checks will run when the pull request builds. Alternatively, you can run the checks manually on your local machine `as specified in Development <#development>`_ . DCO and Signed-off-by ^^^^^^^^^^^^^^^^^^^^^ @@ -135,12 +152,7 @@ It will run automatically when you add a commit. You can also run it manually on Unit tests ^^^^^^^^^^ -.. note:: - The Unit tests require `NVIDIA Inference Xfer Library (NIXL) `_ to be installed. Please follow the details in the NIXL GitHub repo to install. - The NIXL unit tests also require `vLLM `_ and `msgpack `_. - If you are unable to install NIXL you can circumvent the NIXL unit tests by using the following pytest flags: `--ignore=tests/disagg --ignore=tests/v1/test_pos_kernels.py --ignore=tests/v1/test_nixl_storage.py`. - -When making changes, run the tests before pushing the changes. Running unit tests ensures your contributions do not break exiting code. We use the `pytest `_ framework to run unit tests. The framework is setup to run all files in the `tests `_ directory which have a prefix or posfix of "test". +When making changes, run the tests before pushing the changes. Running unit tests ensures your contributions do not break existing code. We use the `pytest `_ framework to run unit tests. The framework is setup to run all files in the `tests `_ directory which have a prefix or posfix of "test". Running unit tests is as simple as: @@ -148,11 +160,9 @@ Running unit tests is as simple as: pytest -Alternatively, running unit tests (minus NIXL tests) is as follows: - -.. code-block:: bash +.. note:: - pytest --ignore=tests/disagg --ignore=tests/v1/test_pos_kernels.py --ignore=tests/v1/test_nixl_storage.py + ``vLLM`` (``pip install vllm``) and the dependencies in ``requirements/test.txt`` need to be installed prior to running unit tests. By default, all tests found within the tests directory are run. However, specific unit tests can run by passing filenames, classes and/or methods to `pytest`. The following example invokes a single test method "test_lm_connector" that is declared in the "tests/test_connector.py" file: @@ -160,9 +170,14 @@ By default, all tests found within the tests directory are run. However, specifi pytest tests/test_connector.py::test_lm_connector -.. warning:: +.. note:: + + Some unit tests require a NVIDIA GPU. This means that on a non Linux NVIDIA GPU system, the full suite of tests will not be run (tests requiring CUDA will be skipped). The Buildkite continuous integration (CI) system executes a full run of all the tests. + +.. note:: - Currently, unit tests do not run on non Linux NVIDIA GPU platforms. If you don't have access to this platform to run unit tests locally, rely on the continuous integration system to run the tests for now. + The `NVIDIA Inference Xfer Library (NIXL) `_ unit tests require NIXL to be to be installed. This is not installed by LMCache by and therefore requires you to install it separately. Please follow the details in the NIXL GitHub repo to install. + If the NIXL package is not installed, the NIXL unit tests are skipped. Building the docs ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/developer_guide/usage/basic_check.rst b/docs/source/developer_guide/usage/basic_check.rst new file mode 100644 index 00000000000..6bdac4095cd --- /dev/null +++ b/docs/source/developer_guide/usage/basic_check.rst @@ -0,0 +1,139 @@ +Basic Check Tool +================ + +The LMCache Basic Check Tool is a testing and validation utility that helps you verify your LMCache installation, configuration, and functionality. It provides multiple testing modes to validate different components of the LMCache system. + +Overview +-------- + +The basic check tool (``lmcache.v1.basic_check``) is designed to: + +* Test remote backend connectivity and functionality +* Validate storage manager operations +* Generate test keys for performance testing +* Verify configuration settings +* Provide diagnostic information for troubleshooting + +Available Check Modes +--------------------- + +The tool supports several check modes, each targeting specific functionality: + +test_remote +~~~~~~~~~~~ + +Tests the remote backend functionality including: + +* Connection establishment to remote backends (fs, etc.) +* put/get operations with data integrity validation +* put/get/exists operations with performance reports + +**Usage:** + +.. code-block:: bash + + python -m lmcache.v1.basic_check --mode test_remote + +test_storage_manager +~~~~~~~~~~~~~~~~~~~~ + +Tests the storage manager operations including: + +* Configuration validation +* batched_put/get operations with data integrity validation +* batched_put/get/contains operations with performance reports + +**Usage:** + +.. code-block:: bash + + python -m lmcache.v1.basic_check --mode test_storage_manager + +gen (Key Generation) +~~~~~~~~~~~~~~~~~~~~ + +Generates test keys for performance testing and benchmarking: + +* Configurable number of keys and concurrency levels +* Memory-efficient batch processing +* Progress tracking and performance metrics +* Offset support for distributed testing + +**Usage:** + +.. code-block:: bash + + python -m lmcache.v1.basic_check --mode gen --num-keys 1000 --concurrency 16 + +Command Line Interface +---------------------- + +Basic Usage +~~~~~~~~~~~ + +.. code-block:: bash + + python -m lmcache.v1.basic_check --mode [OPTIONS] + +List Available Modes +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + python -m lmcache.v1.basic_check --mode list + +Command Line Options +~~~~~~~~~~~~~~~~~~~~ + +.. option:: --mode MODE + + **Required.** Operation mode to run. Use ``list`` to see available modes. + +.. option:: --model MODEL + + Model name for testing, just a part of key of persist kv-cache. Default: ``/lmcache_test_model/`` + +.. option:: --num-keys NUM + + Number of keys to generate (gen mode only). Default: 100 + +.. option:: --concurrency NUM + + Concurrency level for operations (gen mode only). Default: 16 + +.. option:: --offset NUM + + Offset for key generation (gen mode only). Default: 0 + +Configuration +------------- + +The basic check tool uses your existing LMCache configuration. You can specify configuration in several ways: + +Environment Variable +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + export LMCACHE_CONFIG_PATH=/path/to/config.yaml + python -m lmcache.v1.basic_check --mode test_remote + +Example Configuration +~~~~~~~~~~~~~~~~~~~~~ + +Here's an example configuration optimized for basic checks: + +.. code-block:: yaml + + # Basic cache settings + chunk_size: 256 + local_cpu: true + max_local_cpu_size: 1.0 # 1GB for basic checks + + # Remote backend (optional) + remote_url: "file:///tmp/lmcache_basic_check" + +Examples +-------- + +The ``examples/basic_check/`` directory contains comprehensive examples: diff --git a/docs/source/developer_guide/usage/index.rst b/docs/source/developer_guide/usage/index.rst index ecb2b2bc4d4..1a500c909db 100644 --- a/docs/source/developer_guide/usage/index.rst +++ b/docs/source/developer_guide/usage/index.rst @@ -4,4 +4,5 @@ Usage Data Module .. toctree:: :maxdepth: 1 + basic_check usage_stats_collection diff --git a/docs/source/disaggregated_prefill/nixl/1p1d.rst b/docs/source/disaggregated_prefill/nixl/1p1d.rst index 94023f32bc6..8a6692544a8 100644 --- a/docs/source/disaggregated_prefill/nixl/1p1d.rst +++ b/docs/source/disaggregated_prefill/nixl/1p1d.rst @@ -216,7 +216,7 @@ For comprehensive performance testing, use vLLM's benchmark tool: .. code-block:: bash - python benchmark_serving.py --port 9100 --seed $(date +%s) \ + vllm bench serve --port 9100 --seed $(date +%s) \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name random --random-input-len 7500 --random-output-len 200 \ --num-prompts 30 --burstiness 100 --request-rate 1 --ignore-eos diff --git a/docs/source/disaggregated_prefill/nixl/xpyd.rst b/docs/source/disaggregated_prefill/nixl/xpyd.rst index 6f45404497d..d66c9c07f1d 100644 --- a/docs/source/disaggregated_prefill/nixl/xpyd.rst +++ b/docs/source/disaggregated_prefill/nixl/xpyd.rst @@ -267,7 +267,7 @@ For comprehensive performance testing, use vLLM's benchmark tool: .. code-block:: bash - python benchmark_serving.py --port 9100 --seed $(date +%s) \ + vllm bench serve --port 9100 --seed $(date +%s) \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name random --random-input-len 7500 --random-output-len 200 \ --num-prompts 30 --burstiness 100 --request-rate 1 --ignore-eos diff --git a/docs/source/getting_started/benchmarking.rst b/docs/source/getting_started/benchmarking.rst index 66d02cb21ac..a86953507ea 100644 --- a/docs/source/getting_started/benchmarking.rst +++ b/docs/source/getting_started/benchmarking.rst @@ -13,19 +13,16 @@ You can also choose the number of times to repeat prompts and the mode of repeti LMCache provides a simple Long Doc QA Recommender that helps you deploy LMCache and generate the appropriate traffic through Long Doc QA. It will also help you determine the tensor parallelism and the amount of CPU RAM to deploy LMCache with based on the specifications of your hardware. -First set your ``HF_TOKEN`` environment variable with access to the model you want to benchmark. Then run the recommendation script: - .. code-block:: bash python benchmarks/long_doc_qa/long_doc_qa_recommender.py --model -Example #1: ------------ +Example: +--------- .. code-block:: bash - # default is meta-llama/Meta-Llama-3.1-8B-Instruct - python benchmarks/long_doc_qa/long_doc_qa_recommender.py + python benchmarks/long_doc_qa/long_doc_qa_recommender.py --model Qwen/Qwen3-8B .. code-block:: text @@ -35,7 +32,7 @@ Example #1: ----------------- PYTHONHASHSEED=0 \ - vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + vllm serve Qwen/Qwen3-8B \ --tensor-parallel-size 1 \ --load-format dummy @@ -45,7 +42,7 @@ Example #1: PYTHONHASHSEED=0 \ LMCACHE_MAX_LOCAL_CPU_SIZE=66 \ - vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + vllm serve Qwen/Qwen3-8B \ --tensor-parallel-size 1 \ --load-format dummy \ --kv-transfer-config \ @@ -56,106 +53,35 @@ Example #1: ---------------------------------------- python benchmarks/long_doc_qa/long_doc_qa.py \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ - --num-documents 51 \ + --model Qwen/Qwen3-8B \ + --num-documents 46 \ --document-length 10000 \ --output-len 100 \ --repeat-count 1 \ --repeat-mode tile \ --max-inflight-requests 4 -Llama 8B vLLM Metrics: -^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: text - - === BENCHMARK RESULTS === - Warmup round mean TTFT: 0.751s - Warmup round time: 24.915s - Warmup round prompt count: 51 - Query round mean TTFT: 0.753s - Query round time: 24.628s - Query round prompt count: 51 - -Llama 8B LMCache Metrics: -^^^^^^^^^^^^^^^^^^^^^^^^^^ +Qwen 8B vLLM Metrics: +^^^^^^^^^^^^^^^^^^^^^ .. code-block:: text === BENCHMARK RESULTS === - Warmup round mean TTFT: 0.832s - Warmup round time: 26.027s - Warmup round prompt count: 51 - Query round mean TTFT: 0.214s - Query round time: 14.564s - Query round prompt count: 51 - -The warmup round is the first time the model sees the documents. The query round is the second time the model sees the documents. Without offloading, even with KV Cache reuse, there is no improvement in TTFT nor throughput. With offloading, we can see significant performance improvements to the query round. - -Example #2: ------------ - -.. code-block:: bash - - python benchmarks/long_doc_qa/long_doc_qa_recommender.py --model meta-llama/Llama-3.1-70B-Instruct - -.. code-block:: text - - 1. vLLM Deployment: - ----------------- - - PYTHONHASHSEED=0 \ - vllm serve meta-llama/Llama-3.1-70B-Instruct \ - --tensor-parallel-size 4 \ - --load-format dummy - - - 2. LMCache Deployment: - -------------------- + Query round mean TTFT: 0.757s + Query round time: 23.467s + Query round prompt count: 46 - PYTHONHASHSEED=0 \ - LMCACHE_MAX_LOCAL_CPU_SIZE=40 \ - vllm serve meta-llama/Llama-3.1-70B-Instruct \ - --tensor-parallel-size 4 \ - --load-format dummy \ - --kv-transfer-config \ - '{"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"}' - - - 3. Multi-Round QA Workload Generation: - ---------------------------------------- - - python benchmarks/long_doc_qa/long_doc_qa.py \ - --model meta-llama/Llama-3.1-70B-Instruct \ - --num-documents 50 \ - --document-length 10000 \ - --output-len 100 \ - --repeat-count 1 \ - --repeat-mode tile \ - --max-inflight-requests 4 - -Llama 70B vLLM Metrics: -^^^^^^^^^^^^^^^^^^^^^^^ +Qwen 8B LMCache Metrics: +^^^^^^^^^^^^^^^^^^^^^^^^ .. code-block:: text === BENCHMARK RESULTS === - Warmup round mean TTFT: 1.797s - Warmup round time: 54.903s - Warmup round prompt count: 50 - Query round mean TTFT: 1.798s - Query round time: 54.974s - Query round prompt count: 50 + Query round mean TTFT: 0.185s + Query round time: 13.789s + Query round prompt count: 46 -Llama 70B LMCache Metrics: -^^^^^^^^^^^^^^^^^^^^^^^^^^ +From this example, we can see a **75%** reduction in TTFT (0.757s → 0.185s), **41%** reduction in total inference time (23.467s → 13.789s) via offloading with **LMCache**. -.. code-block:: text - - === BENCHMARK RESULTS === - Warmup round mean TTFT: 1.881s - Warmup round time: 56.673s - Warmup round prompt count: 50 - Query round mean TTFT: 0.174s - Query round time: 26.223s - Query round prompt count: 50 +.. note:: + The warmup round is the first time the model sees the documents. The query round is the second time the model sees the documents. Without offloading, even with KV Cache reuse, there is no improvement in TTFT nor throughput. With offloading, we can see significant performance improvements to the query round. diff --git a/docs/source/getting_started/faq.rst b/docs/source/getting_started/faq.rst index 8027b1752f9..62b9b5c3898 100644 --- a/docs/source/getting_started/faq.rst +++ b/docs/source/getting_started/faq.rst @@ -1,4 +1,41 @@ FAQ === -Coming soon... \ No newline at end of file +What are the KV cache sizes for popular models? And why is LMCache important? +----------------------------------------------------------------------------- + +You can calculate KV cache sizes using our :doc:`KV cache calculator `. We also provide a reference table below with KV cache information for some popular models. + +As shown in the table, after loading Qwen/Qwen3-32B for example, there is only enough space in the spare GPU RAM to hold 275,760 tokens for KV caches. This supports only 6.73 concurrent users if each prompt is 40,960 tokens long. Once this capacity is exceeded, the KV cache must be evicted, and when the same user returns, their request needs to be re-prefilled, which takes significantly longer. + +**LMCache is designed to extend this virtual memory capacity**, enabling you to store more KV caches and avoid costly re-prefilling operations. + +**KV Cache Sizes for Popular Models** + +.. list-table:: + :header-rows: 1 + :widths: 30 20 20 15 15 + + * - Model + - KV Cache Size per 1000 tokens + - Spare GPU RAM for KV cache + - Context length + - Number of full-length prompts that can be stored in GPU + * - Qwen/Qwen3-8B + - 0.1373 GB + - 50.32 GB (or 366,400 tokens) + - 40,960 tokens + - 8.95x + * - Qwen/Qwen3-32B (tp=2 on H100) + - 0.2441 GB + - 33.66 GB × 2 (or 275,760 tokens) + - 40,960 tokens + - 6.73x + * - meta-llama/Llama-3.1-70B (tp=4 on H100) + - 0.3052 GB + - 32.06 GB × 4 (or 420,208 tokens) + - 131,072 tokens + - 3.21x + +.. note:: + You may also find this `VRAM Calculator `_ useful for calculating the estimated spare GPU RAM for different models and configurations. \ No newline at end of file diff --git a/docs/source/getting_started/kv_cache_calculator.rst b/docs/source/getting_started/kv_cache_calculator.rst new file mode 100644 index 00000000000..964144b9306 --- /dev/null +++ b/docs/source/getting_started/kv_cache_calculator.rst @@ -0,0 +1,18 @@ +KV Cache Size Calculator +======================== + +Use this interactive calculator to: + +- **Calculate KV Cache Size**: Estimate the memory required for caching a specific number of tokens +- **Calculate Maximum Tokens**: Find out how many tokens you can cache given your available GPU RAM + +This helps you plan memory requirements for your LMCache deployment. + +.. raw:: html + +
+ +
diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst new file mode 100644 index 00000000000..a580b61b327 --- /dev/null +++ b/docs/source/getting_started/quickstart.rst @@ -0,0 +1,92 @@ +.. _quickstart: + +Quickstart +========== + +This guide will help you get LMCache up and running quickly within 2 minutes. You'll see LMCache in action with a complete end-to-end example. + +(Terminal 1) Install LMCache +---------------------------- + +First, install LMCache with these three commands: + +.. code-block:: bash + + uv venv --python 3.12 + source .venv/bin/activate + uv pip install lmcache vllm + +Start vLLM with LMCache using a single command: + +.. code-block:: bash + + # The chunk size here is only for illustration purpose, use default one (256) later + LMCACHE_CHUNK_SIZE=8 \ + vllm serve Qwen/Qwen3-8B \ + --port 8000 --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' + +.. note:: + If you want to customize configurations further, you can create a configuration file. See the :doc:`../api_reference/configurations` page to learn about all available options. + +(Terminal 2) Test LMCache in Action +----------------------------------- + +Now let's see LMCache working! Open a new terminal and send your first request: + +.. code-block:: bash + + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-8B", + "prompt": "Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive suite of dense and mixture-of-experts", + "max_tokens": 100, + "temperature": 0.7 + }' + +You should see LMCache logs like this: + +.. code-block:: text + + (EngineCore_DP0 pid=458469) [2025-09-30 00:08:43,982] LMCache INFO: Stored 27 out of total 27 tokens. size: 0.0037 gb, cost 1.8470 ms, throughput: 2.0075 GB/s; offload_time: 1.7962 ms, put_time: 0.0509 ms + +**What this means:** The 27 tokens from your prompt are being stored in CPU RAM because this is the first time the system processes this text. LMCache is caching the KV cache for future reuse. + +Now send a second request with a prefix that overlaps with the first: + +.. code-block:: bash + + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-8B", + "prompt": "Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive suite of dense and mixture-of-experts (MoE) models", + "max_tokens": 100, + "temperature": 0.7 + }' + +You should see logs like this: + +.. code-block:: text + + Reqid: cmpl-6709d8795d3c4464b01999c9f3fffede-0, Total tokens 32, LMCache hit tokens: 24, need to load: 8 + (EngineCore_DP0 pid=494270) [2025-09-30 01:12:36,502] LMCache INFO: Retrieved 8 out of total 8 out of total 24 tokens. size: 0.0011 gb, cost 0.5547 ms, throughput: 1.9808 GB/s; + (EngineCore_DP0 pid=494270) [2025-09-30 01:12:36,509] LMCache INFO: Storing KV cache for 8 out of 32 tokens (skip_leading_tokens=24) + (EngineCore_DP0 pid=494270) [2025-09-30 01:12:36,510] LMCache INFO: Stored 8 out of total 8 tokens. size: 0.0011 gb, cost 0.4274 ms, throughput: 2.5702 GB/s; offload_time: 0.4013 ms, put_time: 0.0262 ms + +**What this means:** + +- **Total tokens 32**: The new prompt has 32 tokens after tokenization +- **LMCache hit tokens: 24**: 24 tokens were found in the cache (24 is a multiple of 8, our chunk size in this example) +- **Need to load: 8**: vLLM has automatic prefix caching enabled with block size 16. Although there are 24 hit tokens, 16 are already in GPU RAM managed by vLLM, so LMCache only needs to load 24-16=8 tokens +- **Why 24 hit tokens instead of 27?** LMCache hashes every 8 tokens incrementally (8, 16, 24, 27). When the new request comes in, it checks every 8-token chunk, so it uses the 24-token hash instead of checking the 27-token hash +- **Stored another 8 tokens**: The new 8 tokens form a complete chunk that gets hashed and stored in CPU RAM for future use + +🎉 **Congratulations!** You've just seen LMCache automatically cache and reuse KV caches, reducing computation for overlapping text. + +Next Steps +---------- + +- **Performance Testing**: Try the :doc:`benchmarking` section to experience LMCache's performance benefits with more comprehensive examples +- **More Examples**: Explore the :doc:`quickstart/index` section for detailed examples including KV cache sharing across instances and disaggregated prefill diff --git a/docs/source/getting_started/quickstart/disaggregated_prefill.rst b/docs/source/getting_started/quickstart/disaggregated_prefill.rst index 1f81e3c0ca5..93297033e53 100644 --- a/docs/source/getting_started/quickstart/disaggregated_prefill.rst +++ b/docs/source/getting_started/quickstart/disaggregated_prefill.rst @@ -162,13 +162,13 @@ Send requests to the proxy server (port 9000) using either the completions or ch "max_tokens": 100 }' -You can also test the setup with the following command, which runs the `benchmark_serving.py `_ from vLLM. +You can also test the setup with the following command, which runs vLLM's serving benchmark: .. code-block:: bash git clone https://github.com/vllm-project/vllm.git cd vllm/benchmarks - python benchmark_serving.py --port 9000 --seed $(date +%s) \ + vllm bench serve --port 9000 --seed $(date +%s) \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name random --random-input-len 5000 --random-output-len 200 \ --num-prompts 50 --burstiness 100 --request-rate 1 diff --git a/docs/source/getting_started/quickstart/index.rst b/docs/source/getting_started/quickstart/index.rst index 4a9019fceab..2e9d33358f9 100644 --- a/docs/source/getting_started/quickstart/index.rst +++ b/docs/source/getting_started/quickstart/index.rst @@ -1,7 +1,7 @@ .. _quickstart_examples: -Quickstart Examples -=================== +More Examples +============= This section provides quick examples to help you get started with LMCache's key features. diff --git a/docs/source/getting_started/quickstart/offload_kv_cache.rst b/docs/source/getting_started/quickstart/offload_kv_cache.rst index b3ae550d2bf..d94ed058871 100644 --- a/docs/source/getting_started/quickstart/offload_kv_cache.rst +++ b/docs/source/getting_started/quickstart/offload_kv_cache.rst @@ -376,3 +376,29 @@ LMCache now supports offloading KV cache to the following destinations: - :doc:`InfiniStore <../../kv_cache/storage_backends/infinistore>` - :doc:`Redis <../../kv_cache/storage_backends/redis>` - :doc:`ValKey <../../kv_cache/storage_backends/valkey>` + +Troubleshooting +--------------- + +If you encounter the following error: + +.. code-block:: text + + (EngineCore_DP0 pid=55437) ERROR 10-04 14:44:47 [core.py:708] RuntimeError: + Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method + +You can resolve this issue using one of the following methods: + +- Set ``VLLM_WORKER_MULTIPROC_METHOD=spawn`` in the environment variables. +- Or update the Python code to guard usage of vllm behind a if ``__name__ == '__main__':`` block. + +.. code-block:: python + + if __name__ == '__main__': + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + from lmcache.v1.cache_engine import LMCacheEngineBuilder + from lmcache.integration.vllm.utils import ENGINE_NAME + main() + +For details, please refer to the `vLLM Troubleshooting Guide: Python multiprocessing `_. \ No newline at end of file diff --git a/docs/source/getting_started/quickstart/share_kv_cache.rst b/docs/source/getting_started/quickstart/share_kv_cache.rst index b565f7b46b8..6a9d24e94a4 100644 --- a/docs/source/getting_started/quickstart/share_kv_cache.rst +++ b/docs/source/getting_started/quickstart/share_kv_cache.rst @@ -101,6 +101,9 @@ The second request will automatically retrieve and reuse the KV cache from the f P2P KV cache sharing -------------------- +.. note:: + This section is outdated. Please refer to :doc:`../../kv_cache/p2p_sharing` for the latest example. + This section demonstrates how to share KV cache across multiple vLLM instances using peer-to-peer transfer. Setup P2P sharing diff --git a/docs/source/index.rst b/docs/source/index.rst index 98ec6578e32..2db40642dcf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -46,7 +46,7 @@ Welcome to LMCache! For more information, check out the following: * `LMCache blogs `_ -* `Join LMCache slack workspace `_ +* `Join LMCache slack workspace `_ * Our papers: * `CacheGen: KV Cache Compression and Streaming for Fast Large Language Model Serving `_ @@ -71,8 +71,10 @@ Documentation :caption: Getting Started getting_started/installation + getting_started/quickstart getting_started/quickstart/index getting_started/benchmarking + getting_started/kv_cache_calculator getting_started/troubleshoot getting_started/faq @@ -85,6 +87,7 @@ Documentation kv_cache/storage_backends/index kv_cache/caching_policies + kv_cache/p2p_sharing :raw-html:`
` @@ -111,6 +114,7 @@ Documentation kv_cache_optimizations/compression/index kv_cache_optimizations/blending + kv_cache_optimizations/layerwise :raw-html:`
` @@ -162,5 +166,5 @@ Documentation community/meetings community/blogs -raw-html:`
` +:raw-html:`
` diff --git a/docs/source/kv_cache/async_loading.rst b/docs/source/kv_cache/async_loading.rst new file mode 100644 index 00000000000..9a55374f11d --- /dev/null +++ b/docs/source/kv_cache/async_loading.rst @@ -0,0 +1,129 @@ +Async Loading in LMCache +========================= + +.. contents:: Table of Contents + :local: + :depth: 2 + +Overview +-------- + +This document explains the principle, benefits, differences from vLLM PR `19330 `_, and limitations of the LMCache ``async_loading`` feature. +It focuses on LMCache v1 integration with vLLM and the internal storage pipeline. + +Key change of components in this feature include: + +- LMCache async lookup client/server (ZMQ-based) +- Storage manager orchestrating backends and concurrency +- Cache engine async API entrypoints +- vLLM adapter integration points + +Principle and Theory +-------------------- + +At a high level, ``async_loading`` decouples scheduler-side lookup from worker-side prefetch/retrieval, allowing overlap between I/O and computation while preserving prefix-based correctness. + +- The scheduler sends lookup requests with token chunk hashes and offsets. +- Worker-side servers perform tiered ``batched_async_contains`` over available backends and eagerly launch non-blocking batched get operations for hit prefixes. +- Completion is tracked via an ``EventManager`` to safely deliver loaded memory objects back to the requesting path. +- A weighted semaphore with an ``AsyncSerializer`` prevents allocator deadlocks by shaping concurrency according to chunk budget. + +The following Mermaid sequence diagram illustrates the end-to-end flow: + +.. mermaid:: + + sequenceDiagram + autonumber + participant S as Scheduler (vLLM) + participant LC as LMCacheAsyncLookupClient + participant WS as LMCacheAsyncLookupServer (Worker) + participant SM as StorageManager + participant BE as Backends (LocalCPU/LocalDisk/FSConnector) + participant EM as EventManager + + S->>LC: lookup(token_ids, lookup_id, request_configs) + note right of LC: Hashes + offsets via TokenDatabase + LC->>WS: ZMQ PUSH multipart [lookup_id, hashes, offsets, configs] + WS->>SM: async_lookup_and_prefetch(lookup_id, keys, cum_chunk_lengths) + SM->>BE: batched_async_contains(lookup_id, keys, pin=True) + alt prefix hit across tiers + BE-->>SM: num_hit_chunks (per tier) + SM->>BE: batched_get_non_blocking(lookup_id, hit_prefix) + BE-->>SM: Future[List[MemoryObj]] + SM->>EM: add_event(EventType.LOADING, lookup_id, gather_all) + SM-->>WS: send_response_to_scheduler(lookup_id, retrieved_length) + WS-->>LC: ZMQ PUSH [lookup_id, num_hit_tokens] + else cache miss + SM-->>WS: send_response_to_scheduler(lookup_id, 0) + WS-->>LC: ZMQ PUSH [lookup_id, 0] + end + + +Architecture (Worker Side) +-------------------------- + +.. mermaid:: + :align: center + + flowchart LR + subgraph Worker + direction TB + A["LMCacheAsyncLookupServer
ZMQ PULL/PUSH"] + B["StorageManager
Async loop (thread)"] + C["AsyncSerializer
WeightedSemaphore"] + D["EventManager
EventType.LOADING"] + end + + subgraph Backends + E["LocalCPUBackend
contains/get"] + F["LocalDiskBackend
async contains/get"] + G["FSConnector
remote FS"] + end + + A --> B + B --> C + B --> D + B -.contains/get.-> E + B -.contains/get.-> F + B -.contains/get.-> G + + style E fill:#dff,stroke:#333,stroke-width:1px + style F fill:#ffd,stroke:#333,stroke-width:1px + style G fill:#dfd,stroke:#333,stroke-width:1px + + +Benefits +-------- + +- Performance overlap + - **I/O–Compute Overlap**: Decoupling lookup/prefetch from loading enables fetching KV chunks while vLLM continues scheduling/computation. +- Robustness and error handling + - **Event-driven Synchronization**: ``EventManager`` ensures safe hand-off of futures and avoids race conditions between threads and the async loop. + - **Backpressure & Deadlock Avoidance**: ``AsyncSerializer`` with a weighted semaphore caps concurrent chunk retrievals based on allocator budget, preventing starvation or allocator lockups. + - **Graceful Miss Path**: Immediate response with ``None`` hit tokens when nothing is retrievable; worker returns quickly without stalling the scheduler. + +Comparison with vLLM Load Failure Recovery feature +--------------------------------------------------- + +The `VLLM_PR_19330 `_ introduces a fault recovery mechanism for vLLM's KV connector infrastructure that enables graceful handling of KV cache load failures by automatically detecting failed block loads and rescheduling only affected requests for recomputation from a valid prefix. +By contrast, LMCache’s ``async_loading`` is an externalized caching layer with its own client/server, storage backends, and concurrency control. + +Limitations +----------- + +- Only works with vllm merged `VLLM_PR_23620 `_ +- Backend support constraint: This feature currently requires backends that implement ``batched_async_contains``; limited to a few backends, e.g.: + - ``LocalCpuBackend`` + - ``LocalDiskBackend`` + - ``S3Connector`` + - ``FSConnector`` + - ``RedisConnector/RedisClusterConnector`` +- save_unfull_chunk: Automatically disabled in async mode for correctness in prefix chunking. + +Future Work +----------- + +- Introduce a default ``batched_async_contains`` implementation, so all backends can support ``async_loading``. +- Refactor ``AsyncSerializer`` to support being enabled together with ``save_unfull_chunk`` and ``PDBackend``. +- Add metrics and observability to track the number of asynchronous lookup requests and the number of occupied ``MemoryObj`` instances. +- Improve the lookup framework by passing vLLM prefix cache hit tokens so that async lookup can skip loading parts already hit in vLLM. diff --git a/docs/source/kv_cache/p2p_sharing.rst b/docs/source/kv_cache/p2p_sharing.rst new file mode 100644 index 00000000000..adebad6328a --- /dev/null +++ b/docs/source/kv_cache/p2p_sharing.rst @@ -0,0 +1,146 @@ +.. _p2p_sharing: + +P2P KV Cache Sharing +==================== + +P2P (Peer-to-Peer) KV cache sharing enables direct cache transfer between multiple serving engine instances without requiring a centralized cache server. This approach provides high-performance cache sharing with reduced latency and improved scalability, especially beneficial in distributed inference scenarios. + +LMCache supports P2P sharing through a controller-based architecture using NIXL (NVIDIA Inference Xfer Library) for optimized data transfer between instances. + +Prerequisites +------------- + +- **Multi-GPU Setup**: Your server should have at least 2 GPUs +- **NIXL**: Install from `NIXL `_ +- **LMCache**: Install from :ref:`installation_guide` + +Configuration +------------- + +Create two configuration files for the P2P sharing setup. + +The only difference between the two configurations is the ``lmcache_instance_id`` and the ``p2p_init_ports`` and ``p2p_lookup_ports`` and ``lmcache_worker_ports``. + +**Instance 1 Configuration (example1.yaml)**: + +.. code-block:: yaml + + chunk_size: 256 + local_cpu: True + max_local_cpu_size: 5 + enable_async_loading: True + + # P2P configurations + enable_p2p: True + p2p_host: "localhost" + p2p_init_ports: 8200 + p2p_lookup_ports: 8201 + transfer_channel: "nixl" + + # Controller configurations + enable_controller: True + lmcache_instance_id: "lmcache_instance_1" + controller_pull_url: "localhost:8300" + controller_reply_url: "localhost:8400" + lmcache_worker_ports: 8500 + + extra_config: + lookup_backoff_time: 0.001 + +**Instance 2 Configuration (example2.yaml)**: + +.. code-block:: yaml + + chunk_size: 256 + local_cpu: True + max_local_cpu_size: 5 + enable_async_loading: True + + # P2P configurations + enable_p2p: True + p2p_host: "localhost" + p2p_init_ports: 8202 + p2p_lookup_ports: 8203 + transfer_channel: "nixl" + + # Controller configurations + enable_controller: True + lmcache_instance_id: "lmcache_instance_2" + controller_pull_url: "localhost:8300" + controller_reply_url: "localhost:8400" + lmcache_worker_ports: 8501 + + extra_config: + lookup_backoff_time: 0.001 + +Setup and Usage +--------------- + +**Step 1: Start the LMCache Controller** + +.. code-block:: bash + + PYTHONHASHSEED=123 lmcache_controller --host localhost --port 9000 --monitor-ports '{"pull": 8300, "reply": 8400}' + +Make sure that the 8300 and 8400 ports are set up in **controller_pull_url** and **controller_reply_url** in the configuration files. +Port 9000 is the controller main port, which is arbitrary and can be changed. + +**Step 2: Start vLLM Engines with LMCache Workers** + +Start vLLM engine 1 at port 8010: + +.. code-block:: bash + + PYTHONHASHSEED=123 UCX_TLS=rc CUDA_VISIBLE_DEVICES=0 LMCACHE_CONFIG_FILE=example1.yaml \ + vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --gpu-memory-utilization 0.8 \ + --port 8010 \ + --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' + +Start vLLM engine 2 at port 8011: + +.. code-block:: bash + + PYTHONHASHSEED=123 UCX_TLS=rc CUDA_VISIBLE_DEVICES=1 LMCACHE_CONFIG_FILE=example2.yaml \ + vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --gpu-memory-utilization 0.8 \ + --port 8011 \ + --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' + +**Step 3: Test P2P Cache Sharing** + +Send a request to vLLM engine 1 to populate the cache: + +.. code-block:: bash + + curl -X POST http://localhost:8010/v1/completions \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", + \"prompt\": \"$(printf 'Explain the significance of KV cache in language models.%.0s' {1..100})\", + \"max_tokens\": 10 + }" + +Send the same request to vLLM engine 2 to demonstrate cache retrieval from **engine 1**: + +.. code-block:: bash + + curl -X POST http://localhost:8011/v1/completions \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", + \"prompt\": \"$(printf 'Explain the significance of KV cache in language models.%.0s' {1..100})\", + \"max_tokens\": 10 + }" + +Expected Output +--------------- + +When the second request successfully retrieves cache from the first instance, you should see logs similar to: + +.. code-block:: bash + + (EngineCore_DP0 pid=2577584)[2025-09-21 00:00:11,706] LMCache INFO:[0m Established connection to peer_init_url localhost:8200. The peer_lookup_url: localhost:8201 (p2p_backend.py:278:lmcache.v1.storage_backend.p2p_backend) + (EngineCore_DP0 pid=2577584)[2025-09-21 00:00:11,792] LMCache INFO: Retrieved 1002 out of total 1002 out of total 1002 tokens. size: 0.1223 gb, cost 60.3595 ms, throughput: 2.0264 GB/s; (cache_engine.py:496:lmcache.v1.cache_engine) + +These logs indicate successful P2P connection establishment and high-throughput cache retrieval. \ No newline at end of file diff --git a/docs/source/kv_cache/storage_backends/gds.rst b/docs/source/kv_cache/storage_backends/gds.rst index a0184f452d8..799198c3705 100644 --- a/docs/source/kv_cache/storage_backends/gds.rst +++ b/docs/source/kv_cache/storage_backends/gds.rst @@ -76,7 +76,7 @@ Setup Example - vllm and lmcache installed (:doc:`Installation Guide <../../getting_started/installation>`) -- Hugging Face access to ``meta-llama/Llama-3.1-70B-Instruct`` +- Hugging Face access to ``meta-llama/Llama-3.1-8B-Instruct`` .. code-block:: bash @@ -127,7 +127,7 @@ and then comment out the ``LMCACHE_CONFIG_FILE`` below: LMCACHE_CONFIG_FILE="gds-backend.yaml" \ LMCACHE_USE_EXPERIMENTAL=True \ vllm serve \ - meta-llama/Llama-3.1-70B-Instruct \ + meta-llama/Llama-3.1-8B-Instruct \ --max-model-len 65536 \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' diff --git a/docs/source/kv_cache/storage_backends/infinistore.rst b/docs/source/kv_cache/storage_backends/infinistore.rst index e28fe2f8bb4..e28456325e4 100644 --- a/docs/source/kv_cache/storage_backends/infinistore.rst +++ b/docs/source/kv_cache/storage_backends/infinistore.rst @@ -1,47 +1,203 @@ InfiniStore =========== -Coming soon... - - .. _infinistore-overview: Overview -------- -`InfiniStore `_ is an open-source high-performance KV store and one of the remote KV storage options LMCache supports. - -Infinistore supports RDMA and NVLink. LMCache's infinistore connector only uses RDMA transport. - -InfiniStore Explanation: ------------------------- +`InfiniStore `_ is an open-source high-performance KV store. It's designed to support LLM Inference clusters, whether the cluster is in prefill-decoding disaggregation mode or not. InfiniStore provides high-performance and low-latency KV cache transfer and KV cache reuse among inference nodes in the cluster. There are two major scenarios how InfiniStore supports: -Prefill-Decoding disaggregation clusters: in such mode inference workloads are separated into two node pools: prefill nodes and decoding nodes. InfiniStore enables KV cache transfer among these two types of nodes, and also KV cache reuse. -Non-disaggregated clusters: in such mode prefill and decoding workloads are mixed on every node. Infinistore serves as an extra large KV cache pool in addition to GPU cache and local CPU cache, and also enables cross-node KV cache reuse. +* Prefill-Decoding disaggregation clusters: in such mode inference workloads are separated into two node pools: prefill nodes and decoding nodes. InfiniStore enables KV cache transfer among these two types of nodes, and also KV cache reuse. +* Non-disaggregated clusters: in such mode prefill and decoding workloads are mixed on every node. InfiniStore serves as an extra large KV cache pool in addition to GPU cache and local CPU cache, and also enables cross-node KV cache reuse. .. image:: ../../assets/InfiniStore-usage.png :alt: InfiniStore Usage Diagram +For more details, please refer to the `InfiniStore Documentation `_. + +InfiniStore supports both RDMA and TCP for transport. LMCache’s InfiniStore connector only uses the RDMA transport. + + +Quick Start +----------- + +Install InfiniStore via pip: + +.. code-block:: bash + + pip install infinistore + +This package includes the InfiniStore server and the Python bindings. + +To build InfiniStore from source, follow the instructions in the `GitHub repository `_. + +Setup and Deployment +~~~~~~~~~~~~~~~~~~~~ + +**Prerequisites:** + +- Machine with at least one GPU for vLLM inference +- RDMA-capable network hardware and drivers +- Python 3.8+ with pip +- vLLM and LMCache installed + +**Step 1: Start InfiniStore Server** + +For InfiniBand based RDMA: + +.. code-block:: bash + + infinistore --service-port 12345 --dev-name mlx5_0 --link-type IB + +For RoCE based RDMA: + +.. code-block:: bash + + infinistore --service-port 12345 --dev-name mlx5_0 --link-type Ethernet + +You can also specify the ``--hint-gid-index`` option to set the GID index for the InfiniStore server. This is useful when you are in a k8s managed environment. + +**Step 2: Create Configuration File** + +Create your ``infinistore-config.yaml``: + +.. code-block:: yaml + + chunk_size: 256 + remote_url: "infinistore://127.0.0.1:12345/?device=mlx5_1" + remote_serde: "naive" + local_cpu: False + max_local_cpu_size: 5 + +**Step 3: Start vLLM with InfiniStore** + +.. code-block:: bash + + LMCACHE_CONFIG_FILE="infinistore-config.yaml" \ + vllm serve \ + Qwen/Qwen2.5-7B-Instruct \ + --seed 42 \ + --max-model-len 16384 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' + +**Step 4: Verify the Setup** + +Test the integration with a sample request: + +.. code-block:: bash + + curl -X POST "http://localhost:8000/v1/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen2.5-7B-Instruct", + "prompt": "The future of AI is", + "max_tokens": 100, + "temperature": 0.7 + }' + +**Debugging Tips:** + +1. **Enable verbose logging:** + + .. code-block:: bash + + infinistore --log-level=debug + +2. **Check server status:** -.. _infinistore-prerequisites: + .. code-block:: bash -Minimum Viable Example: ------------------------- + # Check if the server is running + ps aux | grep infinistore + netstat -tlnp | grep -E "12345" -To use InfiniStore as a remote RDMA-based backend for LMCache, you should have: +Query TTFT Improvement +~~~~~~~~~~~~~~~~~~~~~~~ -- Two bare metal machines on the same rack or data center network. Each machine must have a Mellanox RDMA-capable NIC, e.g., mlx5_0. +Once the OpenAI compatible server is running, let's query it twice and see the TTFT improvement. -This minimal viable example will use OCI BM.GPU4.8 for LMCache + vLLM and BM.HPC2.36 for an InfiniStore backend. +Run vLLM's serving benchmark twice with the following parameters: -Step 1: Create the InfiniStore server +.. code-block:: bash -Set up networking on the + vllm bench serve \ + --backend vllm \ + --model Qwen/Qwen2.5-7B-Instruct \ + --num-prompts 50 \ + --port 8000 \ + --host 127.0.0.1 \ + --dataset-name random \ + --random-input-len 8192 \ + --random-output-len 128 \ + --seed 42 +**Example Output:** +For the first run, you might see: +.. code-block:: text + ============ Serving Benchmark Result ============ + Successful requests: 50 + Benchmark duration (s): 80.97 + Total input tokens: 409544 + Total generated tokens: 6273 + Request throughput (req/s): 0.62 + Output token throughput (tok/s): 77.48 + Total Token throughput (tok/s): 5135.74 + ---------------Time to First Token---------------- + Mean TTFT (ms): 36203.54 + Median TTFT (ms): 34598.91 + P99 TTFT (ms): 76010.91 + -----Time per Output Token (excl. 1st token)------ + Mean TPOT (ms): 290.30 + Median TPOT (ms): 346.25 + P99 TPOT (ms): 412.24 + ---------------Inter-token Latency---------------- + Mean ITL (ms): 290.30 + Median ITL (ms): 386.78 + P99 ITL (ms): 449.83 +For the second run, you should see a significant reduction in TTFT: + +.. code-block:: text + + ============ Serving Benchmark Result ============ + Successful requests: 50 + Benchmark duration (s): 15.14 + Total input tokens: 409544 + Total generated tokens: 6273 + Request throughput (req/s): 3.30 + Output token throughput (tok/s): 414.22 + Total Token throughput (tok/s): 27457.55 + ---------------Time to First Token---------------- + Mean TTFT (ms): 2880.53 + Median TTFT (ms): 3118.50 + P99 TTFT (ms): 12027.24 + -----Time per Output Token (excl. 1st token)------ + Mean TPOT (ms): 73.81 + Median TPOT (ms): 71.12 + P99 TPOT (ms): 91.24 + ---------------Inter-token Latency---------------- + Mean ITL (ms): 73.81 + Median ITL (ms): 63.86 + P99 ITL (ms): 565.44 + +TTFT Improvement: 33.323 seconds (12.6x faster). + +**Tips:** + +- If you want to run vLLM's serving benchmark multiple times, you'll need to either restart the vLLM LMCache server and the InfiniStore server, or change the ``--seed`` parameter to a different value each time, since you've already warmed up LMCache. +- The benchmark result here was produced by running an L40 with 48GB of GPU memory with ``--gpu-memory-utilization 0.8``. You can adjust the GPU memory utilization and increase the max model length to use more of the long context. LMCache TTFT improvement becomes more pronounced as the context length increases! + + +Additional Resources +-------------------- +- `InfiniStore Documentation `_ +- `GitHub Repository `_ diff --git a/docs/source/kv_cache/storage_backends/mooncake.rst b/docs/source/kv_cache/storage_backends/mooncake.rst index fe4890ed938..13e812ad98a 100644 --- a/docs/source/kv_cache/storage_backends/mooncake.rst +++ b/docs/source/kv_cache/storage_backends/mooncake.rst @@ -53,19 +53,13 @@ Setup and Deployment **Step 1: Start Infrastructure Services** -Start the metadata server: - -.. code-block:: bash - - # HTTP metadata server (recommended for development) - mooncake_http_metadata_server - -Start the Mooncake master service: +Start the Mooncake master service (with built‑in HTTP metadata server): .. code-block:: bash # Master service (use -v=1 for verbose logging) - mooncake_master + # The flag enables the integrated HTTP metadata server + mooncake_master --enable_http_metadata_server=1 Expected output: @@ -81,28 +75,35 @@ Create your ``mooncake-config.yaml``: .. code-block:: yaml - chunk_size: 256 - remote_url: "mooncakestore://127.0.0.1:50051/" - remote_serde: "naive" + # LMCache Configuration local_cpu: False - max_local_cpu_size: 5 + remote_url: "mooncakestore://localhost:50051/" + max_local_cpu_size: 2 # small local buffer + numa_mode: "auto" # reduce tail latency with multi-NUMA/multi-NIC + pre_caching_hash_algorithm: sha256_cbor_64bit + # Mooncake Configuration (via extra_config) extra_config: + use_exists_sync: true + save_chunk_meta: False # Enable chunk metadata optimization local_hostname: "localhost" - metadata_server: "http://127.0.0.1:8080/metadata" - protocol: "tcp" + metadata_server: "http://localhost:8080/metadata" + protocol: "rdma" + device_name: "" # leave empty; autodetect device(s) + global_segment_size: 21474836480 # 20 GiB per worker master_server_address: "localhost:50051" - global_segment_size: 3355443200 - local_buffer_size: 1073741824 - transfer_timeout: 1 + local_buffer_size: 0 # rely on LMCache local_cpu as the buffer + mooncake_prefer_local_alloc: true # prefer local segment if available **Step 3: Start vLLM with Mooncake** .. code-block:: bash + # If you see persistent misses (no Mooncake hits), make sure + # PYTHONHASHSEED is fixed across processes (e.g., export PYTHONHASHSEED=0). LMCACHE_CONFIG_FILE="mooncake-config.yaml" \ vllm serve \ - meta-llama/Llama-3.1-70B-Instruct \ + meta-llama/Llama-3.1-8B-Instruct \ --max-model-len 65536 \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' @@ -116,7 +117,7 @@ Test the integration with a sample request: curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ -d '{ - "model": "meta-llama/Llama-3.1-70B-Instruct", + "model": "meta-llama/Llama-3.1-8B-Instruct", "prompt": "The future of AI is", "max_tokens": 100, "temperature": 0.7 @@ -159,7 +160,7 @@ Configuration - Number of tokens per KV chunk * - ``remote_url`` - Required - - Mooncake store connection URL (format: ``mooncakestore://host:port/``) + - Mooncake store connection URL (format: ``mooncakestore://host:port/``). * - ``remote_serde`` - "naive" - Serialization method for remote storage @@ -169,6 +170,12 @@ Configuration * - ``max_local_cpu_size`` - Required - Maximum local CPU cache size in GB (required even when local_cpu is False) + * - ``numa_mode`` + - "auto" + - NUMA binding mode. "auto" is recommended on multi‑NIC/multi‑NUMA systems to reduce tail latency. + * - ``pre_caching_hash_algorithm`` + - "sha256_cbor_64bit" + - Hash used for pre-caching keying. For cross‑process consistency, fix ``PYTHONHASHSEED`` (e.g., export ``PYTHONHASHSEED=0``). **Mooncake Parameters (via extra_config):** @@ -184,28 +191,40 @@ Configuration - Hostname/IP of the local node for Mooncake client identification * - ``metadata_server`` - Required - - Address of metadata coordination server (etcd/Redis/HTTP format) + - HTTP metadata server address. When starting master with ``--enable_http_metadata_server=1``, it exposes this endpoint. * - ``master_server_address`` - Required - Mooncake master service address (host:port format) * - ``protocol`` - - "tcp" - - Communication protocol ("rdma" for high performance, "tcp" for compatibility) + - "rdma" + - Communication protocol ("rdma" for high performance; "tcp" for compatibility) * - ``device_name`` - "" - - RDMA device specification (e.g., "erdma_0,erdma_1" or "mlx5_0,mlx5_1") + - RDMA device specification (e.g., "erdma_0,erdma_1" or "mlx5_0,mlx5_1"). Leave empty for autodetection in most setups. * - ``global_segment_size`` - - 3355443200 - - **Memory size contributed by each vLLM worker** in bytes (~3.1GB) + - 21474836480 + - **Memory size contributed by each vLLM worker** in bytes (e.g., 20 GiB recommended) * - ``local_buffer_size`` - - 1073741824 - - Local buffer allocation size in bytes (~1GB) + - 0 + - Local buffer size in bytes used by Mooncake. Behavior depends on ``save_chunk_meta``: + - When ``save_chunk_meta: False`` (recommended), LMCache uses its local CPU backend for zero‑copy RDMA, so Mooncake's ``local_buffer_size`` can be ``0``. + - When ``save_chunk_meta: True``, Mooncake uses its own local buffer; set this to a proper value (e.g., several GiB). + - Note: Some RDMA NICs have memory registration limits; registering LMCache's large CPU buffer can fail on constrained devices. In those cases, consider enabling ``save_chunk_meta: True`` and sizing ``local_buffer_size`` instead. * - ``transfer_timeout`` - 1 - Timeout for transfer operations in seconds * - ``storage_root_dir`` - "" - The root directory for persistence (e.g., "/mnt/mooncake") + * - ``save_chunk_meta`` + - False + - Whether to save chunk metadata alongside data. Set to ``False`` to enable the optimized zero‑copy path in LMCache. + * - ``use_exists_sync`` + - False + - Use synchronous existence checks to avoid async scheduling overhead in hot paths. + * - ``mooncake_prefer_local_alloc`` + - False + - Prefer allocating on the local segment when possible. .. important:: **Understanding global_segment_size**: This parameter defines the amount of memory each vLLM worker contributes to the distributed memory pool. @@ -213,10 +232,18 @@ Configuration Adjust this value based on your available system memory and expected cache requirements. +.. tip:: + If you consistently get misses (no Mooncake hits), ensure all processes use the same hashing seed: ``export PYTHONHASHSEED=0``. This keeps pre‑caching keys consistent across processes. + +.. note:: + RDMA device(s) usually do not need to be specified; leaving ``device_name`` empty works for most deployments. + Additional Resources -------------------- - `Mooncake Store Architecture `_ +- `Mooncake Store Deployment Guide `_ +- `Mooncake Store Python API Reference `_ - `Transfer Engine Documentation `_ - `Build Instructions `_ - `GitHub Repository `_ diff --git a/docs/source/kv_cache/storage_backends/nixl.rst b/docs/source/kv_cache/storage_backends/nixl.rst index 112bc6c2cac..636accd5172 100644 --- a/docs/source/kv_cache/storage_backends/nixl.rst +++ b/docs/source/kv_cache/storage_backends/nixl.rst @@ -27,24 +27,51 @@ Passed in through ``LMCACHE_CONFIG_FILE=lmcache-config.yaml`` ``LMCACHE_USE_EXPERIMENTAL`` MUST be set. -Example ``lmcache-config.yaml``: +Example ``lmcache-config.yaml`` for POSIX backend: .. code-block:: yaml chunk_size: 256 nixl_buffer_size: 1073741824 # 1GB nixl_buffer_device: cpu - extra_config: {enable_nixl_storage: true, nixl_backend: POSIX, \ - nixl_file_pool_size: 64, nixl_path: /mnt/nixl/cache/} + extra_config: + enable_nixl_storage: true + nixl_backend: POSIX + nixl_pool_size: 64 + nixl_path: /mnt/nixl/cache/ Key settings: - ``nixl_buffer_size``: buffer size for NIXL transfers. -- ``nixl_file_pool_size``: number of files opened at init time for nixl backend. +- ``nixl_pool_size``: number of descriptors opened at init time for nixl backend. - ``nixl_path``: directory under which the storage files will be saved (e.g. /mnt/nixl/). Needed for NIXL backends that store to file. -- ``nixl_backend``: configuration of which nixl backend to use for storage. Options are: ["GDS", "GDS_MT", "POSIX", "HF3FS"]. +- ``nixl_buffer_device``: dictates where the memory managed by NIXL should be on. "cpu" or "cuda" is supported for "GDS" and "GDS_MT" backends - for "POSIX", "HF3FS" & "OBJ", must be "cpu". -- ``nixl_buffer_device``: dictates where the memory managed by NIXL should be on. "cpu" or "cuda" is supported for "GDS" and "GDS_MT" backends - for "POSIX" and "HF3FS", must be "cpu". +- ``nixl_backend``: configuration of which nixl backend to use for storage. + + .. note:: + + Supported backends are: ["GDS", "GDS_MT", "POSIX", "HF3FS", "OBJ"]. + + Backend specific params should be provided via ``extra_config.nixl_backend_params``. Please refer to NIXL documentation for specifics. + +Example ``lmcache-config.yaml`` for OBJ backend using S3 API: + +.. code-block:: yaml + + chunk_size: 256 + nixl_buffer_size: 1073741824 # 1GB + nixl_buffer_device: cpu + extra_config: + enable_nixl_storage: true + nixl_backend: OBJ + nixl_pool_size: 64 + nixl_path: /mnt/nixl/cache/ + nixl_backend_params: + access_key: + secret_key: + bucket: + region: diff --git a/docs/source/kv_cache/storage_backends/sagemaker_hyperpod.rst b/docs/source/kv_cache/storage_backends/sagemaker_hyperpod.rst new file mode 100644 index 00000000000..f46090b7356 --- /dev/null +++ b/docs/source/kv_cache/storage_backends/sagemaker_hyperpod.rst @@ -0,0 +1,71 @@ +SageMaker Hyperpod Backend +=========================== + +Prerequisites +------------- + +Create an Amazon SageMaker HyperPod cluster with tiered storage enabled by following the instructions at: + +https://docs.aws.amazon.com/sagemaker/latest/dg/managed-tier-checkpointing-setup.html + +This enables the ai-toolkit daemon that provides shared memory access for LMCache. + +Example Configuration +--------------------- + +.. code-block:: yaml + + chunk_size: 256 + local_cpu: True + max_local_cpu_size: 5 + remote_url: "sagemaker-hyperpod://$NODE_IP:9200" + +Configuration Parameters +------------------------ + +SageMaker Hyperpod-Specific (in extra_config) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* **sagemaker_hyperpod_bucket**: Bucket name for KV storage namespace (default: "lmcache") +* **sagemaker_hyperpod_shared_memory_name**: Name of shared memory segment (default: "shared_memory"). Set to None to disable shared memory. +* **sagemaker_hyperpod_max_concurrent_requests**: Maximum concurrent HTTP requests allowed in-flight at any moment (application-level throttling, default: 100, minimum: 1). This limit is per LMCache engine instance. With multiple workers (e.g., high TP), each worker creates its own engine with separate limits. +* **sagemaker_hyperpod_max_connections**: Maximum total TCP connections in the connection pool per LMCache engine across all daemons (default: 256, minimum: 1). For typical single-daemon setups, this effectively limits connections from one engine to one daemon. With N workers per node, total connections to the daemon = N × this value. +* **sagemaker_hyperpod_max_connections_per_host**: Maximum TCP connections per LMCache engine to a single daemon address (IP:port) (default: 128, minimum: 1). "Host" refers to the daemon's network address, not the client machine. For today's typical single-daemon setup, this has similar effect as max_connections. This parameter enables future multi-daemon configurations where one engine connects to multiple daemons for load balancing. With N workers per node connecting to the same daemon, total connections = N × this value. Reduce proportionally for high TP setups (e.g., set to 16 for 8 workers to achieve ~128 total connections). +* **sagemaker_hyperpod_timeout_ms**: Timeout for lease acquisition requests in milliseconds (default: 5000, minimum: 100) +* **sagemaker_hyperpod_lease_ttl_s**: Server-side lease timeout in seconds (default: 30.0) +* **sagemaker_hyperpod_put_stream_chunk_bytes**: Chunk size for streaming PUT requests in bytes (default: 65536, minimum: 1024) +* **sagemaker_hyperpod_use_https**: Enable HTTPS instead of HTTP (default: False). **Note**: Ignored if ``remote_url`` already contains ``http://`` or ``https://`` protocol. +* **save_chunk_meta**: Whether to save chunk metadata with data (set False for performance) + +Kubernetes Deployment Requirements +----------------------------------- + +Environment Variable for Node IP +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Add the ``NODE_IP`` environment variable to resolve the local node's IP address: + +.. code-block:: yaml + + env: + - name: NODE_IP + valueFrom: + fieldRef: + fieldPath: status.hostIP + +/dev/shm Volume Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +SageMaker Hyperpod requires /dev/shm for high-performance shared memory operations: + +.. code-block:: yaml + + volumeMounts: + - name: dshm + mountPath: /dev/shm/shared_memory + subPath: shared_memory + + volumes: + - name: dshm + hostPath: + path: /dev/shm diff --git a/docs/source/kv_cache/storage_backends/valkey.rst b/docs/source/kv_cache/storage_backends/valkey.rst index 9a9711b90de..ba714ff5690 100644 --- a/docs/source/kv_cache/storage_backends/valkey.rst +++ b/docs/source/kv_cache/storage_backends/valkey.rst @@ -1,4 +1,74 @@ -ValKey +Valkey ====== -Coming soon... +Overview +-------- + +Valkey is an open source (BSD) high-performance key/value datastore and is a supported option for remote KV Cache offloading in LMCache. +Some other remote backends are :doc:`Mooncake <./mooncake>`, :doc:`Redis <./redis>`, and :doc:`InfiniStore <./infinistore>`. + +Prerequisites +------------- + +To use this connector, you need valkey-glide 2.0 or higher. Valkey Connector currently uses pipelining, which generally results in better RTT compared to Redis Connector. +Pipelining will also be implemented to the Redis Connector in the future. + +.. code-block:: shell + + # Install Valkey-GLIDE (Minimum 2.0.0 or higher) + $ pip install valkey-glide + +Example Configurations +---------------------- + +Basic Valkey Configuration (Standalone mode) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: yaml + + chunk_size: 256 + remote_url: "valkey://:6379" + remote_serde: "naive" + extra_config: + valkey_username: "Your username" + valkey_password: "Your password" + +Standalone-mode Valkey Configuration with database +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: yaml + + chunk_size: 256 + remote_url: "valkey://:6379" + remote_serde: "naive" + extra_config: + valkey_username: "Your username" + valkey_password: "Your password" + valkey_database: 0 + + +Cluster-mode Valkey Configuration (Endpoint) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: yaml + + chunk_size: 256 + remote_url: "valkey://:6379" + remote_serde: "naive" + extra_config: + valkey_mode: "cluster" + valkey_username: "Your username" + valkey_password: "Your password" + +Cluster-mode Valkey Configuration (Nodes) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: yaml + + chunk_size: 256 + remote_url: "valkey://:, :, ... :" + remote_serde: "naive" + extra_config: + valkey_mode: "cluster" + valkey_username: "Your username" + valkey_password: "Your password" diff --git a/docs/source/kv_cache/storage_backends/weka.rst b/docs/source/kv_cache/storage_backends/weka.rst index 035f142aad9..35c2264cc55 100644 --- a/docs/source/kv_cache/storage_backends/weka.rst +++ b/docs/source/kv_cache/storage_backends/weka.rst @@ -79,7 +79,7 @@ Setup Example - vllm and lmcache installed (:doc:`Installation Guide <../../getting_started/installation>`) -- Hugging Face access to ``meta-llama/Llama-3.1-70B-Instruct`` +- Hugging Face access to ``meta-llama/Llama-3.1-8B-Instruct`` .. code-block:: bash @@ -131,7 +131,7 @@ and then comment out the ``LMCACHE_CONFIG_FILE`` below: LMCACHE_CONFIG_FILE="weka-offload.yaml" \ LMCACHE_USE_EXPERIMENTAL=True \ vllm serve \ - meta-llama/Llama-3.1-70B-Instruct \ + meta-llama/Llama-3.1-8B-Instruct \ --max-model-len 65536 \ --kv-transfer-config \ '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' diff --git a/docs/source/kv_cache_optimizations/layerwise.rst b/docs/source/kv_cache_optimizations/layerwise.rst new file mode 100644 index 00000000000..80c1844a536 --- /dev/null +++ b/docs/source/kv_cache_optimizations/layerwise.rst @@ -0,0 +1,111 @@ +Layerwise KV Transfer +===================== + +The storage and loading of KV Cache on a layer granularity is a key optimization that allows for forward pass to "stagger" through its computation as each layer's KV Cache is received instead of only waiting to begin after the entire loading + +CacheBlend is implemented on top of the layerwise codepath in order to pipeline recompute and loading to mask the latency of loading KV Cache. + +.. image:: /_static/basic_codepath.svg + :alt: Basic Codepath + :class: scalable + +.. raw:: html + +
+ + + +
+ Click to open full-size +
+
+ +Architecture Overview +--------------------- + +**CacheEngine** + The main orchestrator containing two primary generators: + + * **Retrieval Generator** (N + 2 yields): Handles layer-by-layer KV cache loading with on-demand memory allocation + * **Storage Generator** (N + 1 yields): Manages layer-by-layer KV cache saving with upfront CPU memory allocation + +**LayerwiseGPUConnector** + Manages GPU-CPU memory transfers with dedicated CUDA streams: + + * **Load GPU Buffer**: Temporary GPU memory for CPU→GPU transfers (``use_gpu: true``) + * **Store GPU Buffer**: Temporary GPU memory for GPU→CPU transfers (``use_gpu: true``) + * **Nested Generators**: ``batched_to_gpu()`` and ``batched_from_gpu()`` handle actual memory operations + +**StorageManager** + Handles persistent storage operations: + + * ``layerwise_batched_get()``: Asynchronous retrieval with ``.result()`` for request-level concurrency + * ``batched_put()``: Stores memory objects to persistent backends + +Execution Flow +~~~~~~~~~~~~~~ + +The layerwise pipeline follows a numbered execution sequence: + +**1. start_load_kv()** + * Initializes Retrieval Generator via ``lmcache_engine.retrieve_layer()`` + * Performs setup (1st ``next()``) and loads layer 0 (2nd ``next()``) + * Creates ``layerwise_retrievers`` list for ongoing layer processing + +**2. wait_for_layer_load()** (repeated for each layer) + * Advances Retrieval Generator via ``next()`` to process layer i + * Triggers ``StorageManager.layerwise_batched_get()`` for async cache retrieval + * Calls GPU Load Generator's ``batched_to_gpu()`` to transfer memory objects to GPU + * **Last request in batch**: Synchronizes ``current_stream.wait_stream(load_stream)`` + +**3. save_kv_layer()** (repeated for each layer) + * **First call only**: Creates Storage Generator with upfront CPU memory allocation + * Advances Storage Generator via ``next()`` to process layer i + * Calls GPU Store Generator's ``batched_from_gpu()`` to transfer GPU data to CPU + * **First request in batch**: Synchronizes ``store_stream.wait_stream(current_stream)`` + +**4. wait_for_save()** + * Finalizes Storage Generator with last ``next()`` call + * Completes all ``StorageManager.batched_put()`` operations + * Performs GPU Store Generator cleanup + +Key Optimizations +~~~~~~~~~~~~~~~~~ + +**Pipelined Memory Operations** + The system overlaps layer N+1 computation with layer N storage. + +**Stream Synchronization** + Three CUDA streams coordinate operations: + + * ``current_stream``: vLLM's forward pass computation + * ``load_stream``: KV cache loading operations + * ``store_stream``: KV cache storing operations + +**Batch-Level Coordination** + Multiple requests are processed together with specialized synchronization: + + * **First request**: Provides store stream synchronization to prevent GPU buffer corruption + * **Last request**: Provides load stream synchronization to ensure KV cache availability + +**Memory Allocation Strategies** + * **Retrieval**: Layer-by-layer allocation + * **Storage**: Upfront allocation for all layers + +**Cache Key Management** + Multi-layer cache engine keys use ``split_layers(N)`` to create per-layer kubernetes_deployment + +Configuration +~~~~~~~~~~~~~ + +Enable layerwise caching by setting: + +.. code-block:: yaml + + use_layerwise: true + +The system automatically selects appropriate layerwise GPU connectors based on configuration: + +* ``VLLMPagedMemLayerwiseGPUConnector``: For standard layerwise operations +* ``VLLMBufferLayerwiseGPUConnector``: When blending is enabled diff --git a/docs/source/production/kubernetes_deployment.rst b/docs/source/production/kubernetes_deployment.rst index 519a7dad2b6..cdf9c56e371 100644 --- a/docs/source/production/kubernetes_deployment.rst +++ b/docs/source/production/kubernetes_deployment.rst @@ -1,4 +1,37 @@ -Kubernetes deployment +Kubernetes Deployment ===================== -Coming soon... +For Kubernetes deployment of vLLM with LMCache integration, we recommend using the `vLLM Production Stack `_ project. This is a specialized production-ready implementation for K8S-native cluster-wide deployment for vllm & lmcache. + +For a quick start guide, please refer to the official `documentation `_ + +and replace the Helm values file with (`values-05-cpu-offloading.yaml `_): + +.. code-block:: yaml + + servingEngineSpec: + runtimeClassName: "" + modelSpec: + - name: "mistral" + repository: "lmcache/vllm-openai" + tag: "latest" + modelURL: "mistralai/Mistral-7B-Instruct-v0.2" + replicaCount: 1 + requestCPU: 10 + requestMemory: "40Gi" + requestGPU: 1 + pvcStorage: "50Gi" + pvcAccessMode: + - ReadWriteOnce + vllmConfig: + maxModelLen: 32000 + + lmcacheConfig: + enabled: true + cpuOffloadingBufferSize: "20" + + hf_token: + +OR + +refer to a detailed `step-by-step tutorial `_ on how to offload KV cache with LMCache in the production stack. diff --git a/examples/agents/prefix_analysis.py b/examples/agents/prefix_analysis.py new file mode 100644 index 00000000000..6bc1f4a7628 --- /dev/null +++ b/examples/agents/prefix_analysis.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from collections import OrderedDict +from typing import List, Optional, Tuple, Union +import argparse +import json + +# Third Party +from tqdm import tqdm +from transformers import AutoTokenizer +import matplotlib.pyplot as plt +import torch + +# Constants +DEFAULT_TOKENIZER = "meta-llama/Llama-3.1-8B" +DEFAULT_TOKENS_PER_GB = 8200 # Default for Llama-3.1; More details here: https://docs.lmcache.ai/getting_started/kv_cache_calculator.html +DEFAULT_POOL_SIZES_GB: List[Union[int, float, str]] = [ + 1, + 2, + 4, + 8, + 16, + 32, + 50, + 100, + 200, + 500, + "unlimited", +] + + +class LRUTokenPool: + """ + Token pool with LRU eviction policy based on token count limit. + """ + + def __init__(self, max_tokens: float) -> None: + self.max_tokens = max_tokens + self.current_tokens = 0 + self.requests: OrderedDict[int, List[int]] = OrderedDict() + + def longest_prefix_len(self, tokens: List[int]) -> Tuple[int, int]: + """ + Find longest prefix match and update LRU ordering. + For request i (1-indexed): + y[i] = y[i-1] + (len(tokens[i]) - max_shared_prefix(tokens[i], any previous)) + """ + best_len = 0 + best_id = -1 + + for req_id, req_tokens in self.requests.items(): + common_len = 0 + for i in range(min(len(tokens), len(req_tokens))): + if tokens[i] == req_tokens[i]: + common_len += 1 + else: + break + + if common_len > best_len: + best_len = common_len + best_id = req_id + + # Update LRU ordering + if best_id != -1: + self.requests.move_to_end(best_id) + + return best_len, best_id + + def longest_common_substring( + self, + request_id: int, + token_tensor: torch.Tensor, + tokens: List[int], + *, + chunk_len: int = 4, + stride_r: int = 4, + chunk_batch: int = 512, + ) -> Tuple[int, float]: + """ + For token_tensor[request_id], chunk it and check whether each chunk + appears contiguously in any previous request (token_tensor[:request_id]). + Returns (total_tokens_matched, elapsed_seconds). + """ + assert token_tensor.ndim == 2, "Expected [N, T] tensor" + N, T = token_tensor.shape + assert 0 <= request_id < N, "request_id out of range" + + if request_id == 0 or T < chunk_len: + return 0, 0 + + r = token_tensor[request_id] # [T] + r = r[: len(tokens)] + Xprev = token_tensor[:request_id] # [request_id, T] + + # Sliding windows for previous rows + Xw = Xprev.unfold(dimension=1, size=chunk_len, step=1) # [R, W, L] + # Chunks of r + r_chunks = r.unfold(dimension=0, size=chunk_len, step=stride_r) # [C, L] + if r_chunks.numel() == 0: + return 0, 0 + + total_matched_chunks = 0 + + # Process in mini-batches to control memory + for b in range(0, r_chunks.size(0), chunk_batch): + rc = r_chunks[b : b + chunk_batch] # [B, L] + eq = Xw[:, :, None, :] == rc[None, None, :, :] + full = eq.all(dim=-1) # [R, W, B] + # Count how many unique chunks matched (across all previous rows) + matched_chunk_indices = torch.unique(full.nonzero(as_tuple=True)[2]) + total_matched_chunks += matched_chunk_indices.numel() + + total_tokens_matched = total_matched_chunks * chunk_len + + return total_tokens_matched, 0 + + def add_request( + self, + request_id: int, + tokens: List[int], + token_tensor: Optional[torch.Tensor] = None, + ) -> None: + """ + Add a request to the pool, evicting LRU entries if necessary. + """ + # Evict until we have space + while self.current_tokens + len(tokens) > self.max_tokens and self.requests: + old_id, old_tokens = self.requests.popitem(last=False) + self.current_tokens -= len(old_tokens) + + # substring matching case + if token_tensor is not None: + token_tensor[old_id, :] = 0 + + # Add new request + self.requests[request_id] = tokens + self.current_tokens += len(tokens) + + +def load_and_tokenize_inputs( + jsonl_path: str, tokenizer_name: str = DEFAULT_TOKENIZER +) -> Tuple[List[List[int]], torch.Tensor]: + """ + Load and tokenize inputs from a JSONL file. + + Returns: + Tuple of (tokenized_sequences_list, tokenized_sequences_tensor) + - tokenized_sequences_list: List of token lists + - tokenized_sequences_tensor: Padded 2D tensor (sequences, tokens) + Sequences are padded with 0s to match the longest sequence. + """ + print(f"Loading tokenizer: {tokenizer_name}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + print(f"Reading and tokenizing inputs from: {jsonl_path}") + tokenized_sequences = [] + + with open(jsonl_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + for line in tqdm(lines, desc="Tokenizing"): + try: + data = json.loads(line.strip()) + input_text = data.get("input", "") + tokens = tokenizer.encode(input_text) + tokenized_sequences.append(tokens) + except Exception as e: + print(f"Warning: Failed to process line: {e}") + tokenized_sequences.append([]) + + if tokenized_sequences: + max_length = max(len(seq) for seq in tokenized_sequences) + num_sequences = len(tokenized_sequences) + + # Create padded tensor (pad with 0s) + tokenized_tensor = torch.zeros((num_sequences, max_length), dtype=torch.long) + for i, seq in enumerate(tokenized_sequences): + if seq: + tokenized_tensor[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) + else: + tokenized_tensor = torch.tensor([], dtype=torch.long) + + return tokenized_sequences, tokenized_tensor + + +def calculate_hit_rate( + token_sequences: List[List[int]], + pool_size: Optional[int] = None, + token_tensor: Optional[torch.Tensor] = None, + method: str = "prefix", +) -> float: + # Use float('inf') for unlimited case to avoid eviction + max_tokens = float("inf") if pool_size is None else pool_size + pool = LRUTokenPool(max_tokens) + + total_tokens = 0 + hit_tokens = 0 + + total_lcs_time_s = 0.0 + lcs_calls = 0 + + for idx, tokens in tqdm(list(enumerate(token_sequences))): + total_tokens += len(tokens) + + if method == "prefix": + if idx > 0: + common, _ = pool.longest_prefix_len(tokens) + hit_tokens += common + pool.add_request(idx, tokens) + elif method == "substring" and token_tensor is not None: + if idx > 0: + common, elapsed = pool.longest_common_substring( + idx, token_tensor, tokens + ) + hit_tokens += common + total_lcs_time_s += elapsed + lcs_calls += 1 + pool.add_request(idx, tokens, token_tensor) + else: + raise ValueError(f"Invalid method: {method}") + + if method == "substring": + avg_ms = (total_lcs_time_s / lcs_calls * 1000.0) if lcs_calls > 0 else 0.0 + print( + f" [Timing] longest_common_substring: total {total_lcs_time_s:.3f}s, " + f"calls {lcs_calls}, avg {avg_ms:.2f} ms" + ) + + return hit_tokens / total_tokens if total_tokens > 0 else 0.0 + + +def analyze_hit_rates_across_pool_sizes( + token_sequences: List[List[int]], + pool_sizes_gb: List[Union[int, float, str]], + tokens_per_gb: int, + token_tensor: Optional[torch.Tensor] = None, +) -> Tuple[List[float], List[float], List[str]]: + print("\nAnalyzing hit rates across pool sizes...") + print("=" * 60) + + prefix_hit_rates = [] + substring_hit_rates = [] + x_labels = [] + + for size_gb in pool_sizes_gb: + if size_gb == "unlimited": + size_tokens = None + x_labels.append("∞") + pool_desc = "unlimited" + token_desc = "" + else: + size_tokens = int(size_gb * tokens_per_gb) + x_labels.append(str(int(size_gb))) + pool_desc = f"{size_gb}GB" + token_desc = f" ({size_tokens:,} tokens)" + + print(f"Testing pool size: {pool_desc}{token_desc}") + + # For every pool size round, we should start from fresh + tensor_copy = token_tensor.clone() if token_tensor is not None else None + + prefix_hit_rate = calculate_hit_rate( + token_sequences, size_tokens, tensor_copy, method="prefix" + ) + prefix_hit_rates.append(prefix_hit_rate) + print(f" Prefix: {prefix_hit_rate:.4f} ({prefix_hit_rate * 100:.2f}%)") + + substring_hit_rate = calculate_hit_rate( + token_sequences, size_tokens, tensor_copy, method="substring" + ) + substring_hit_rates.append(substring_hit_rate) + print( + f" Substring: {substring_hit_rate:.4f} ({substring_hit_rate * 100:.2f}%)\n" + ) + + print("=" * 60) + return prefix_hit_rates, substring_hit_rates, x_labels + + +def plot_hit_rates( + prefix_hit_rates: List[float], + substring_hit_rates: List[float], + x_labels: List[str], + output_path: str, +) -> None: + """ + Generate and save the hit rate vs pool size plot comparing both methods. + """ + plt.figure(figsize=(12, 7)) + + # Plot prefix + plt.plot( + range(len(prefix_hit_rates)), + prefix_hit_rates, + marker="o", + linewidth=2, + markersize=8, + color="#2E86AB", + label="Prefix Matching", + ) + + # Plot substring + plt.plot( + range(len(substring_hit_rates)), + substring_hit_rates, + marker="s", + linewidth=2, + markersize=8, + color="#A23B72", + label="Substring Matching", + ) + + plt.xlabel("Pool Size (GB)", fontsize=12, fontweight="bold") + plt.ylabel("Hit Rate", fontsize=12, fontweight="bold") + plt.title( + "Cache Hit Rate vs Pool Size: Prefix vs Substring Matching", + fontsize=14, + fontweight="bold", + ) + plt.xticks(range(len(x_labels)), x_labels, rotation=45) + plt.grid(True, alpha=0.3, linestyle="--") + + # Set y-axis limit based on max of both methods + max_rate = max(max(prefix_hit_rates), max(substring_hit_rates)) + plt.ylim(0, min(1.0, max_rate * 1.1)) + plt.legend(loc="best", fontsize=10) + + # Annotate prefix matching rates + for i, (rate, label) in enumerate(zip(prefix_hit_rates, x_labels, strict=False)): + plt.annotate( + f"{rate * 100:.1f}%", + xy=(i, rate), + xytext=(0, 8), + textcoords="offset points", + ha="center", + fontsize=8, + color="#2E86AB", + ) + + # Annotate substring matching rates + for i, (rate, label) in enumerate(zip(substring_hit_rates, x_labels, strict=False)): + plt.annotate( + f"{rate * 100:.1f}%", + xy=(i, rate), + xytext=(0, -15), + textcoords="offset points", + ha="center", + fontsize=8, + color="#A23B72", + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"Plot saved to: {output_path}") + + +def parse_arguments() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Analyze prefix cache hit rates across different pool sizes", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s -i trace.jsonl + %(prog)s -i trace.jsonl -o custom_output.png + %(prog)s -i trace.jsonl --pool-sizes 1 2 4 8 16 unlimited + """, + ) + + parser.add_argument( + "-i", + "--input", + type=str, + required=True, + help="Path to input JSONL file (trace.jsonl)", + ) + + parser.add_argument( + "-o", + "--output", + type=str, + default="prefix_cache_hit_rate.png", + help="Path to output plot file (PNG) (default: prefix_cache_hit_rate.png)", + ) + + parser.add_argument( + "--tokenizer", + type=str, + default=DEFAULT_TOKENIZER, + help=f"HuggingFace tokenizer model name (default: {DEFAULT_TOKENIZER})", + ) + + parser.add_argument( + "--tokens-per-gb", + type=int, + default=DEFAULT_TOKENS_PER_GB, + help=f"Conversion factor from GB to tokens " + f"(default: {DEFAULT_TOKENS_PER_GB}). " + "This should be adjusted when using a different tokenizer.", + ) + + parser.add_argument( + "--pool-sizes", + nargs="+", + default=None, + help='Pool sizes in GB to test (space-separated, can include "unlimited"). ' + f"Default: {' '.join(map(str, DEFAULT_POOL_SIZES_GB))}", + ) + + return parser.parse_args() + + +def parse_pool_sizes( + pool_sizes_input: Optional[List[str]], +) -> List[Union[int, float, str]]: + if pool_sizes_input is None: + return DEFAULT_POOL_SIZES_GB + + parsed_sizes: List[Union[int, float, str]] = [] + for size in pool_sizes_input: + if size.lower() == "unlimited": + parsed_sizes.append("unlimited") + else: + try: + parsed_sizes.append(float(size)) + except ValueError: + raise ValueError( + f"Invalid pool size: {size}. Must be a number or 'unlimited'" + ) from None + + return parsed_sizes + + +def main() -> None: + args = parse_arguments() + + # Parse pool sizes + pool_sizes_gb = parse_pool_sizes(args.pool_sizes) + + print("Configuration:") + print(f" Input: {args.input}") + print(f" Output: {args.output}") + print(f" Tokenizer: {args.tokenizer}") + print(f" Tokens per GB: {args.tokens_per_gb}") + print(f" Pool sizes: {pool_sizes_gb}\n") + + # Load and tokenize inputs + token_sequences, token_tensor = load_and_tokenize_inputs(args.input, args.tokenizer) + print(f"Loaded {len(token_sequences)} requests") + print(f"Token tensor shape: {token_tensor.shape} (padded with 0s)") + print(f"First sequence: {token_tensor[0]}") + + # Analyze hit rates using both methods + prefix_hit_rates, substring_hit_rates, x_labels = ( + analyze_hit_rates_across_pool_sizes( + token_sequences, + pool_sizes_gb, + args.tokens_per_gb, + token_tensor, + ) + ) + + # Generate comparison plot + plot_hit_rates(prefix_hit_rates, substring_hit_rates, x_labels, args.output) + print("\nAnalysis complete!") + + +if __name__ == "__main__": + main() diff --git a/examples/agents/requirements.txt b/examples/agents/requirements.txt new file mode 100644 index 00000000000..3c231c63827 --- /dev/null +++ b/examples/agents/requirements.txt @@ -0,0 +1,4 @@ +transformers +matplotlib +tqdm +torch \ No newline at end of file diff --git a/examples/basic_check/README.md b/examples/basic_check/README.md new file mode 100644 index 00000000000..4595818d54c --- /dev/null +++ b/examples/basic_check/README.md @@ -0,0 +1,47 @@ +# LMCache Basic Check Examples + +This is an introduce of examples for the LMCache Basic Check Tool. + +## Example Usage Patterns + +### Testing Storage Manager +```bash +# Basic test +python -m lmcache.v1.basic_check --mode test_storage_manager + +# With custom model +python -m lmcache.v1.basic_check --mode test_storage_manager --model /my_model/ +``` + +### Testing Remote Backend +```bash +# Basic remote test +python -m lmcache.v1.basic_check --mode test_remote + +``` + +### Key Generation +```bash +# Generate 100 keys with 8 concurrent workers +python -m lmcache.v1.basic_check --mode gen --num-keys 100 --concurrency 8 + +# Generate keys with offset (useful for distributed testing) +python -m lmcache.v1.basic_check --mode gen --num-keys 100 --concurrency 8 --offset 1000 +``` + +## Configuration + +Use the provided example configuration: + +```bash +# Copy to default location +cp example_config.yaml ~/.lmcache/config.yaml + +# Or set environment variable +export LMCACHE_CONFIG_PATH=$(pwd)/example_config.yaml +``` + +## Documentation + +For comprehensive documentation, see: +- [Detailed Usage Documentation](../../docs/source/developer_guide/usage/basic_check.rst) diff --git a/examples/basic_check/example_config.yaml b/examples/basic_check/example_config.yaml new file mode 100644 index 00000000000..c2424762731 --- /dev/null +++ b/examples/basic_check/example_config.yaml @@ -0,0 +1,13 @@ +# LMCache Basic Check Example Configuration +# This configuration file provides settings optimized for basic check operations + +# Basic cache settings +chunk_size: 256 +local_cpu: False +max_local_cpu_size: 1 +save_unfull_chunk: False + +remote_url: "fs://host:0/tmp/lmcache_fs_test" + +extra_config: + save_chunk_meta: False \ No newline at end of file diff --git a/examples/blend_kv_v1/blend.py b/examples/blend_kv_v1/blend.py index c34d2a04d67..720efff81af 100644 --- a/examples/blend_kv_v1/blend.py +++ b/examples/blend_kv_v1/blend.py @@ -69,7 +69,7 @@ def build_llm_with_lmcache(lmcache_connector: str, model: str): model=model, kv_transfer_config=ktc, max_model_len=32648, - gpu_memory_utilization=0.8, + gpu_memory_utilization=0.7, enable_prefix_caching=False, enforce_eager=True, ) @@ -145,11 +145,15 @@ def main(): with build_llm_with_lmcache(lmcache_connector, model) as llm: # Define the shared prompt and specific prompts warmup_prompt = tokenizer.encode("Nice to meet you" * 500)[1:] - sys_prompt = tokenizer.encode("You are a very helpful assistant.") + sys_prompt = [1, 733, 16289, 28793] + tokenizer.encode( + "You are a very helpful assistant. " + "Please answer the question with instructions." + ) chunk1_prompt = tokenizer.encode("Hello, how are you?" * 500)[1:] chunk2_prompt = tokenizer.encode("Hello, what's up?" * 500)[1:] chunk3_prompt = tokenizer.encode("Hi, what are you up to?" * 500)[1:] blend_special_str = tokenizer.encode(os.getenv("LMCACHE_BLEND_SPECIAL_STR"))[1:] + first_prompt = ( sys_prompt + blend_special_str @@ -160,6 +164,7 @@ def main(): + chunk3_prompt + blend_special_str + tokenizer.encode("Hello, my name is")[1:] + + [733, 28748, 16289, 28793] ) second_prompt = ( @@ -172,18 +177,20 @@ def main(): + chunk3_prompt + blend_special_str + tokenizer.encode("Hello, how are you?")[1:] + + [733, 28748, 16289, 28793] ) third_prompt = ( sys_prompt + blend_special_str - + chunk3_prompt + + chunk2_prompt + blend_special_str + chunk1_prompt + blend_special_str - + chunk2_prompt + + chunk3_prompt + blend_special_str + tokenizer.encode("Hello, what's up?")[1:] + + [733, 28748, 16289, 28793] ) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) diff --git a/examples/cache_interface/README.md b/examples/cache_interface/README.md index 6933530df95..c162d41ccec 100644 --- a/examples/cache_interface/README.md +++ b/examples/cache_interface/README.md @@ -9,7 +9,11 @@ This will use the port 8000 for 1 vllm. 1. Start the vllm engine at port 8000: ```bash -CUDA_VISIBLE_DEVICES=0 LMCACHE_USE_EXPERIMENTAL=True LMCACHE_CONFIG_FILE=example.yaml vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 4096 --gpu-memory-utilization 0.8 --port 8000 --kv-transfer-config '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}' +CUDA_VISIBLE_DEVICES=0 LMCACHE_USE_EXPERIMENTAL=True LMCACHE_CONFIG_FILE=example.yaml vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --max-model-len 4096 \ + --gpu-memory-utilization 0.8 \ + --port 8000 \ + --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' ``` diff --git a/examples/disagg_prefill/1p1d/README.md b/examples/disagg_prefill/1p1d/README.md index 7c1bcc6333b..a05ec6ad5ed 100644 --- a/examples/disagg_prefill/1p1d/README.md +++ b/examples/disagg_prefill/1p1d/README.md @@ -28,7 +28,7 @@ Press `Ctrl+C` to stop the servers. #### Example benchmark command -If you have vLLM [benchmark_serving.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py), you can run the following command to benchmark the serving performance of the disaggregated prefill setup: +If you have vLLM's serving benchmark tool, you can run the following command to benchmark the serving performance of the disaggregated prefill setup: ```bash vllm bench serve --port 9100 --seed $(date +%s) \ diff --git a/examples/disagg_prefill/xpyd/README.md b/examples/disagg_prefill/xpyd/README.md index e514067604d..64f983a9d1b 100644 --- a/examples/disagg_prefill/xpyd/README.md +++ b/examples/disagg_prefill/xpyd/README.md @@ -62,7 +62,7 @@ python disagg_proxy_server.py \ ``` #### Example benchmark command -If you have vLLM [benchmark_serving.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py), you can run the following command to benchmark the serving performance of the disaggregated prefill setup: +If you have vLLM's serving benchmark tool, you can run the following command to benchmark the serving performance of the disaggregated prefill setup: ```bash vllm bench serve --port 9100 --seed $(date +%s) \ diff --git a/examples/kubernetes/health_probe.py b/examples/kubernetes/health_probe.py index 5886662ddc9..ac86b2e643e 100644 --- a/examples/kubernetes/health_probe.py +++ b/examples/kubernetes/health_probe.py @@ -40,7 +40,12 @@ def main(): msg = ClientMetaMessage( ClientCommand.HEALTH, key=CacheEngineKey( - fmt="", model_name="", world_size=0, worker_id=0, chunk_hash="" + fmt="", + model_name="", + world_size=0, + worker_id=0, + chunk_hash=0, + dtype=torch.float16, ), length=0, fmt=MemoryFormat(1), diff --git a/examples/kv_cache_reuse/remote_backends/fs/example.yaml b/examples/kv_cache_reuse/remote_backends/fs/example.yaml new file mode 100644 index 00000000000..07bf5628705 --- /dev/null +++ b/examples/kv_cache_reuse/remote_backends/fs/example.yaml @@ -0,0 +1,9 @@ +chunk_size: 256 +local_cpu: False +max_local_cpu_size: 1 +save_unfull_chunk: False + +remote_url: "fs://host:0/tmp/lmcache_fs_test" + +extra_config: + save_chunk_meta: False diff --git a/lmcache/cache_engine.py b/lmcache/cache_engine.py index b6f05a464c8..f3c488baf15 100644 --- a/lmcache/cache_engine.py +++ b/lmcache/cache_engine.py @@ -53,6 +53,7 @@ def _make_key(self, chunk_hash: int, fmt: str) -> CacheEngineKey: self.metadata.world_size, self.metadata.worker_id, chunk_hash, + self.metadata.kv_dtype, ) def _num_tokens_in_kv( diff --git a/lmcache/config.py b/lmcache/config.py index 8f9f566717a..de3c819af39 100644 --- a/lmcache/config.py +++ b/lmcache/config.py @@ -33,6 +33,8 @@ class LMCacheEngineMetadata: kv_shape: tuple[int, int, int, int, int] """ whether use MLA""" use_mla: bool = False + """ the role of the current instance (e.g., 'scheduler', 'worker') """ + role: Optional[str] = None """ the first rank of the distributed setting """ # TODO(baoloongmao): first_rank should be configurable first_rank = 0 diff --git a/lmcache/integration/sglang/sglang_adapter.py b/lmcache/integration/sglang/sglang_adapter.py index 37c36e056a7..43a60d13dbd 100644 --- a/lmcache/integration/sglang/sglang_adapter.py +++ b/lmcache/integration/sglang/sglang_adapter.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from dataclasses import dataclass -from typing import Any, List +from typing import Any, List, Optional import uuid # Third Party from sglang.srt.configs.model_config import ModelConfig import torch +import torch.distributed as dist # First Party from lmcache.config import LMCacheEngineMetadata @@ -208,12 +209,27 @@ def __init__( rank: int, k_pool: List[torch.Tensor], v_pool: List[torch.Tensor], + tp_group: Optional[torch.distributed.ProcessGroup] = None, ): super().__init__(sgl_config, tp_size, rank, k_pool, v_pool) self._lmcache_chunk_size = self.lmcache_engine.config.chunk_size self.layerwise_retrievers: List[Any] = [] self.layer_load_layer: List[int] = [] self.kvcaches = [k_pool, v_pool] + self.tp_group = tp_group + self.lookup_id_list: List[str] = [] + + @torch.no_grad() + def global_min_tokens( + self, local_tokens: int, tp_group: dist.ProcessGroup, device: torch.device + ): + # If tensor parallel size is 1, no need for all_reduce + if self.tp_size == 1: + return local_tokens + + t = torch.tensor([local_tokens], dtype=torch.int32, device=device) + dist.all_reduce(t, op=dist.ReduceOp.MIN, group=tp_group) + return int(t.item()) def load_kv_layerwise(self, layer_id: int) -> None: if len(self.layerwise_retrievers) == 0: @@ -230,6 +246,8 @@ def load_kv_layerwise(self, layer_id: int) -> None: for i in sorted(indices_to_remove, reverse=True): del self.layerwise_retrievers[i] del self.layer_load_layer[i] + self.lmcache_engine.lookup_unpin(self.lookup_id_list[i]) + del self.lookup_id_list[i] return @@ -243,26 +261,38 @@ def start_load_kv(self, load_metadata: LoadMetadata) -> int: load_mask = torch.ones_like(token_ids, dtype=torch.bool) load_mask[:offset] = False - layerwise_retriever = self.lmcache_engine.retrieve_layer( + lookup_id = str(uuid.uuid4()) + retrieve_token_num = self.lmcache_engine.lookup( token_ids, - mask=load_mask, + lookup_id=lookup_id, + pin=True, + ) + + retrieve_token_num = self.global_min_tokens( + retrieve_token_num, self.tp_group, torch.device(f"cuda:{self.rank}") + ) + + layerwise_retriever = self.lmcache_engine.retrieve_layer( + token_ids[:retrieve_token_num], + mask=load_mask[:retrieve_token_num], kvcaches=self.kvcaches, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping[:retrieve_token_num], sync=False, ) - retrieve_token_num = next(layerwise_retriever) + next(layerwise_retriever) # Load First Layer next(layerwise_retriever) if retrieve_token_num is None: return 0 - retrieve_token_num = retrieve_token_num.item() self.layerwise_retrievers.append(layerwise_retriever) self.layer_load_layer.append(1) - return retrieve_token_num + self.lookup_id_list.append(lookup_id) + + return retrieve_token_num - offset def store_kv(self, store_metadata: StoreMetadata) -> None: slot_mapping = store_metadata.kv_indices.to(torch.int64).cuda() @@ -284,4 +314,4 @@ def store_kv(self, store_metadata: StoreMetadata) -> None: for _ in range(self.sgl_config.num_hidden_layers): next(layerwise_storer) - self.lmcache_engine.lookup_unpin([lookup_id]) + self.lmcache_engine.lookup_unpin(lookup_id) diff --git a/lmcache/integration/vllm/lmcache_connector_v1.py b/lmcache/integration/vllm/lmcache_connector_v1.py index 3de80124f55..1e2b0a5f4a9 100644 --- a/lmcache/integration/vllm/lmcache_connector_v1.py +++ b/lmcache/integration/vllm/lmcache_connector_v1.py @@ -114,6 +114,10 @@ def get_finished( """ return self._lmcache_engine.get_finished(finished_req_ids) + def get_block_ids_with_load_errors(self) -> set[int]: + """Return block IDs that failed to load during the last interval.""" + return self._lmcache_engine.get_block_ids_with_load_errors() + # ============================== # Scheduler-side methods # ============================== diff --git a/lmcache/integration/vllm/utils.py b/lmcache/integration/vllm/utils.py index 4692f00643f..2cf5f23686c 100644 --- a/lmcache/integration/vllm/utils.py +++ b/lmcache/integration/vllm/utils.py @@ -130,8 +130,13 @@ def create_lmcache_metadata( tuple: (LMCacheEngineMetadata, LMCacheEngineConfig) """ # Third Party - from vllm.utils import get_kv_cache_torch_dtype - + # Try to import from old location before merged https://github.com/vllm-project/vllm/pull/26908 + try: + # Third Party + from vllm.utils.torch_utils import get_kv_cache_torch_dtype + except ImportError: + # Third Party + from vllm.utils import get_kv_cache_torch_dtype # First Party from lmcache.config import LMCacheEngineMetadata diff --git a/lmcache/integration/vllm/vllm_adapter.py b/lmcache/integration/vllm/vllm_adapter.py index 965debf610c..fdc550b1b18 100644 --- a/lmcache/integration/vllm/vllm_adapter.py +++ b/lmcache/integration/vllm/vllm_adapter.py @@ -35,13 +35,16 @@ ParallelConfig, SchedulerConfig, ) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import cdiv, round_down # First Party from lmcache.integration.vllm.utils import ENGINE_NAME from lmcache.logging import init_logger -from lmcache.utils import _lmcache_nvtx_annotate + +# Use LMCache's own math utilities instead of vllm's +# (avoids dependency on vllm internal changes like https://github.com/vllm-project/vllm/pull/27188) +from lmcache.utils import _lmcache_nvtx_annotate, cdiv, round_down from lmcache.v1.cache_engine import LMCacheEngineBuilder # FIXME(Jiayi): temporarily comment this out @@ -49,7 +52,8 @@ logger = init_logger(__name__) -LMCACHE_CUDA_STREAM = torch.cuda.Stream() +if current_platform.is_cuda_alike(): + LMCACHE_CUDA_STREAM = torch.cuda.Stream() SUPPORTED_BACKEND_METADATA = ( FlashAttentionMetadata, diff --git a/lmcache/integration/vllm/vllm_v1_adapter.py b/lmcache/integration/vllm/vllm_v1_adapter.py index abfc4ea6d8f..da1dba89e52 100644 --- a/lmcache/integration/vllm/vllm_v1_adapter.py +++ b/lmcache/integration/vllm/vllm_v1_adapter.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from dataclasses import dataclass, field +from types import SimpleNamespace from typing import TYPE_CHECKING, Any, Generator, Optional, Union import os -import uuid # Third Party from vllm.config import ( @@ -15,11 +15,27 @@ KVConnectorRole, ) from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, get_tp_group, ) from vllm.sampling_params import SamplingParams -from vllm.utils import cdiv, get_kv_cache_torch_dtype + +# First Party +# Use LMCache's own math utilities instead of vllm's +# (avoids dependency on vllm internal changes like https://github.com/vllm-project/vllm/pull/27188) +from lmcache.utils import cdiv + +# Try to import from old location before merged https://github.com/vllm-project/vllm/pull/26908 +try: + # Third Party + from vllm.utils.torch_utils import get_kv_cache_torch_dtype +except ImportError: + # Third Party + from vllm.utils import get_kv_cache_torch_dtype + +# Third Party +from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput from vllm.version import __version__ as VLLM_VERSION import torch @@ -41,6 +57,7 @@ from lmcache.v1.compute.blend import LMCBlenderBuilder from lmcache.v1.config import LMCacheEngineConfig, _validate_and_set_config_value from lmcache.v1.gpu_connector import ( + GPUConnectorInterface, VLLMBufferLayerwiseGPUConnector, VLLMPagedMemGPUConnectorV2, VLLMPagedMemLayerwiseGPUConnector, @@ -52,6 +69,7 @@ ) from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.plugin.plugin_launcher import PluginLauncher +from lmcache.v1.xpu_connector import VLLMPagedMemXPUConnectorV2 if TYPE_CHECKING: # Third Party @@ -320,11 +338,12 @@ def from_request_tracker( # NOTE(vladnosiv): for the input_token_len chunk prefill, # we are required to discard partial chunks, # as new tokens will be added in the next iteration. - num_tokens_to_save = ( - (input_token_len // lmcache_chunk_size * lmcache_chunk_size) - if not is_last_prefill or discard_partial_chunks - else input_token_len - ) + if not is_last_prefill or discard_partial_chunks: + num_tokens_to_save = ( + input_token_len // lmcache_chunk_size * lmcache_chunk_size + ) + else: + num_tokens_to_save = input_token_len # If we need to save, update the number of saved tokens if not skip_save: @@ -428,6 +447,7 @@ def _calculate_draft_layers(vllm_config, model_config): def _init_lmcache_engine( lmcache_config: LMCacheEngineConfig, vllm_config: "VllmConfig", + role: str, ) -> LMCacheEngine: """Initialize the LMCache engine by the given model config and parallel config. This function will check the environment variable @@ -475,10 +495,21 @@ def _init_lmcache_engine( ) # Change current device. - num_gpus = torch.cuda.device_count() + if current_platform.is_cuda_alike(): + logger.info("CUDA device is available. Using CUDA for LMCache engine.") + torch_dev = torch.cuda + dev_name = "cuda" + elif current_platform.is_xpu(): + logger.info("XPU device is available. Using XPU for LMCache engine.") + torch_dev = torch.xpu + dev_name = "xpu" + else: + raise RuntimeError("Unsupported device platform for LMCache engine.") + + num_gpus = torch_dev.device_count() local_rank = parallel_config.rank % num_gpus - torch.cuda.set_device(local_rank) - device = torch.device(f"cuda:{local_rank}") + torch_dev.set_device(local_rank) + device = torch.device(f"{dev_name}:{local_rank}") metadata = LMCacheEngineMetadata( model_config.model, parallel_config.world_size, @@ -487,21 +518,24 @@ def _init_lmcache_engine( kv_dtype, kv_shape, use_mla, + role, ) use_gpu = need_gpu_interm_buffer(lmcache_config) - vllm_gpu_connector: Union[ - VLLMBufferLayerwiseGPUConnector, - VLLMPagedMemGPUConnectorV2, - VLLMPagedMemLayerwiseGPUConnector, - ] + vllm_gpu_connector: Optional[GPUConnectorInterface] if use_mla and lmcache_config.use_layerwise: raise ValueError("layerwise MLA connector is not supported yet") # When use_mla is True, num_kv_head is 1 hidden_dim_size = num_kv_head * head_size - if lmcache_config.use_layerwise: + if role == "scheduler": + vllm_gpu_connector = None + # Create a dummy tpg object with broadcast and broadcast_object methods + tpg = SimpleNamespace() + tpg.broadcast = lambda tensor, src: tensor + tpg.broadcast_object = lambda obj, src: obj + elif lmcache_config.use_layerwise: if lmcache_config.enable_blending: # Use layerwise connector for blending vllm_gpu_connector = VLLMBufferLayerwiseGPUConnector( @@ -521,8 +555,16 @@ def _init_lmcache_engine( dtype=kv_dtype, device=device, ) + tpg = get_tp_group() else: - vllm_gpu_connector = VLLMPagedMemGPUConnectorV2( + if current_platform.is_cuda_alike(): + connector_cls = VLLMPagedMemGPUConnectorV2 + elif current_platform.is_xpu(): + connector_cls = VLLMPagedMemXPUConnectorV2 + else: + raise RuntimeError("No supported connector found for the current platform.") + + vllm_gpu_connector = connector_cls( hidden_dim_size, num_layer, use_gpu=use_gpu, @@ -531,7 +573,7 @@ def _init_lmcache_engine( device=device, use_mla=use_mla, ) - tpg = get_tp_group() + tpg = get_tp_group() engine = LMCacheEngineBuilder.get_or_create( ENGINE_NAME, lmcache_config, @@ -540,14 +582,19 @@ def _init_lmcache_engine( tpg.broadcast, tpg.broadcast_object, ) - + if role == "scheduler" and lmcache_config.enable_scheduler_bypass_lookup: + assert engine.save_only_first_rank or lmcache_config.get_extra_config_value( + "remote_enable_mla_worker_id_as0", metadata.use_mla + ), ( + "enable_scheduler_bypass_lookup is only supported with " + "save_only_first_rank or remote_enable_mla_worker_id_as0" + ) return engine @dataclass class LMCacheConnectorMetadata(KVConnectorMetadata): requests: list[ReqMeta] = field(default_factory=list) - lookup_requests_in_step: list[str] = field(default_factory=list) @_lmcache_nvtx_annotate def add_request(self, req_meta: ReqMeta) -> None: @@ -568,6 +615,7 @@ def __init__( ): self._parent = parent self._vllm_config = vllm_config + self.device = vllm_config.device_config.device self.kv_role = vllm_config.kv_transfer_config.kv_role self.worker_count = vllm_config.parallel_config.tensor_parallel_size config = lmcache_get_or_create_config() @@ -597,23 +645,35 @@ def __init__( ] = [] self._stats_monitor = LMCStatsMonitor.GetOrCreate() if role == KVConnectorRole.SCHEDULER: + self.lmcache_engine: Optional[LMCacheEngine] = None + # Check if bypass lookup is enabled for scheduler + if config.enable_scheduler_bypass_lookup: + # Create LMCacheEngine for scheduler when bypass is enabled + self.lmcache_engine = _init_lmcache_engine( + config, + vllm_config, + role="scheduler", + ) # Create lookup client using factory self.lookup_client = LookupClientFactory.create_lookup_client( - vllm_config, config + vllm_config, config, self.lmcache_engine ) self._unfinished_requests: dict[str, Request] = {} - self._lookup_requests_in_step: list[str] = [] self.lmcache_engine = None else: self.lmcache_engine = _init_lmcache_engine( config, vllm_config, + role="worker", ) self.use_layerwise = config.use_layerwise self.enable_blending = config.enable_blending if self.enable_blending: + assert self.lmcache_engine.gpu_connector is not None, ( + "GPU connector must be available for blending" + ) self.blender = LMCBlenderBuilder.get_or_create( ENGINE_NAME, self.lmcache_engine, @@ -674,6 +734,9 @@ def __init__( self._requests_priority: dict[str, int] = {} + # Track block IDs associated with failed load attempts. + self._invalid_block_ids: set[int] = set() + # TODO(baoloongmao): Internal api server & plugin framework support dp > 1 if vllm_config.parallel_config.data_parallel_rank_local == 0: # Start internal API server if enabled @@ -822,12 +885,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: tokens = request.token_ids # TODO: have a pre-allocated buffer to hold the slot_mappings - slot_mapping = request.slot_mapping.cuda() + slot_mapping = request.slot_mapping.to(self.device) assert len(tokens) == len(slot_mapping) - self._stats_monitor.update_interval_vllm_hit_tokens( - request.load_spec.vllm_cached_tokens - ) token_mask = torch.ones(len(tokens), dtype=torch.bool) masked_token_count = ( request.load_spec.vllm_cached_tokens @@ -871,6 +931,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: slot_mapping=slot_mapping[:lmcache_cached_tokens], request_configs=request.request_configs, req_id=request.req_id, + skip_contains_check=True, ) # Check the result @@ -888,6 +949,95 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: num_retrieved_tokens, num_expected_tokens, ) + """ + Report failed block IDs in case of partial failure. + """ + missing_blocks = self.record_failed_blocks( + request.req_id, + token_mask[:lmcache_cached_tokens], + ret_token_mask, + slot_mapping[:lmcache_cached_tokens], + ) + self._invalid_block_ids.update(missing_blocks) + + self._stats_monitor.update_interval_vllm_hit_tokens( + request.load_spec.vllm_cached_tokens + ) + self._stats_monitor.update_interval_prompt_tokens(len(tokens)) + + def record_failed_blocks( + self, + request_id: str, + expected_mask: torch.Tensor, + ret_mask: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> set[int]: + """Record block IDs associated with failed load attempts. + + Args: + request_id: request id from vLLM. + expected_mask: Boolean tensor indicating which tokens were expected to + be loaded from LMCache. True means the token should be loaded, + False means the token is already cached in vLLM and does not need + to be loaded from LMCache. + ret_mask: Boolean tensor indicating which tokens were actually + successfully retrieved from LMCache. True means the token was + successfully loaded. For example, if 256 tokens are expected to be + loaded, but only 192 tokens are successfully loaded, then the + ret_mask will be a tensor of 256 items like [T, T, ..., F, F, ...] + where the first 192 elements are True and the last 64 elements + are False. + slot_mapping: Tensor indicating slot IDs for each token. The block + ID is computed by dividing the slot ID by the block size. + + Example: + expected_mask = [F, T, T, T] meaning the 1st is in vLLM cache + ret_mask = [F, T, F, F] meaning failure from loading the 3rd + missing_mask = expected_mask & ~ret_mask = [F, F, T, T] + missing_indices = [2, 3] + then missing_blocks is calculated from slot_mapping and missing_indices + + Returns: + set[int]: Set of block IDs that failed to load. + """ + + if expected_mask.numel() == 0: + return set() + + expected_mask_cpu = expected_mask.to(device="cpu", dtype=torch.bool) + ret_mask_cpu = ret_mask.to(device="cpu", dtype=torch.bool) + + if ret_mask_cpu.shape[0] != expected_mask_cpu.shape[0]: + logger.debug("expected_mask_cpu.shape[0] != ret_mask_cpu.shape[0]") + return set() + + missing_mask = expected_mask_cpu & ~ret_mask_cpu + if not torch.any(missing_mask): + return set() + + missing_indices = torch.nonzero(missing_mask, as_tuple=False).view(-1) + if missing_indices.numel() == 0: + return set() + + slot_mapping_cpu = slot_mapping.to(device="cpu", dtype=torch.long) + if slot_mapping_cpu.shape[0] > missing_mask.shape[0]: + slot_mapping_cpu = slot_mapping_cpu[: missing_mask.shape[0]] + + missing_blocks_tensor = torch.unique( + slot_mapping_cpu[missing_indices] // self._block_size + ) + missing_blocks = {int(block.item()) for block in missing_blocks_tensor} + + if not missing_blocks: + return set() + + logger.warning( + "Request %s failed to load %d tokens across %d blocks", + request_id, + missing_indices.numel(), + len(missing_blocks), + ) + return missing_blocks @_lmcache_nvtx_annotate def wait_for_layer_load(self, layer_name: str) -> None: @@ -968,7 +1118,7 @@ def save_kv_layer( assert len(slot_mapping) == len(token_ids) # TODO: have a pre-allocated buffer to hold the slot_mappings - slot_mapping = slot_mapping.cuda() + slot_mapping = slot_mapping.to(self.device) if self.kv_role == "kv_producer": skip_leading_tokens = 0 @@ -1022,8 +1172,6 @@ def wait_for_save(self): connector_metadata = self._parent._get_connector_metadata() assert isinstance(connector_metadata, LMCacheConnectorMetadata) - self.lmcache_engine.lookup_unpin(connector_metadata.lookup_requests_in_step) - if self.kv_role == "kv_consumer": # Don't do save if the role is kv_consumer return @@ -1031,6 +1179,10 @@ def wait_for_save(self): if self.use_layerwise: for layerwise_storer in self.layerwise_storers: next(layerwise_storer) + + # unpin the kv caches according to req_id + for request in connector_metadata.requests: + self.lmcache_engine.lookup_unpin(request.req_id) return assert len(self.kv_caches) > 0 @@ -1039,6 +1191,9 @@ def wait_for_save(self): assert self.lmcache_engine is not None for request in connector_metadata.requests: + # unpin the kv caches according to req_id + self.lmcache_engine.lookup_unpin(request.req_id) + save_spec = request.save_spec if ( save_spec is None or not save_spec.can_save @@ -1052,7 +1207,7 @@ def wait_for_save(self): assert len(slot_mapping) == len(token_ids) # TODO: have a pre-allocated buffer to hold the slot_mappings - slot_mapping = slot_mapping.cuda() + slot_mapping = slot_mapping.to(self.device) skip_leading_tokens = save_spec.skip_leading_tokens if self.kv_role == "kv_producer": @@ -1086,13 +1241,14 @@ def wait_for_save(self): if request.disagg_spec: request.disagg_spec.is_last_prefill = True else: - token_len = len(token_ids) - aligned_token_len = ( - token_len // self._lmcache_chunk_size * self._lmcache_chunk_size - ) - token_ids = token_ids[:aligned_token_len] - store_mask = store_mask[:aligned_token_len] - slot_mapping = slot_mapping[:aligned_token_len] + if not self.enable_blending: + token_len = len(token_ids) + aligned_token_len = ( + token_len // self._lmcache_chunk_size * self._lmcache_chunk_size + ) + token_ids = token_ids[:aligned_token_len] + store_mask = store_mask[:aligned_token_len] + slot_mapping = slot_mapping[:aligned_token_len] self.lmcache_engine.store( token_ids, @@ -1104,10 +1260,13 @@ def wait_for_save(self): request_configs=request.request_configs, ) - # NOTE(Jiayi): We assume all tokens are saved - save_spec.skip_leading_tokens = len(token_ids) - if request.disagg_spec: - request.disagg_spec.num_transferred_tokens = len(token_ids) + # Update skip_leading_tokens only on last rank to ensure + # each PP stage stores its own KV cache + if get_pp_group().is_last_rank: + # NOTE(Jiayi): We assume all tokens are saved + save_spec.skip_leading_tokens = len(token_ids) + if request.disagg_spec: + request.disagg_spec.num_transferred_tokens = len(token_ids) @_lmcache_nvtx_annotate def get_finished( @@ -1115,6 +1274,11 @@ def get_finished( ) -> tuple[Optional[set[str]], Optional[set[str]]]: return None, None + def get_block_ids_with_load_errors(self) -> set[int]: + invalid_blocks = self._invalid_block_ids.copy() + self._invalid_block_ids.clear() + return invalid_blocks + ################### # Scheduler side APIs #################### @@ -1157,12 +1321,8 @@ def get_num_new_matched_tokens( request_configs = extract_request_configs(request.sampling_params) if self.skip_last_n_tokens > 0: token_ids = token_ids[: -self.skip_last_n_tokens] - if self.async_loading: - lookup_id = request.request_id - else: - lookup_id = str(uuid.uuid4()) - self._lookup_requests_in_step.append(lookup_id) + lookup_id = request.request_id num_external_hit_tokens = self.lookup_client.lookup( token_ids, @@ -1171,7 +1331,7 @@ def get_num_new_matched_tokens( ) if num_external_hit_tokens is None: - logger.info( + logger.debug( "Reqid: %s, Total tokens %d, LMCache hit tokens: None.", request.request_id, request.num_tokens, @@ -1220,6 +1380,10 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens: int) if the CacheManager this allocated blocks for us. """ + # Clear local status in lookup client when a new request is + # successfully scheduled. + self.lookup_client.clear_lookup_status(request.request_id) + kv_transfer_params = ( request.kv_transfer_params if hasattr(request, "kv_transfer_params") @@ -1290,10 +1454,6 @@ def build_connector_meta( meta = LMCacheConnectorMetadata() - # set and update lookup requests for unpin - meta.lookup_requests_in_step = self._lookup_requests_in_step - self._lookup_requests_in_step = [] - for finished_req_id in scheduler_output.finished_req_ids: self._request_trackers.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None) diff --git a/lmcache/observability.py b/lmcache/observability.py index d0675adb162..8fcd62ac544 100644 --- a/lmcache/observability.py +++ b/lmcache/observability.py @@ -32,6 +32,7 @@ class LMCacheStats: interval_lookup_tokens: int interval_lookup_hits: int interval_vllm_hit_tokens: int + interval_prompt_tokens: int interval_remote_read_requests: int interval_remote_read_bytes: int @@ -68,11 +69,29 @@ class LMCacheStats: retrieve_speed: List[float] # Tokens per second store_speed: List[float] # Tokens per second + # P2P transfer metrics + interval_p2p_requests: int + interval_p2p_transferred_tokens: int + p2p_time_to_transfer: List[float] + p2p_transfer_speed: List[float] # Tokens per second + + # request lookup hit rates + # use bucket of interval_lookup_hit_rates to represents non-0 hit requests + # use interval_lookup_0_hit_requests to represents 0 hit requests + interval_lookup_hit_rates: List[float] + interval_lookup_0_hit_requests: int + @dataclass class LookupRequestStats: num_tokens: int hit_tokens: int + is_finished: bool + + def hit_rate(self): + if self.num_tokens == 0: + return 0 + return self.hit_tokens / self.num_tokens @dataclass @@ -113,6 +132,23 @@ def store_speed(self): return self.num_tokens / self.time_to_store() +@dataclass +class P2PTransferRequestStats: + num_tokens: int + start_time: float + end_time: float + + def time_to_transfer(self): + if self.end_time == 0: + return 0 + return self.end_time - self.start_time + + def transfer_speed(self): + if self.time_to_transfer() == 0: + return 0 + return self.num_tokens / self.time_to_transfer() + + class LMCStatsMonitor: def __init__(self): # Interval metrics that will be reset after each log @@ -126,6 +162,14 @@ def __init__(self): self.interval_lookup_tokens = 0 # total requested tokens lookup self.interval_lookup_hits = 0 # total hit tokens lookup self.interval_vllm_hit_tokens = 0 # total hit tokens in vllm + self.interval_prompt_tokens = 0 # total prompt tokens + self.interval_lookup_0_hit_requests = 0 + + # P2P transfer metrics + self.interval_p2p_requests = 0 + self.interval_p2p_transferred_tokens = 0 + self.p2p_requests: Dict[int, P2PTransferRequestStats] = {} + self.p2p_request_id = 0 # remote backends read/write metrics self.interval_remote_read_requests = 0 @@ -158,26 +202,42 @@ def __init__(self): self.retrieve_requests: Dict[int, RetrieveRequestStats] = {} self.store_requests: Dict[int, StoreRequestStats] = {} + self.lookup_requests: Dict[int, LookupRequestStats] = {} self.retrieve_request_id = 0 self.store_request_id = 0 + self.lookup_request_id = 0 @thread_safe - def on_lookup_request(self, num_tokens: int): + def on_lookup_request(self, num_tokens: int) -> int: """ This function is called when a lookup request is sent to the cache. It will record the number of tokens requested. """ + lookup_stats = LookupRequestStats( + num_tokens=num_tokens, + hit_tokens=0, + is_finished=False, + ) self.interval_lookup_requests += 1 self.interval_lookup_tokens += num_tokens + self.lookup_requests[self.lookup_request_id] = lookup_stats + self.lookup_request_id += 1 + return self.lookup_request_id - 1 @thread_safe - def on_lookup_finished(self, num_hit_tokens: int): + def on_lookup_finished(self, request_id: int, num_hit_tokens: int): """ This function is called when a lookup request is finished. It will record the number of tokens hit. """ + assert request_id in self.lookup_requests + lookup_stats = self.lookup_requests[request_id] + lookup_stats.hit_tokens = num_hit_tokens + lookup_stats.is_finished = True self.interval_lookup_hits += num_hit_tokens + if num_hit_tokens == 0: + self.interval_lookup_0_hit_requests += 1 @thread_safe def on_retrieve_request(self, num_tokens: int) -> int: @@ -232,6 +292,26 @@ def on_store_finished(self, request_id: int, num_tokens: int = -1): if num_tokens >= 0: store_stats.num_tokens = num_tokens + @thread_safe + def on_p2p_transfer_request(self, num_tokens: int) -> int: + curr_time = time.time() + self.interval_p2p_requests += 1 + self.p2p_requests[self.p2p_request_id] = P2PTransferRequestStats( + num_tokens=num_tokens, + start_time=curr_time, + end_time=0, + ) + self.p2p_request_id += 1 + return self.p2p_request_id - 1 + + @thread_safe + def on_p2p_transfer_finished(self, request_id: int): + curr_time = time.time() + assert request_id in self.p2p_requests + p2p_stats = self.p2p_requests[request_id] + self.interval_p2p_transferred_tokens += p2p_stats.num_tokens + p2p_stats.end_time = curr_time + @thread_safe def update_local_cache_usage(self, usage: int): self.local_cache_usage_bytes = usage @@ -300,6 +380,10 @@ def update_pinned_memory_objs_count(self, delta: int): def update_interval_vllm_hit_tokens(self, delta: int): self.interval_vllm_hit_tokens += delta + @thread_safe + def update_interval_prompt_tokens(self, delta: int): + self.interval_prompt_tokens += delta + def _clear(self): """ Clear all the distribution stats @@ -314,6 +398,7 @@ def _clear(self): self.interval_lookup_tokens = 0 self.interval_lookup_hits = 0 self.interval_vllm_hit_tokens = 0 + self.interval_prompt_tokens = 0 self.interval_remote_read_requests = 0 self.interval_remote_read_bytes = 0 @@ -333,6 +418,11 @@ def _clear(self): self.interval_local_cpu_evict_keys_count = 0 self.interval_local_cpu_evict_failed_count = 0 + self.interval_p2p_requests = 0 + self.interval_p2p_transferred_tokens = 0 + + self.interval_lookup_0_hit_requests = 0 + new_retrieve_requests = {} for request_id, retrieve_stats in self.retrieve_requests.items(): if retrieve_stats.end_time == 0: @@ -345,6 +435,18 @@ def _clear(self): new_store_requests[request_id] = store_stats self.store_requests = new_store_requests + new_p2p_requests = {} + for request_id, p2p_stats in self.p2p_requests.items(): + if p2p_stats.end_time == 0: + new_p2p_requests[request_id] = p2p_stats + self.p2p_requests = new_p2p_requests + + new_lookup_requests = {} + for request_id, lookup_stats in self.lookup_requests.items(): + if not lookup_stats.is_finished: + new_lookup_requests[request_id] = lookup_stats + self.lookup_requests = new_lookup_requests + @thread_safe def get_stats_and_clear(self) -> LMCacheStats: """ @@ -365,25 +467,41 @@ def get_stats_and_clear(self) -> LMCacheStats: else self.interval_lookup_hits / self.interval_lookup_tokens ) - def filter_out_invalid(stats: List[float]): + def filter_out_zeros(stats: List[float]): return [x for x in stats if x != 0] - time_to_retrieve = filter_out_invalid( + time_to_retrieve = filter_out_zeros( [stats.time_to_retrieve() for stats in self.retrieve_requests.values()] ) - time_to_store = filter_out_invalid( + time_to_store = filter_out_zeros( [stats.time_to_store() for stats in self.store_requests.values()] ) - retrieve_speed = filter_out_invalid( + retrieve_speed = filter_out_zeros( [stats.retrieve_speed() for stats in self.retrieve_requests.values()] ) - store_speed = filter_out_invalid( + store_speed = filter_out_zeros( [stats.store_speed() for stats in self.store_requests.values()] ) + p2p_time_to_transfer = filter_out_zeros( + [stats.time_to_transfer() for stats in self.p2p_requests.values()] + ) + + p2p_transfer_speed = filter_out_zeros( + [stats.transfer_speed() for stats in self.p2p_requests.values()] + ) + + request_lookup_hit_rates = filter_out_zeros( + [ + stats.hit_rate() + for stats in self.lookup_requests.values() + if stats.is_finished + ] + ) + ret = LMCacheStats( interval_retrieve_requests=self.interval_retrieve_requests, interval_store_requests=self.interval_store_requests, @@ -419,6 +537,13 @@ def filter_out_invalid(stats: List[float]): retrieve_speed=retrieve_speed, store_speed=store_speed, interval_vllm_hit_tokens=self.interval_vllm_hit_tokens, + interval_p2p_requests=self.interval_p2p_requests, + interval_p2p_transferred_tokens=self.interval_p2p_transferred_tokens, + p2p_time_to_transfer=p2p_time_to_transfer, + p2p_transfer_speed=p2p_transfer_speed, + interval_lookup_hit_rates=request_lookup_hit_rates, + interval_prompt_tokens=self.interval_prompt_tokens, + interval_lookup_0_hit_requests=self.interval_lookup_0_hit_requests, ) self._clear() return ret @@ -519,6 +644,12 @@ def __init__(self, metadata: LMCacheEngineMetadata): labelnames=labelnames, ) + self.counter_num_prompt_tokens = self._counter_cls( + name="lmcache:num_prompt_tokens", + documentation="Number of prompt tokens in lmcache", + labelnames=labelnames, + ) + self.counter_num_remote_read_requests = self._counter_cls( name="lmcache:num_remote_read_requests", documentation="Total number of requests read from " @@ -563,6 +694,12 @@ def __init__(self, metadata: LMCacheEngineMetadata): labelnames=labelnames, ) + self.counter_lookup_0_hit_requests = self._counter_cls( + name="lmcache:lookup_0_hit_requests", + documentation="Total number of 0 hit lookup requests", + labelnames=labelnames, + ) + self.gauge_retrieve_hit_rate = self._gauge_cls( name="lmcache:retrieve_hit_rate", documentation="Hit rate of lmcache retrieve requests since last log", @@ -710,6 +847,56 @@ def __init__(self, metadata: LMCacheEngineMetadata): buckets=store_speed_buckets, ) + # P2P transfer metrics + p2p_time_buckets = [ + 0.001, # 1ms + 0.005, # 5ms + 0.01, # 10ms + 0.02, # 20ms + 0.04, # 40ms + 0.06, # 60ms + 0.08, # 80ms + 0.1, # 100ms + 0.25, # 250ms + 0.5, # 500ms + 0.75, # 750ms + 1.0, # 1s + 2.5, # 2.5s + 5.0, # 5s + 7.5, # 7.5s + 10.0, # 10s + ] + self.histogram_p2p_time_to_transfer = self._histogram_cls( + name="lmcache:p2p_time_to_transfer", + documentation="Time to transfer via P2P (seconds)", + labelnames=labelnames, + buckets=p2p_time_buckets, + ) + + p2p_speed_buckets = [ + 1, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + ] + self.histogram_p2p_transfer_speed = self._histogram_cls( + name="lmcache:p2p_transfer_speed", + documentation="P2P transfer speed (tokens per second)", + labelnames=labelnames, + buckets=p2p_speed_buckets, + ) + remote_time_to_get = [ 1, 5, @@ -785,6 +972,25 @@ def __init__(self, metadata: LMCacheEngineMetadata): buckets=remote_time_to_get_sync, ) + request_cache_hit_rate = [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + ] + self.histogram_request_cache_hit_rate = self._histogram_cls( + name="lmcache:request_cache_hit_rate", + documentation="Request cache hit rate", + labelnames=labelnames, + buckets=request_cache_hit_rate, + ) + # Ping latency metrics: use a gauge to record the latest ping latency self.gauge_remote_ping_latency = self._gauge_cls( name="lmcache:remote_ping_latency", @@ -833,6 +1039,17 @@ def _dynamic_metrics(self, labelnames): multiprocess_mode="livemostrecent", ).labels(**self.labels) + event_statuses = ["ongoing", "done", "not_found"] + for status in event_statuses: + metric_name = f"storage_events_{status}_count" + gauge = self._gauge_cls( + name=f"lmcache:{metric_name}", + documentation=f"The number of {status.replace('_', ' ')} events", + labelnames=labelnames, + multiprocess_mode="sum", + ).labels(**self.labels) + setattr(self, metric_name, gauge) + def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) @@ -867,6 +1084,7 @@ def log_prometheus(self, stats: LMCacheStats): self._log_counter(self.counter_num_stored_tokens, stats.interval_stored_tokens) self._log_counter(self.counter_num_lookup_tokens, stats.interval_lookup_tokens) self._log_counter(self.counter_num_lookup_hits, stats.interval_lookup_hits) + self._log_counter(self.counter_num_prompt_tokens, stats.interval_prompt_tokens) self._log_counter( self.counter_num_vllm_hit_tokens, stats.interval_vllm_hit_tokens ) @@ -898,6 +1116,10 @@ def log_prometheus(self, stats: LMCacheStats): self.counter_local_cpu_evict_failed_count, stats.interval_local_cpu_evict_failed_count, ) + self._log_counter( + self.counter_lookup_0_hit_requests, + stats.interval_lookup_0_hit_requests, + ) self._log_gauge(self.gauge_retrieve_hit_rate, stats.retrieve_hit_rate) @@ -917,6 +1139,12 @@ def log_prometheus(self, stats: LMCacheStats): self._log_histogram(self.histogram_store_speed, stats.store_speed) + self._log_histogram( + self.histogram_p2p_time_to_transfer, stats.p2p_time_to_transfer + ) + + self._log_histogram(self.histogram_p2p_transfer_speed, stats.p2p_transfer_speed) + self._log_histogram( self.histogram_remote_time_to_get, stats.interval_remote_time_to_get ) @@ -927,6 +1155,9 @@ def log_prometheus(self, stats: LMCacheStats): self.histogram_remote_time_to_get_sync, stats.interval_remote_time_to_get_sync, ) + self._log_histogram( + self.histogram_request_cache_hit_rate, stats.interval_lookup_hit_rates + ) self._log_gauge( self.gauge_remote_ping_latency, stats.interval_remote_ping_latency ) diff --git a/lmcache/usage_context.py b/lmcache/usage_context.py index b997a313440..229e2e6c601 100644 --- a/lmcache/usage_context.py +++ b/lmcache/usage_context.py @@ -243,6 +243,11 @@ def _get_gpu_info(self): gpu_count = torch.cuda.device_count() gpu_type = device_property.name gpu_memory_per_device = device_property.total_memory + elif torch.xpu.is_available(): + device_property = torch.xpu.get_device_properties(0) + gpu_count = torch.xpu.device_count() + gpu_type = device_property.name + gpu_memory_per_device = device_property.total_memory else: gpu_count = psutil.cpu_count(logical=False) gpu_type = platform.processor() diff --git a/lmcache/utils.py b/lmcache/utils.py index 845045204fd..68745af345f 100644 --- a/lmcache/utils.py +++ b/lmcache/utils.py @@ -3,8 +3,8 @@ from __future__ import annotations # Standard -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import asyncio import hashlib import threading @@ -39,6 +39,18 @@ def decorator(func): # Type definition KVCache = Tuple[Tuple[torch.Tensor, torch.Tensor], ...] + +# Math utility functions +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +def round_down(x: int, y: int) -> int: + """Round down x to the nearest multiple of y.""" + return (x // y) * y + + try: # First Party from lmcache import _version # type: ignore[attr-defined] @@ -62,6 +74,7 @@ class DiskCacheMetadata: size: int # in bytes shape: Optional[torch.Size] = None dtype: Optional[torch.dtype] = None + cached_positions: Optional[torch.Tensor] = None fmt: Optional[MemoryFormat] = None pin_count: int = 0 @@ -91,24 +104,56 @@ def can_evict(self) -> bool: torch.bfloat16: "bfloat16", torch.float: "float", torch.float32: "float", - torch.float64: "double", torch.double: "double", - torch.uint8: "fp8", - torch.float8_e4m3fn: "fp8_e4m3", - torch.float8_e5m2: "fp8_e5m2", + torch.float64: "double", + torch.int8: "int8", + torch.uint8: "uint8", + torch.int16: "int16", + torch.int32: "int32", + torch.int64: "int64", + torch.bool: "bool", } +# FP8 variants (PyTorch ≥2.1) +if hasattr(torch, "float8_e4m3fn"): + TORCH_DTYPE_TO_STR_DTYPE[torch.float8_e4m3fn] = "fp8_e4m3" +if hasattr(torch, "float8_e4m3fnuz"): + TORCH_DTYPE_TO_STR_DTYPE[torch.float8_e4m3fnuz] = "fp8_e4m3" +if hasattr(torch, "float8_e5m2"): + TORCH_DTYPE_TO_STR_DTYPE[torch.float8_e5m2] = "fp8_e5m2" +if hasattr(torch, "float8_e5m2fnuz"): + TORCH_DTYPE_TO_STR_DTYPE[torch.float8_e5m2fnuz] = "fp8_e5m2" + STR_DTYPE_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_STR_DTYPE.items()} -@dataclass(order=True) +def parse_cache_key(key_str: str) -> Union[CacheEngineKey, LayerCacheEngineKey]: + """Parse a key string into either a CacheEngineKey or LayerCacheEngineKey. + + Args: + key_str: String in format: + fmt@model@world_size@worker_id@chunk_hash[@layer_id][@tag%value...] + + Returns: + CacheEngineKey if no layer_id, LayerCacheEngineKey if valid layer_id + """ + parts = key_str.strip().split("@") + if len(parts) >= 6 and parts[5].isdigit(): + return LayerCacheEngineKey.from_string(key_str) + return CacheEngineKey.from_string(key_str) + + +@dataclass(slots=True) class CacheEngineKey: fmt: str model_name: str world_size: int worker_id: int chunk_hash: int - request_configs: Optional[dict] = None + dtype: torch.dtype + request_configs: Optional[dict] = field(default_factory=dict) + tags: Optional[tuple] = field(init=False, default=None) + _dtype_str: str = field(init=False, default="") def __post_init__(self): tag_list = None @@ -118,6 +163,9 @@ def __post_init__(self): if tag_list is None: tag_list = [] tag_list.append((k[len("lmcache.tag.") :], v)) + if self.dtype not in TORCH_DTYPE_TO_STR_DTYPE: + raise ValueError(f"Unsupported dtype in CacheEngineKey: {self.dtype}") + self._dtype_str = TORCH_DTYPE_TO_STR_DTYPE[self.dtype] # use tuple to save tags self.tags = None if tag_list is None else tuple(tag_list) @@ -129,6 +177,7 @@ def __hash__(self): self.world_size, self.worker_id, self.chunk_hash, + self._dtype_str, self.tags, ) ) @@ -141,6 +190,7 @@ def __eq__(self, other): and self.world_size == other.world_size and self.worker_id == other.worker_id and self.chunk_hash == other.chunk_hash + and self.dtype == other.dtype and self.tags == other.tags ) @@ -149,7 +199,7 @@ def __eq__(self, other): def to_string(self): s = ( f"{self.fmt}@{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash:x}" + f"@{self.worker_id}@{self.chunk_hash:x}@{self._dtype_str}" ) if self.tags is not None and len(self.tags) != 0: tags = [f"{k}%{v}" for k, v in self.tags] @@ -167,6 +217,7 @@ def split_layers(self, num_layers: int) -> List["LayerCacheEngineKey"]: self.world_size, self.worker_id, self.chunk_hash, + self.dtype, self.request_configs, layer_id, ) @@ -181,6 +232,7 @@ def get_first_layer(self) -> "LayerCacheEngineKey": self.world_size, self.worker_id, self.chunk_hash, + self.dtype, self.request_configs, 0, ) @@ -189,12 +241,12 @@ def get_first_layer(self) -> "LayerCacheEngineKey": @staticmethod def from_string(s): parts = s.split("@") - if len(parts) < 5: + if len(parts) < 6: raise ValueError(f"Invalid key string: {s}") request_configs = None - if len(parts) >= 6: + if len(parts) >= 7: request_configs = {} - for kv in parts[5:]: + for kv in parts[6:]: kvs = kv.split("%", 1) if len(kvs) != 2: raise ValueError(f"Invalid key string: {s}") @@ -205,6 +257,7 @@ def from_string(s): int(parts[2]), int(parts[3]), int(parts[4], 16), + STR_DTYPE_TO_TORCH_DTYPE[parts[5]], request_configs, ) @@ -217,6 +270,7 @@ def to_dict(self): "world_size": self.world_size, "worker_id": self.worker_id, "chunk_hash": self.chunk_hash, + "dtype": self._dtype_str, } if self.request_configs is not None and len(self.request_configs) != 0: msg["request_configs"] = [ @@ -240,11 +294,24 @@ def from_dict(d): world_size=d["world_size"], worker_id=d["worker_id"], chunk_hash=d["chunk_hash"], + dtype=STR_DTYPE_TO_TORCH_DTYPE[d["dtype"]], request_configs=request_configs, ) + def with_new_worker_id(self, new_worker_id: int) -> "CacheEngineKey": + # Reconstruct the cache engine key with new worker id + return CacheEngineKey( + self.fmt, + self.model_name, + self.world_size, + new_worker_id, + self.chunk_hash, + self.dtype, + self.request_configs, + ) + -@dataclass(order=True) +@dataclass(slots=True) class LayerCacheEngineKey(CacheEngineKey): """A key for the layer cache engine""" @@ -258,13 +325,14 @@ def __hash__(self): self.world_size, self.worker_id, self.chunk_hash, + self._dtype_str, self.tags, self.layer_id, ) ) def __eq__(self, other): - if super().__eq__(other): + if super(LayerCacheEngineKey, self).__eq__(other): return self.layer_id == other.layer_id return False @@ -272,7 +340,7 @@ def __eq__(self, other): def to_string(self): s = ( f"{self.fmt}@{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash:x}@{self.layer_id}" + f"@{self.worker_id}@{self.chunk_hash:x}@{self._dtype_str}@{self.layer_id}" ) if self.tags is not None and len(self.tags) != 0: tags = [f"{k}%{v}" for k, v in self.tags] @@ -290,6 +358,7 @@ def split_layers(self, num_layers: int) -> List["LayerCacheEngineKey"]: self.world_size, self.worker_id, self.chunk_hash, + self.dtype, self.request_configs, layer_id, ) @@ -299,12 +368,12 @@ def split_layers(self, num_layers: int) -> List["LayerCacheEngineKey"]: @staticmethod def from_string(s): parts = s.split("@") - if len(parts) < 6: + if len(parts) < 7: raise ValueError(f"Invalid key string: {s}") request_configs = None - if len(parts) >= 7: + if len(parts) >= 8: request_configs = {} - for kv in parts[6:]: + for kv in parts[7:]: kvs = kv.split("%", 1) if len(kvs) != 2: raise ValueError(f"Invalid key string: {s}") @@ -315,8 +384,9 @@ def from_string(s): int(parts[2]), int(parts[3]), int(parts[4], 16), + STR_DTYPE_TO_TORCH_DTYPE[parts[5]], request_configs, - int(parts[5]), + int(parts[6]), ) diff --git a/lmcache/v1/api_server/__main__.py b/lmcache/v1/api_server/__main__.py index a30a27f13d4..4fdf1432501 100644 --- a/lmcache/v1/api_server/__main__.py +++ b/lmcache/v1/api_server/__main__.py @@ -35,7 +35,10 @@ PinRetMsg, QueryInstMsg, QueryInstRetMsg, + QueryWorkerInfoMsg, + QueryWorkerInfoRetMsg, ) +from lmcache.v1.cache_controller.utils import WorkerInfo logger = init_logger(__name__) @@ -300,6 +303,32 @@ async def check_finish(req: CheckFinishRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e + class QueryWorkerInfoRequest(BaseModel): + instance_id: str + worker_ids: Optional[list[int]] + + class QueryWorkerInfoResponse(BaseModel): + event_id: str + worker_infos: list[WorkerInfo] + + @app.post("/query_worker_info", response_model=QueryWorkerInfoResponse) + async def query_worker_info(req: QueryWorkerInfoRequest): + try: + event_id = "QueryWorkerInfo" + str(uuid.uuid4()) + msg = QueryWorkerInfoMsg( + event_id=event_id, + instance_id=req.instance_id, + worker_ids=req.worker_ids, + ) + ret_msg = await lmcache_controller_manager.handle_orchestration_message(msg) + assert not isinstance(ret_msg, ErrorMsg), ret_msg.error + assert isinstance(ret_msg, QueryWorkerInfoRetMsg) + return QueryWorkerInfoResponse( + event_id=ret_msg.event_id, worker_infos=ret_msg.worker_infos + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + return app diff --git a/lmcache/v1/basic_check.py b/lmcache/v1/basic_check.py new file mode 100644 index 00000000000..de502b1a484 --- /dev/null +++ b/lmcache/v1/basic_check.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import argparse +import asyncio + +# First Party +from lmcache.v1.check import registry + +model_name = "/lmcache_test_model/" + + +def parse_args(): + parser = argparse.ArgumentParser(description="LMCache basic check Tool") + parser.add_argument( + "--mode", + required=True, + help="Operation mode (e.g. test_remote, test_storage_manager). " + "Use 'list' to show available modes", + ) + parser.add_argument("--model", default=model_name, help="model name") + parser.add_argument( + "--num-keys", + type=int, + default=100, + help="Number of keys to generate (gen mode only)", + ) + parser.add_argument( + "--concurrency", + type=int, + default=16, + help="Concurrency level for generation (gen mode only)", + ) + parser.add_argument( + "--offset", + type=int, + default=0, + help="Offset for key generation (gen mode only)", + ) + return parser.parse_args() + + +async def main(): + args = parse_args() + + # List available modes if requested + if args.mode == "list": + registry.load_modes() + print("Available check modes:") + for mode_name in registry.modes: + print(f" - {mode_name}") + return + + # Get the requested mode function + mode_func = registry.get_mode(args.mode) + if not mode_func: + print( + f"Error: Unknown mode '{args.mode}'. " + "Use '--mode list' to see available modes." + ) + return + + # Prepare arguments for the mode function + mode_args = { + "model": args.model, + "num_keys": args.num_keys, + "concurrency": args.concurrency, + "offset": args.offset, + } + + # Execute the mode function + await mode_func(**mode_args) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lmcache/v1/cache_controller/controller_manager.py b/lmcache/v1/cache_controller/controller_manager.py index c1234dfb503..3b39dfb26b7 100644 --- a/lmcache/v1/cache_controller/controller_manager.py +++ b/lmcache/v1/cache_controller/controller_manager.py @@ -37,6 +37,7 @@ OrchRetMsg, PinMsg, QueryInstMsg, + QueryWorkerInfoMsg, RegisterMsg, WorkerMsg, WorkerReqMsg, @@ -147,6 +148,8 @@ async def handle_orchestration_message(self, msg: OrchMsg) -> OrchRetMsg: # FIXME(Jiayi): This `check_finish` thing # shouldn't be implemented in kv_controller. return await self.kv_controller.check_finish(msg) + elif isinstance(msg, QueryWorkerInfoMsg): + return await self.reg_controller.query_worker_info(msg) else: logger.error(f"Unknown orchestration message type: {msg}") raise RuntimeError(f"Unknown orchestration message type: {msg}") diff --git a/lmcache/v1/cache_controller/controllers/kv_controller.py b/lmcache/v1/cache_controller/controllers/kv_controller.py index 154302a138e..788b848ca9e 100644 --- a/lmcache/v1/cache_controller/controllers/kv_controller.py +++ b/lmcache/v1/cache_controller/controllers/kv_controller.py @@ -219,6 +219,7 @@ async def batched_p2p_lookup( peer_init_url = self.reg_controller.get_distributed_url( instance_id, worker_id ) + assert peer_init_url is not None num_hit_chunks += 1 return BatchedP2PLookupRetMsg( diff --git a/lmcache/v1/cache_controller/controllers/registration_controller.py b/lmcache/v1/cache_controller/controllers/registration_controller.py index b4471fd3bf5..8482d420703 100644 --- a/lmcache/v1/cache_controller/controllers/registration_controller.py +++ b/lmcache/v1/cache_controller/controllers/registration_controller.py @@ -16,6 +16,8 @@ HeartbeatMsg, QueryInstMsg, QueryInstRetMsg, + QueryWorkerInfoMsg, + QueryWorkerInfoRetMsg, RegisterMsg, ) from lmcache.v1.cache_controller.utils import WorkerInfo @@ -34,8 +36,9 @@ def __init__(self): self.worker_mapping: dict[str, list[int]] = {} # Mapping from `(instance_id, worker_id)` -> `distributed_url` - # NOTE(Jiayi): `distributed_url` is used for actual KV cache transfer. - # It's not the lmcache_worker_url + # NOTE(Jiayi): `distributed_url` is used for actual KV cache transfer(p2p), + # It's not the lmcache_worker_url. + # if p2p is not used, distributed_url is None and not registered. self.distributed_url_mapping: dict[tuple[str, int], str] = {} # Mapping from `(instance_id, worker_id)` -> `socket` @@ -71,7 +74,10 @@ def get_distributed_url(self, instance_id: str, worker_id: int) -> Optional[str] """ url = self.distributed_url_mapping.get((instance_id, worker_id)) if url is None: - logger.warning(f"Instance-worker {(instance_id, worker_id)} not registered") + logger.warning( + f"Instance-worker {(instance_id, worker_id)} not registered " + f"or P2P is not used" + ) return url def get_workers(self, instance_id: str) -> list[int]: @@ -102,7 +108,13 @@ async def register(self, msg: RegisterMsg) -> None: port = msg.port url = f"{ip}:{port}" distributed_url = msg.distributed_url - self.distributed_url_mapping[(instance_id, worker_id)] = distributed_url + if distributed_url is not None: + self.distributed_url_mapping[(instance_id, worker_id)] = distributed_url + else: + logger.info( + f"distributed url of {(instance_id, worker_id)} is None, " + f"only register when p2p is used." + ) self.instance_mapping[ip] = instance_id @@ -188,3 +200,24 @@ async def heartbeat(self, msg: HeartbeatMsg) -> None: else: # update worker info self.worker_info_mapping[worker_key].last_heartbeat_time = time.time() + + async def query_worker_info(self, msg: QueryWorkerInfoMsg) -> QueryWorkerInfoRetMsg: + """ + Query worker info. + """ + event_id = msg.event_id + worker_infos = [] + if msg.instance_id not in self.worker_mapping: + logger.warning(f"instance {msg.instance_id} not registered.") + else: + worker_ids = msg.worker_ids + if worker_ids is None or len(worker_ids) == 0: + worker_ids = self.worker_mapping[msg.instance_id] + for worker_id in worker_ids: + worker_key = (msg.instance_id, worker_id) + if worker_key in self.worker_info_mapping: + worker_infos.append(self.worker_info_mapping[worker_key]) + else: + logger.warning(f"worker {worker_key} not registered.") + + return QueryWorkerInfoRetMsg(event_id=event_id, worker_infos=worker_infos) diff --git a/lmcache/v1/cache_controller/executor.py b/lmcache/v1/cache_controller/executor.py index 8d406307606..72677587ee7 100644 --- a/lmcache/v1/cache_controller/executor.py +++ b/lmcache/v1/cache_controller/executor.py @@ -310,7 +310,7 @@ async def move(self, msg: MoveMsg) -> Union[MoveRetMsg, ErrorMsg]: f"Src worker {src_worker_id} not registered for " f"instance {src_instance_id} or " f"dst worker {dst_worker_id} not registered for " - f"instance {dst_instance_id}" + f"instance {dst_instance_id} or P2P is not enabled." ) ) sockets.append(socket) diff --git a/lmcache/v1/cache_controller/message.py b/lmcache/v1/cache_controller/message.py index 6d49d338b96..47d858fe9b0 100644 --- a/lmcache/v1/cache_controller/message.py +++ b/lmcache/v1/cache_controller/message.py @@ -5,6 +5,9 @@ # Third Party import msgspec +# First Party +from lmcache.v1.cache_controller.utils import WorkerInfo + class MsgBase(msgspec.Struct, tag=True): # type: ignore """Base class for all messages""" @@ -38,7 +41,8 @@ class RegisterMsg(WorkerMsg): worker_id: int ip: str port: int - distributed_url: str # URL for actual KV cache transfer + # URL for actual KV cache transfer, only useful when p2p is enabled + distributed_url: Optional[str] def describe(self) -> str: return ( @@ -432,6 +436,17 @@ def describe(self) -> str: return f"Checking finish for event {self.event_id}" +class QueryWorkerInfoMsg(OrchMsg): + """Query worker info message""" + + event_id: str + instance_id: str + worker_ids: Optional[list[int]] + + def describe(self) -> str: + return f"Query worker info of {self.instance_id} : {self.worker_ids}" + + class OrchRetMsg(MsgBase): """Return message from Controller to Ochestrator""" @@ -529,6 +544,16 @@ def describe(self) -> str: return f"Event status: {self.status}" +class QueryWorkerInfoRetMsg(OrchRetMsg): + """Query worker info return message""" + + event_id: str + worker_infos: list[WorkerInfo] + + def describe(self) -> str: + return f"worker infos: {self.worker_infos}" + + class ErrorMsg(MsgBase): """Control Error Message""" @@ -579,4 +604,6 @@ def describe(self) -> str: HeartbeatMsg, BatchedP2PLookupMsg, BatchedP2PLookupRetMsg, + QueryWorkerInfoMsg, + QueryWorkerInfoRetMsg, ] diff --git a/lmcache/v1/cache_controller/utils.py b/lmcache/v1/cache_controller/utils.py index 0f3be8f8f22..7a2ea0436ca 100644 --- a/lmcache/v1/cache_controller/utils.py +++ b/lmcache/v1/cache_controller/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from dataclasses import dataclass +from typing import Optional @dataclass @@ -9,6 +10,6 @@ class WorkerInfo: worker_id: int ip: str port: int - distributed_url: str + distributed_url: Optional[str] registration_time: float last_heartbeat_time: float diff --git a/lmcache/v1/cache_controller/worker.py b/lmcache/v1/cache_controller/worker.py index 2e4621732bc..9db2d2a1161 100644 --- a/lmcache/v1/cache_controller/worker.py +++ b/lmcache/v1/cache_controller/worker.py @@ -99,9 +99,11 @@ def __init__( self.lmcache_worker_ip = get_ip() self.lmcache_worker_port = lmcache_worker_port - self.p2p_host = config.p2p_host - self.p2p_init_port = config.p2p_init_ports[self.worker_id] - self.p2p_init_url = f"{self.p2p_host}:{self.p2p_init_port}" + self.p2p_init_url = None + if config.enable_p2p: + self.p2p_host = config.p2p_host + self.p2p_init_port = config.p2p_init_ports[self.worker_id] + self.p2p_init_url = f"{self.p2p_host}:{self.p2p_init_port}" self.reply_socket = get_zmq_socket( self.context, diff --git a/lmcache/v1/cache_engine.py b/lmcache/v1/cache_engine.py index 2a5720df066..fe6d8af1cf9 100644 --- a/lmcache/v1/cache_engine.py +++ b/lmcache/v1/cache_engine.py @@ -53,6 +53,12 @@ logger = init_logger(__name__) +# Type aliases for processed chunks +# (cache_key, memory_obj, start_index, end_index) +ProcessedChunk = Tuple[CacheEngineKey, MemoryObj, int, int] +# (list of processed chunks, total kv size) +ProcessTokensInternalResult = Tuple[List[ProcessedChunk], int] + class CacheEngineEndSignal: pass @@ -80,7 +86,7 @@ def __init__( config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, token_database: TokenDatabase, - gpu_connector: GPUConnectorInterface, + gpu_connector: Optional[GPUConnectorInterface], broadcast_fn: Callable[[torch.Tensor, int], None], broadcast_object_fn: Callable[[Any, int], Any], ): @@ -97,7 +103,7 @@ def __init__( and metadata.use_mla ) - if self.save_only_first_rank: + if self.save_only_first_rank and self.gpu_connector is not None: self.broadcast_stream = ( self.gpu_connector.load_stream if hasattr(self.gpu_connector, "load_stream") @@ -114,19 +120,44 @@ def __init__( from lmcache.v1.cache_controller import LMCacheWorker self.lmcache_worker: Optional[LMCacheWorker] = None - if self.enable_controller: + if self.enable_controller and self.metadata.role != "scheduler": self.lmcache_worker = LMCacheWorker(config, metadata, self) self.async_loading = config.enable_async_loading self.event_manager = EventManager() - self.storage_manager = StorageManager( - config, - metadata, - # self.memory_allocator, - event_manager=self.event_manager, - lmcache_worker=self.lmcache_worker, + self.use_layerwise = config.use_layerwise + + # TODO: support save_only_first_rank when use layerwise + # if use_layerwise is True, all ranks will initialize the storage_manager + # if save_only_first_rank is False, all ranks will initialize + # the storage_manager + # if save_only_first_rank is True, only the first rank and + # lookup server workers will initialize the storage_manager + self.storage_manager = None + lookup_server_worker_ids = self.config.get_lookup_server_worker_ids( + metadata.use_mla, metadata.world_size ) + if ( + self.lmcache_worker is not None + or self.use_layerwise + or not self.save_only_first_rank + or self.metadata.is_first_rank() + or len(lookup_server_worker_ids) == 0 + or self.metadata.worker_id in lookup_server_worker_ids + ): + logger.info( + f"Initialize storage manager on rank {self.metadata.worker_id}, " + f"use layerwise: {self.use_layerwise}," + f"save only first rank: {self.save_only_first_rank}" + ) + self.storage_manager = StorageManager( + config, + metadata, + # self.memory_allocator, + event_manager=self.event_manager, + lmcache_worker=self.lmcache_worker, + ) # HACK: remove this in the future # NOTE (Jiayi): This is currently used to support @@ -134,13 +165,15 @@ def __init__( # at decoder. self.remove_after_retrieve = config.enable_pd and config.pd_role == "receiver" - self.use_layerwise = config.use_layerwise self.num_layers = metadata.kv_shape[0] + self.fmt = None if self.use_layerwise: if config.enable_blending: self.fmt = MemoryFormat.KV_2TD else: self.fmt = MemoryFormat.KV_T2D + if metadata.use_mla: + self.fmt = MemoryFormat.KV_MLA_FMT # NOTE(ApostaC): we haven't support lookup-cache yet self.lookup_cache: dict[CacheEngineKey, Any] = {} @@ -164,11 +197,13 @@ def __init__( def post_init(self, **kwargs) -> None: if "async_lookup_server" in kwargs: - self.async_lookup_server = kwargs.pop("async_lookup_server") - self.storage_manager.post_init(async_lookup_server=self.async_lookup_server) + self.async_lookup_server = kwargs["async_lookup_server"] if not self.post_inited: + if self.storage_manager is not None: + self.storage_manager.post_init(**kwargs) logger.info("Post-initializing LMCacheEngine") - self.gpu_connector.initialize_kvcaches_ptr(**kwargs) + if self.gpu_connector is not None: + self.gpu_connector.initialize_kvcaches_ptr(**kwargs) self.post_inited = True @_lmcache_nvtx_annotate @@ -200,10 +235,16 @@ def store( :raises: ValueError if the number of Falses in the mask is not a multiple of the chunk size. """ + assert self.gpu_connector is not None, ( + "gpu_connector is required for store operation" + ) + if self._is_passive(): logger.debug(f"rank={self.metadata.worker_id} ignore store") return + assert self.storage_manager is not None + if mask is not None: num_to_store_tokens = torch.sum(mask).item() elif tokens is not None: @@ -256,6 +297,7 @@ def store( kv_shape, kv_dtype, busy_loop=self.force_store_wait, + fmt=self.fmt, ) if memory_obj is None: logger.warning( @@ -327,6 +369,10 @@ def store_layer( storage backends. In the last iteration, it puts the memory objects of the last layer to the storage backends. """ + assert self.storage_manager is not None + assert self.gpu_connector is not None, ( + "gpu_connector is required for store_layer operation" + ) if mask is not None: num_to_store_tokens = torch.sum(mask).item() @@ -442,6 +488,10 @@ def retrieve( :raises: ValueError if the number of Falses in the mask is not a multiple of the chunk size. """ + assert self.gpu_connector is not None, ( + "gpu_connector is required for retrieve operation" + ) + tot_kv_size = 0 t = time.perf_counter() @@ -453,7 +503,7 @@ def retrieve( ret_mask = torch.zeros(len(tokens), dtype=torch.bool, device="cpu") - reordered_chunks: List[Tuple[CacheEngineKey, MemoryObj, int, int]] = [] + reordered_chunks: List[ProcessedChunk] = [] if not self._is_passive(): if self.async_loading: reordered_chunks, tot_kv_size = self._async_process_tokens_internal( # noqa: E501 @@ -499,6 +549,7 @@ def retrieve( # TODO(Jiayi): Need to refactor the `remove_after_retrieve` logic. for key, memory_obj, _, _ in reordered_chunks: if self.remove_after_retrieve and not self._is_passive(): + assert self.storage_manager is not None self.storage_manager.remove(key) memory_obj.ref_count_down() @@ -549,6 +600,10 @@ def retrieve_layer( last iteration, it moves the memory objects of the last layer to the GPU. """ + assert self.storage_manager is not None + assert self.gpu_connector is not None, ( + "gpu_connector is required for retrieve_layer operation" + ) if mask is not None: num_required_tokens = torch.sum(mask).item() @@ -565,6 +620,8 @@ def retrieve_layer( request_configs = kwargs.get("request_configs") if request_configs is not None and len(request_configs) != 0: assert isinstance(request_configs, dict) + + location = None for start, end, key in self.token_database.process_tokens( tokens=tokens, mask=mask, @@ -575,7 +632,17 @@ def retrieve_layer( keys_multi_layer = key.split_layers(self.num_layers) # NOTE: Only check the first layer - if not self.storage_manager.contains(keys_multi_layer[0]): + if current_location := self.storage_manager.contains(keys_multi_layer[0]): + if location is None: + location = current_location + else: + # TODO(Jiayi): Support multi-location retrieval in the future + assert location == current_location, ( + "All retrieved keys should be from the same location " + "when use layerwise retrieval." + "Please support multi-location retrieval in the future." + ) + else: break starts.append(start) @@ -588,7 +655,10 @@ def retrieve_layer( # Transpose the keys into layer major format keys_layer_major = [list(row) for row in zip(*keys, strict=False)] - get_generator = self.storage_manager.layerwise_batched_get(keys_layer_major) + get_generator = self.storage_manager.layerwise_batched_get( + keys_layer_major, + location=location, + ) assert isinstance( self.gpu_connector, @@ -677,30 +747,29 @@ def lookup( :return: An int indicating how many prefix tokens are cached. """ + assert self.storage_manager is not None if tokens is not None: - self.stats_monitor.on_lookup_request(len(tokens)) + lookup_request_id = self.stats_monitor.on_lookup_request(len(tokens)) else: assert offsets is not None assert hashes is not None - self.stats_monitor.on_lookup_request(sum(offsets)) + lookup_request_id = self.stats_monitor.on_lookup_request(sum(offsets)) + res = 0 try: - end = 0 - prev_end = 0 - - if pin: - assert lookup_id is not None, "lookup_id is required when pin is True" - - for start, end, key in self.token_database.process_tokens( + chunk_info_iterator = self.token_database.process_tokens( tokens=tokens, hashes=hashes, offsets=offsets, request_configs=request_configs, - ): - assert isinstance(key, CacheEngineKey) + ) + + # TODO: support batched_contains when layerwise is enabled + if self.use_layerwise: + for start, end, key in chunk_info_iterator: + assert isinstance(key, CacheEngineKey) - if self.use_layerwise: # TODO(Jiayi): Optimize by checking only the existence of the key # of one layer key_all_layers = key.split_layers(self.num_layers) @@ -713,34 +782,48 @@ def lookup( found = True if found: if pin: + assert lookup_id is not None, ( + "lookup_id is required when pin is True" + ) self.lookup_pins[lookup_id].extend( # type: ignore key_all_layers ) - prev_end = end + res = end continue - end = prev_end - return prev_end - else: - if self.storage_manager.contains(key, search_range, pin): + return res + else: + chunk_info_list = [] + keys = [] + for chunk_info in chunk_info_iterator: + assert isinstance(chunk_info[2], CacheEngineKey) + chunk_info_list.append(chunk_info) + keys.append(chunk_info[2]) + + batched_contains_res = self.storage_manager.batched_contains( + keys, search_range, pin, True + ) + assert len(batched_contains_res) == len(chunk_info_list) + for (start, end, key), exists in zip( + chunk_info_list, batched_contains_res, strict=False + ): + if exists: if pin: - self.lookup_pins[lookup_id].append( # type: ignore - key + assert lookup_id is not None, ( + "lookup_id is required when pin is True" ) - prev_end = end + self.lookup_pins[lookup_id].append(key) + res = end continue - - end = prev_end - return prev_end + return res # all tokens where found, return the maximal end - return end + return res finally: - self.stats_monitor.on_lookup_finished(end) + self.stats_monitor.on_lookup_finished(lookup_request_id, res) # vllm lookup sets pin to True if pin: self.storage_manager.touch_cache() - # FIXME @_lmcache_nvtx_annotate def move( self, @@ -753,6 +836,7 @@ def move( """ Perform cross-node move of the KV cache. """ + assert self.storage_manager is not None num_tokens = self.lookup( tokens, @@ -825,6 +909,7 @@ def async_lookup_and_prefetch( (2) sync lookup + async retrieval (e.g., disk) (3) async lookup + async retrieval (e.g., p2p) """ + assert self.storage_manager is not None keys: list[CacheEngineKey] = [] cum_chunk_lengths = [0] @@ -858,6 +943,7 @@ def compress( location: str, event_id: str, ) -> int: + assert self.storage_manager is not None if method not in ["cachegen"]: logger.warning(f"Unsupported compression method: {method}.") return 0 @@ -913,6 +999,7 @@ def decompress( location: str, event_id: str, ) -> int: + assert self.storage_manager is not None if method not in ["cachegen"]: logger.warning(f"Unsupported decompression method: {method}.") return 0 @@ -963,11 +1050,11 @@ def decompress( return num_tokens @_lmcache_nvtx_annotate - def lookup_unpin(self, lookup_ids: list[str]) -> None: - for lookup_id in lookup_ids: - if lookup_id in self.lookup_pins: - self.storage_manager.batched_unpin(self.lookup_pins[lookup_id]) - del self.lookup_pins[lookup_id] + def lookup_unpin(self, lookup_id: str) -> None: + if lookup_id in self.lookup_pins: + assert self.storage_manager is not None + self.storage_manager.batched_unpin(self.lookup_pins[lookup_id]) + del self.lookup_pins[lookup_id] @_lmcache_nvtx_annotate def clear( @@ -980,11 +1067,9 @@ def clear( if self.save_only_first_rank: if self.metadata.is_first_rank(): num_removed = self._clear(tokens, locations, request_configs) - self.broadcast_object_fn(num_removed, self.metadata.first_rank) return num_removed else: - num_removed = self.broadcast_object_fn(None, self.metadata.first_rank) - return int(num_removed) + return 0 return self._clear(tokens, locations, request_configs) def _clear( @@ -993,6 +1078,7 @@ def _clear( locations: Optional[List[str]] = None, request_configs: Optional[dict] = None, ) -> int: + assert self.storage_manager is not None assert isinstance(self.storage_manager, StorageManager) # Clear all caches if tokens is None if tokens is None or len(tokens) == 0: @@ -1017,6 +1103,7 @@ def health( Check the health of the cache engine. return: 0 if healthy, otherwise the error code """ + assert self.storage_manager is not None return 0 if self.storage_manager.memcheck() else -1 def close(self) -> None: @@ -1025,7 +1112,8 @@ def close(self) -> None: if self.lmcache_worker is not None: self.lmcache_worker.close() - self.storage_manager.close() + if self.storage_manager is not None: + self.storage_manager.close() logger.info("LMCacheEngine closed.") @@ -1035,7 +1123,7 @@ def _async_process_tokens_internal( mask, ret_mask, **kwargs, - ) -> tuple[list[tuple[CacheEngineKey, MemoryObj, int, int]], int]: + ) -> ProcessTokensInternalResult: """ This function is used to get the memory objects from the event manager. @@ -1051,7 +1139,7 @@ def _async_process_tokens_internal( assert isinstance(request_configs, dict) tot_kv_size = 0 - chunks: list[tuple[CacheEngineKey, MemoryObj, int, int]] = [] + chunks: List[ProcessedChunk] = [] future = self.event_manager.pop_event(EventType.LOADING, kwargs["req_id"]) memory_objs = future.result() @@ -1060,22 +1148,24 @@ def _async_process_tokens_internal( # NOTE(Jiayi): here we assume the retrieved memory_objs have # the same order as the lookup order. # TODO(Jiayi): hashing inside `process_tokens` can be skipped. - for idx, (start, end, key) in enumerate( - self.token_database.process_tokens( - tokens=tokens, - mask=mask, - request_configs=request_configs, - ) + used_indices = set() + for start, end, key in self.token_database.process_tokens( + tokens=tokens, + mask=mask, + request_configs=request_configs, ): assert isinstance(key, CacheEngineKey) + idx = start // self.config.chunk_size memory_obj = memory_objs[idx] chunks.append((key, memory_obj, start, end)) tot_kv_size += memory_obj.get_size() ret_mask[start:end] = True + used_indices.add(idx) # NOTE: free the memory objects that are not hit. - for unused_mem_obj in memory_objs[len(chunks) :]: - unused_mem_obj.ref_count_down() + for idx, unused_mem_obj in enumerate(memory_objs): + if idx not in used_indices: + unused_mem_obj.ref_count_down() return chunks, tot_kv_size @@ -1085,7 +1175,7 @@ def _process_tokens_internal( mask, ret_mask, **kwargs, - ) -> tuple[list[tuple[CacheEngineKey, MemoryObj, int, int]], int]: + ) -> ProcessTokensInternalResult: """Process tokens and populate the reordered lists. This function is used to process tokens and populate the reordered lists. @@ -1096,6 +1186,7 @@ def _process_tokens_internal( ret_mask: Output mask updated with cache hit positions **kwargs: Additional keyword arguments """ + assert self.storage_manager is not None tot_kv_size = 0 # location -> [(CacheEngineKey, start, end)] @@ -1103,12 +1194,20 @@ def _process_tokens_internal( list ) - reordered_chunks: list[tuple[CacheEngineKey, MemoryObj, int, int]] = [] + reordered_chunks: List[ProcessedChunk] = [] request_configs = kwargs.get("request_configs") if request_configs is not None and len(request_configs) != 0: assert isinstance(request_configs, dict) + # In some scenarios, lookup is called first, and then the original tokens + # is sliced based on the lookup result. In these scenarios, the tokens + # passed in must exist in LMCache, and we can set skip_contains_check to True. + # When skip_contains_check is True and there is only one backend, the `contains` + # call can be skipped. + skip_contains_check = ( + kwargs["skip_contains_check"] if "skip_contains_check" in kwargs else False + ) for start, end, key in self.token_database.process_tokens( tokens=tokens, mask=mask, @@ -1116,14 +1215,21 @@ def _process_tokens_internal( ): assert isinstance(key, CacheEngineKey) + location = None if key in self.lookup_cache: # TODO(Jiayi): we can reduce the number of `contains` calls # by checking the lookup cache first (should be updated in `lookup`) pass else: - # NOTE: key should always be in the lookup cache once - # we support it. - location = self.storage_manager.contains(key) + # NOTE: key should always be in the lookup cache once we support it. + # TODO: use lookup_cache to skip the contains + if ( + skip_contains_check + and len(self.storage_manager.non_allocator_backends) == 1 + ): + location = self.storage_manager.non_allocator_backends[0] + else: + location = self.storage_manager.contains(key) if location is None: break @@ -1358,7 +1464,7 @@ def get_or_create( instance_id: str, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, - gpu_connector: GPUConnectorInterface, + gpu_connector: Optional[GPUConnectorInterface], broadcast_fn: Callable[[torch.Tensor, int], None], broadcast_object_fn: Callable[[Any, int], Any], ) -> LMCacheEngine: diff --git a/lmcache/v1/check/__init__.py b/lmcache/v1/check/__init__.py new file mode 100644 index 00000000000..216ea0478cf --- /dev/null +++ b/lmcache/v1/check/__init__.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Check mode registry implementation""" + +# Standard +from typing import Callable, Dict, Optional +import importlib +import inspect +import os + +# First Party +from lmcache.logging import init_logger + +logger = init_logger(__name__) + + +class CheckModeRegistry: + """Registry for dynamically loaded check modes""" + + def __init__(self): + self.modes: Dict[str, Callable] = {} + self.loaded = False + + def register(self, name: str, func: Callable): + """Register a check mode function""" + if name in self.modes: + raise ValueError(f"Check mode '{name}' already registered") + self.modes[name] = func + + def load_modes(self): + """Dynamically load all check mode modules""" + if self.loaded: + return + + # Get current package + current_dir = os.path.dirname(__file__) + + # Find all modules with check_mode_ prefix + for filename in os.listdir(current_dir): + if filename.startswith("check_mode_") and filename.endswith(".py"): + module_name = filename[:-3] # Remove .py + try: + module = importlib.import_module( + f".{module_name}", package=__package__ + ) + # Find and register mode functions + for name, obj in inspect.getmembers(module): + if inspect.isfunction(obj) and hasattr(obj, "is_check_mode"): + self.register(obj.mode_name, obj) + except ImportError as e: + logger.error(f"Failed to load check mode module {module_name}: {e}") + + self.loaded = True + logger.info(f"Loaded {len(self.modes)} check modes") + + def get_mode(self, name: str) -> Optional[Callable]: + """Get registered mode function. Returns None if the mode is not found.""" + if not self.loaded: + self.load_modes() + return self.modes.get(name) + + +def check_mode(name: str): + """Decorator to mark functions as check modes""" + + def decorator(func): + func.is_check_mode = True + func.mode_name = name + return func + + return decorator + + +# Global registry instance +registry = CheckModeRegistry() diff --git a/lmcache/v1/check/check_mode_gen.py b/lmcache/v1/check/check_mode_gen.py new file mode 100644 index 00000000000..da588795834 --- /dev/null +++ b/lmcache/v1/check/check_mode_gen.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Generate mode implementation for key generation""" + +# Third Party +import tqdm + +# First Party +from lmcache.v1.check import check_mode +from lmcache.v1.check.utils import ( + _get_default_metadata, + create_memory_objects_batch, + create_storage_manager_with_config, + create_test_key, + find_remote_backend, + flow_control_check, + wait_put_tasks_complete, +) + + +@check_mode("gen") +async def run_gen_mode( + model: str, num_keys: int, concurrency: int, offset: int = 0, **kwargs +): + """Run key generation mode""" + # Create storage manager using common function + storage_manager = create_storage_manager_with_config(model) + metadata = _get_default_metadata(model) + + try: + print("Generate: Passed - Created storage manager with valid config") + + # Find remote backend for flow control + remote_backend = find_remote_backend(storage_manager) + + # Create limited number of memory objects for reuse (memory efficiency) + batch_size = min(concurrency, 100) # Limit to 100 for memory efficiency + memory_objs = create_memory_objects_batch(storage_manager, metadata, batch_size) + + if not memory_objs: + print("Generate: Failed - Could not allocate any memory objects") + return + + # Create progress bar + progress_bar = tqdm.tqdm( + total=num_keys, desc="Generating keys", unit="key", unit_scale=True + ) + sleep_count = 1.0 + # Process keys in batches of concurrency size + for batch_start in range(0, num_keys, concurrency): + batch_end = min(batch_start + concurrency, num_keys) + batch_keys = [] + batch_memory_objs = [] + + # Create keys and reuse memory objects for this batch + for i in range(batch_start, batch_end): + key = create_test_key(model, f"gen_{offset + i}") + # Reuse memory objects in round-robin fashion + memory_obj = memory_objs[i % len(memory_objs)] + batch_keys.append(key) + batch_memory_objs.append(memory_obj) + memory_obj.ref_count_up() + + # Flow control: check if remote backend has too many pending tasks + sleep_count = await flow_control_check( + remote_backend, concurrency, sleep_count + ) + + # Use batched_put to store the batch of memory objects + storage_manager.batched_put(batch_keys, batch_memory_objs) + + # Update progress bar + progress_bar.update(len(batch_keys)) + + progress_bar.close() + print(f"Generate: Successfully generated {num_keys} keys") + + # Wait for remote backend put_tasks to complete + wait_put_tasks_complete(find_remote_backend(storage_manager)) + + except Exception as e: + print( + f"Generate: Failed - Error creating storage manager with valid config: {e}" + ) + finally: + if storage_manager: + storage_manager.close() diff --git a/lmcache/v1/check/check_mode_test_remote.py b/lmcache/v1/check/check_mode_test_remote.py new file mode 100644 index 00000000000..11ee3f472d3 --- /dev/null +++ b/lmcache/v1/check/check_mode_test_remote.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test mode implementation for basic checks""" + +# Standard +import asyncio + +# First Party +from lmcache.integration.vllm.utils import lmcache_get_or_create_config +from lmcache.v1.check import check_mode + +# Import shared utilities +from lmcache.v1.check.utils import ( + EventLoopManager, + _get_default_metadata, + create_test_key, + run_common_test_framework, + validate_get_results, +) +from lmcache.v1.memory_management import MemoryObj + +# Import from lmcache with absolute paths +from lmcache.v1.storage_backend import RemoteBackend +from lmcache.v1.storage_backend.connector import InstrumentedRemoteConnector +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend + + +async def async_contains_backend(backend, key): + """Async wrapper for backend contains method""" + return backend.contains(key) + + +async def async_get_backend(backend, key): + """Async wrapper for backend get_blocking method""" + return backend.get_blocking(key) + + +async def async_submit_put_backend(backend, key, memory_obj): + """Async wrapper for backend submit_put_task""" + future = backend.submit_put_task(key, memory_obj) + # Wait for the future to complete with timeout + try: + await asyncio.wait_for(asyncio.wrap_future(future), timeout=10.0) + return True + except asyncio.TimeoutError: + print(f"Put task timed out for key: {key}") + return False + + +def create_test_memory_obj( + backend: RemoteBackend, local_cpu_backend: LocalCPUBackend +) -> MemoryObj: + """Create a test MemoryObj for testing.""" + if backend.connection is None: + raise ValueError("Backend connection is None") + + if isinstance(backend.connection, InstrumentedRemoteConnector): + connector = backend.connection.getWrappedConnector() + else: + connector = backend.connection + + return local_cpu_backend.allocate( + connector.meta_shape, connector.meta_dtype, connector.meta_fmt + ) + + +def create_test_data_for_backend(backend, local_cpu_backend, model, num_tests): + """Create test data for backend based tests""" + # Group 1: Non-existing keys + non_exist_keys = [ + create_test_key(model, f"non_exist_{i}") for i in range(num_tests) + ] + + # Group 2: Existing keys + exist_keys = [create_test_key(model, f"exist_{i}") for i in range(num_tests)] + exist_memories = [ + create_test_memory_obj(backend, local_cpu_backend) for _ in range(num_tests) + ] + + return non_exist_keys, exist_keys, exist_memories, num_tests + + +@check_mode("test_remote") +async def run_test_mode(model: str, **kwargs): + """Run connector test mode""" + config = lmcache_get_or_create_config() + metadata = _get_default_metadata(model) + + # Create and start event loop manager + loop_manager = EventLoopManager() + loop_manager.start() + + local_cpu_backend = LocalCPUBackend( + config=config, metadata=metadata, dst_device="cpu" + ) + + backend = RemoteBackend( + config=config, + metadata=metadata, + loop=loop_manager.get_loop(), + local_cpu_backend=local_cpu_backend, + dst_device="cpu", + ) + + try: + # Create test context for the common framework + test_context = { + "create_test_data_func": create_test_data_for_backend, + "async_contains_func": async_contains_backend, + "async_put_func": async_submit_put_backend, + "async_get_func": async_get_backend, + "validate_get_func": validate_get_results, + "test_object": backend, + "extra_args": [local_cpu_backend], # Additional argument for backend tests + } + + # Run the common test framework + await run_common_test_framework(test_context, model, num_tests=5) + + except Exception as e: + print(f"Test Failed - Error: {e}") + finally: + # Clean up + try: + if backend: + backend.close() + except Exception as e: + print(f"Error closing backend: {e}") + + # Stop the event loop + loop_manager.stop() diff --git a/lmcache/v1/check/check_mode_test_storage_manager.py b/lmcache/v1/check/check_mode_test_storage_manager.py new file mode 100644 index 00000000000..55dfe13bad1 --- /dev/null +++ b/lmcache/v1/check/check_mode_test_storage_manager.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test mode implementation for basic checks""" + +# Standard +import asyncio + +# First Party +from lmcache.v1.check import check_mode + +# Import shared utilities +from lmcache.v1.check.utils import ( + create_storage_manager_with_config, + create_test_key, + create_test_memory_obj_for_storage_manager, + find_remote_backend, + run_common_test_framework, + validate_get_results, + wait_put_tasks_complete, +) + + +async def async_contains_storage_manager(storage_manager, key): + """Async wrapper for storage manager contains method""" + # Use asyncio.to_thread to make the synchronous call truly async + # This allows for proper timeout handling and non-blocking execution + result = await asyncio.to_thread(storage_manager.contains, key) + return result is not None + + +async def async_get_storage_manager(storage_manager, key): + """Async wrapper for storage manager get method""" + # Use asyncio.to_thread to make the synchronous call truly async + # This allows for proper timeout handling and non-blocking execution + return await asyncio.to_thread(storage_manager.get, key) + + +async def async_submit_put_storage_manager(storage_manager, key, memory_obj): + """Async wrapper for storage manager batched_put""" + try: + # Use asyncio.to_thread to make the synchronous calls truly async + # This allows for proper timeout handling and non-blocking execution + await asyncio.to_thread(storage_manager.batched_put, [key], [memory_obj]) + await asyncio.to_thread( + wait_put_tasks_complete, find_remote_backend(storage_manager) + ) + return True + except Exception as e: + print(f"Put task failed for key: {key}, error: {e}") + return False + + +def create_test_data_for_storage_manager(storage_manager, metadata, model, num_tests): + """Create test data for storage manager based tests""" + # Group 1: Non-existing keys + non_exist_keys = [ + create_test_key(model, f"non_exist_{i}") for i in range(num_tests) + ] + + # Group 2: Existing keys + exist_keys = [create_test_key(model, f"exist_{i}") for i in range(num_tests)] + exist_memories = [] + for i in range(num_tests): + memory_obj = create_test_memory_obj_for_storage_manager( + storage_manager, metadata + ) + if memory_obj is not None: + # Fill with unique test data for each memory object + if memory_obj.tensor is not None: + # Fill with a pattern based on the index to make each object unique + memory_obj.tensor.fill_(float(i + 1)) + memory_obj.ref_count_up() + exist_memories.append(memory_obj) + + if len(exist_memories) != num_tests: + print( + f"Warning: Could only allocate {len(exist_memories)}/{num_tests} " + f"memory objects" + ) + num_tests = len(exist_memories) + exist_keys = exist_keys[:num_tests] + + return non_exist_keys, exist_keys, exist_memories, num_tests + + +@check_mode("test_storage_manager") +async def run_test_mode(model: str, **kwargs): + """Run connector test mode""" + # Create storage manager using common function + storage_manager = create_storage_manager_with_config(model) + + try: + print("Test: Passed - Created storage manager with valid config") + + # Create test context for the common framework + test_context = { + "create_test_data_func": create_test_data_for_storage_manager, + "async_contains_func": async_contains_storage_manager, + "async_put_func": async_submit_put_storage_manager, + "async_get_func": async_get_storage_manager, + "validate_get_func": validate_get_results, + "test_object": storage_manager, + } + + # Run the common test framework + await run_common_test_framework(test_context, model, num_tests=5) + + except Exception as e: + print(f"Test Failed - Error: {e}") + finally: + # Clean up + try: + if storage_manager: + storage_manager.close() + except Exception as e: + print(f"Error closing storage manager: {e}") diff --git a/lmcache/v1/check/utils.py b/lmcache/v1/check/utils.py new file mode 100644 index 00000000000..eb353dd1b60 --- /dev/null +++ b/lmcache/v1/check/utils.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared utilities for check modes""" + +# Standard +from typing import Optional +import asyncio +import hashlib +import threading +import time + +# Third Party +import torch + +# First Party +from lmcache.config import LMCacheEngineMetadata +from lmcache.utils import CacheEngineKey + +# Import from lmcache with absolute paths +from lmcache.v1.memory_management import MemoryFormat, MemoryObj +from lmcache.v1.storage_backend.remote_backend import RemoteBackend +from lmcache.v1.storage_backend.storage_manager import StorageManager + + +def _get_default_metadata(model: str) -> LMCacheEngineMetadata: + """Get default metadata for testing""" + return LMCacheEngineMetadata( + model_name=model, + world_size=8, + worker_id=0, + fmt="vllm", + kv_dtype=torch.bfloat16, + kv_shape=(8, 2, 16, 8, 16), + ) + + +def create_test_key(model: str, key_id: str = "test_key") -> CacheEngineKey: + """Create a test CacheEngineKey.""" + return CacheEngineKey( + "vllm", + model, + 8, + 0, + int(hashlib.sha256(key_id.encode()).hexdigest(), 16), + torch.bfloat16, + ) + + +def create_test_memory_obj_for_storage_manager( + storage_manager: StorageManager, metadata: LMCacheEngineMetadata +) -> Optional[MemoryObj]: + """Create a test MemoryObj for testing with StorageManager.""" + # The metadata.kv_shape is in vllm format: + # [num_layers, 2, num_tokens, num_heads, head_size] + # For KV_2LTD format, we need shape: [2, num_layers, num_tokens, hidden_dim] + # where hidden_dim = num_heads * head_size + + vllm_shape = metadata.kv_shape # [num_layers, 2, num_tokens, num_heads, head_size] + num_layers = vllm_shape[0] # 8 + kv_dim = vllm_shape[1] # 2 (K and V) + num_tokens = vllm_shape[2] # 16 + num_heads = vllm_shape[3] # 8 + head_size = vllm_shape[4] # 16 + + # Convert to KV_2LTD format shape: [2, num_layers, num_tokens, hidden_dim] + hidden_dim = num_heads * head_size + kv_2ltd_shape = torch.Size([kv_dim, num_layers, num_tokens, hidden_dim]) + + memory_obj = storage_manager.allocate( + shape=kv_2ltd_shape, + dtype=metadata.kv_dtype, + fmt=MemoryFormat.KV_2LTD, + eviction=True, + busy_loop=False, + ) + return memory_obj + + +def create_storage_manager_with_config(model: str): + """Create storage manager with default configuration""" + # First Party + from lmcache.integration.vllm.utils import lmcache_get_or_create_config + from lmcache.v1.event_manager import EventManager + + config = lmcache_get_or_create_config() + metadata = _get_default_metadata(model) + + # Create event manager + event_manager = EventManager() + + # Create storage manager + storage_manager = StorageManager( + config=config, + metadata=metadata, + event_manager=event_manager, + ) + + return storage_manager + + +def find_remote_backend(storage_manager: StorageManager) -> Optional[RemoteBackend]: + """Find remote backend from storage manager""" + for backend_name, backend in storage_manager.storage_backends.items(): + if isinstance(backend, RemoteBackend): + return backend + return None + + +def wait_put_tasks_complete( + remote_backend: Optional[RemoteBackend], max_wait_time: float = 5.0 +): + """Wait for remote backend put tasks to complete""" + if remote_backend is None: + return + + check_interval = 0.001 + elapsed_time = 0.0 + + while elapsed_time < max_wait_time: + if not remote_backend.put_tasks: + break + time.sleep(check_interval) + elapsed_time += check_interval + + # Log warning if timeout + remaining_tasks = len(remote_backend.put_tasks) + if remaining_tasks > 0: + print( + f"Warning: {remaining_tasks} remote put tasks still " + f"pending after {max_wait_time}s timeout" + ) + + +def create_memory_objects_batch( + storage_manager: StorageManager, metadata: LMCacheEngineMetadata, batch_size: int +) -> list[MemoryObj]: + """Create a batch of memory objects for reuse""" + memory_objs = [] + for i in range(batch_size): + memory_obj = create_test_memory_obj_for_storage_manager( + storage_manager, metadata + ) + if memory_obj is not None: + memory_obj.ref_count_up() + memory_objs.append(memory_obj) + return memory_objs + + +async def flow_control_check( + remote_backend: Optional[RemoteBackend], concurrency: int, sleep_count: float = 1.0 +) -> float: + """Check flow control and wait if necessary""" + if remote_backend is None: + return sleep_count + + high_watermark = 100 * concurrency + low_watermark = 10 * concurrency + current_tasks = len(remote_backend.put_tasks) + + while current_tasks > high_watermark: + current_tasks = len(remote_backend.put_tasks) + if current_tasks > high_watermark: + # Too many pending tasks, wait before proceeding + sleep_sec = 0.1 * sleep_count + current_tasks = len(remote_backend.put_tasks) + await asyncio.sleep(sleep_sec) + current_tasks_after_sleep = len(remote_backend.put_tasks) + if current_tasks_after_sleep > low_watermark: + sleep_count *= 2.0 + elif current_tasks_after_sleep == 0: + sleep_count /= 2.0 + continue + if current_tasks <= low_watermark: + break + + return sleep_count + + +async def run_perf_test_with_timeout(func, args_list, timeout=30.0): + """Common performance test framework with timeout handling""" + times = [] + results = [] # Collect results for each operation + for i, args in enumerate(args_list): + try: + start = time.perf_counter() + result = await asyncio.wait_for(func(*args), timeout=timeout) + end = time.perf_counter() + times.append((end - start) * 1000) + results.append(result) + print( + f" Test {i + 1}/{len(args_list)} completed in " + f"{(end - start) * 1000:.2f}ms" + ) + except asyncio.TimeoutError: + print(f" Test {i + 1}/{len(args_list)} timed out after {timeout}s") + times.append(timeout * 1000) + results.append(None) + except Exception as e: + print(f" Test {i + 1}/{len(args_list)} failed: {e}") + times.append(0) + results.append(None) + + if times: + return { + "time_stats": { + "avg": sum(times) / len(times), + "max": max(times), + "min": min(times), + }, + "results": results, + } + else: + return {"time_stats": {"avg": 0, "max": 0, "min": 0}, "results": []} + + +def print_performance_results(stats_data): + """Print performance results in a formatted table""" + print("\nPerformance Results:") + print("-" * 100) + print( + f"| {'Operation':<20} | {'Avg (ms)':>12} | {'Max (ms)':>12} " + f"| {'Min (ms)':>12} | {'Pass/All':>10} | {'Pass Rate':>10} |" + ) + print("-" * 100) + for op, stats, results, pass_count in stats_data: + total = len(results) + pass_all = f"{pass_count}/{total}" + pass_rate = pass_count / total * 100 if total > 0 else 0 + + print( + f"| {op:<20} | {stats['avg']:>12.6f} | {stats['max']:>12.6f} " + f"| {stats['min']:>12.6f} | {pass_all:>10} | {pass_rate:>9.1f}% |" + ) + print("-" * 100) + + +def validate_get_results(get_results, exist_keys, exist_memories, num_tests): + """Validate GET operation results and return statistics""" + content_valid_count = 0 + for i, result in enumerate(get_results["results"]): + if result is None: + print(f" GET for key {exist_keys[i]} returned None result") + continue + try: + if result.tensor is None: + print(f" GET for key {exist_keys[i]} returned None tensor") + continue + + if exist_memories[i].tensor is None: + print(f" Original memory object {i} has None tensor") + continue + + # Compare data content + data_match = torch.equal(result.tensor, exist_memories[i].tensor) + + if data_match: + content_valid_count += 1 + else: + print(f" GET for key {exist_keys[i]} returned incorrect memory object") + print(" Data content mismatch detected") + + except Exception as e: + print(f" Data comparison failed for key {exist_keys[i]}: {e}") + # Standard + import traceback + + traceback.print_exc() + + # Calculate pass rates + not_none_count = sum(1 for r in get_results["results"] if r is not None) + content_pass_rate = content_valid_count / num_tests * 100 + print(f" Validation (not None): {not_none_count}/{num_tests} passed") + print( + f" Validation (content correct): {content_valid_count}/{num_tests}" + f" passed ({content_pass_rate:.1f}%)" + ) + return content_valid_count, not_none_count + + +async def run_common_test_framework( + test_context, + model: str, + num_tests: int = 5, +): + """ + Common test framework for both storage manager and remote backend tests. + + Args: + test_context: A dictionary containing test-specific functions and objects: + - 'create_test_data_func': Function to create test data + - 'async_contains_func': Async function for contains operations + - 'async_put_func': Async function for put operations + - 'async_get_func': Async function for get operations + - 'validate_get_func': Function to validate get results + - 'test_object': The main test object (storage_manager or backend) + - 'extra_args': Extra arguments for test data creation (optional) + model: Model name for testing + num_tests: Number of tests to run + """ + print("Testing basic operations...") + + # Create test data using the provided function + extra_args = test_context.get("extra_args", []) + if extra_args: + non_exist_keys, exist_keys, exist_memories, num_tests = test_context[ + "create_test_data_func" + ](test_context["test_object"], *extra_args, model, num_tests) + else: + non_exist_keys, exist_keys, exist_memories, num_tests = test_context[ + "create_test_data_func" + ](test_context["test_object"], _get_default_metadata(model), model, num_tests) + + # Phase 1: exists test (key does not exist) + print("Phase 1: Testing exists for non-existing keys...") + + exists_non_exist_res = await run_perf_test_with_timeout( + test_context["async_contains_func"], + [(test_context["test_object"], key) for key in non_exist_keys], + ) + exists_non_exist_stats = exists_non_exist_res["time_stats"] + # Validation: All non-existing keys should return False + exists_non_exist_pass_count = sum( + 1 for r in exists_non_exist_res["results"] if r is False + ) + pass_rate = exists_non_exist_pass_count / len(non_exist_keys) * 100 + print( + f" Validation: {exists_non_exist_pass_count}/{len(non_exist_keys)} " + f"passed ({pass_rate:.1f}%)" + ) + + # Phase 2: put test (create new key) + print("Phase 2: Testing put operations...") + + put_res = await run_perf_test_with_timeout( + test_context["async_put_func"], + [ + (test_context["test_object"], exist_keys[i], exist_memories[i]) + for i in range(num_tests) + ], + ) + put_stats = put_res["time_stats"] + # Validation: All PUT operations should return True + put_pass_count = sum(1 for r in put_res["results"] if r is True) + pass_rate = put_pass_count / num_tests * 100 + print(f" Validation: {put_pass_count}/{num_tests} passed ({pass_rate:.1f}%)") + + # Phase 3: exists test (key exists) + print("Phase 3: Testing exists for existing keys...") + + exists_exist_res = await run_perf_test_with_timeout( + test_context["async_contains_func"], + [(test_context["test_object"], key) for key in exist_keys], + ) + exists_exist_stats = exists_exist_res["time_stats"] + # Validation: All existing keys should return True + exists_exist_pass_count = sum(1 for r in exists_exist_res["results"] if r is True) + pass_rate = exists_exist_pass_count / num_tests * 100 + print( + f" Validation: {exists_exist_pass_count}/{num_tests} passed ({pass_rate:.1f}%)" + ) + + # Phase 4: get test (key exists) + print("Phase 4: Testing get operations...") + + get_res = await run_perf_test_with_timeout( + test_context["async_get_func"], + [(test_context["test_object"], key) for key in exist_keys], + ) + get_stats = get_res["time_stats"] + # Validation: Check for non-None results and content correctness + content_valid_count, not_none_count = test_context["validate_get_func"]( + get_res, exist_keys, exist_memories, num_tests + ) + # Use content_valid_count as the pass_count for GET operations + get_pass_count = content_valid_count + + stats_data = [ + ( + "EXISTS (non-exist)", + exists_non_exist_stats, + exists_non_exist_res["results"], + exists_non_exist_pass_count, + ), + ("PUT", put_stats, put_res["results"], put_pass_count), + ( + "EXISTS (exist)", + exists_exist_stats, + exists_exist_res["results"], + exists_exist_pass_count, + ), + ("GET", get_stats, get_res["results"], get_pass_count), + ] + + # Use common performance results printing + print_performance_results(stats_data) + + +class EventLoopManager: + """Manages a dedicated event loop in a separate thread""" + + def __init__(self): + self.loop = None + self.thread = None + self._loop_started = threading.Event() + + def start(self): + """Start the event loop in a separate thread""" + if self.thread is not None and self.thread.is_alive(): + return + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self._run_loop, daemon=True) + self.thread.start() + self._loop_started.wait() + + def _run_loop(self): + """Run the event loop""" + asyncio.set_event_loop(self.loop) + self._loop_started.set() + try: + self.loop.run_forever() + except Exception as e: + print(f"Event loop error: {e}") + finally: + self.loop.close() + + def stop(self): + """Stop the event loop and thread""" + if self.loop and not self.loop.is_closed(): + self.loop.call_soon_threadsafe(self.loop.stop) + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=5.0) + + def get_loop(self): + """Get the event loop""" + return self.loop diff --git a/lmcache/v1/compute/blend/blender.py b/lmcache/v1/compute/blend/blender.py index 6a2d10ce81d..2f6ccf080ac 100644 --- a/lmcache/v1/compute/blend/blender.py +++ b/lmcache/v1/compute/blend/blender.py @@ -95,6 +95,7 @@ def process_qkv( # TODO(Jiayi): remove `[0]` hardcode topk_num = int(total_len * self.common_metadata.recomp_ratios[0]) + topk_num = max(topk_num, 1) top_indices = torch.topk(diff_k, k=topk_num).indices top_indices, _ = torch.sort(top_indices) diff --git a/lmcache/v1/compute/models/base.py b/lmcache/v1/compute/models/base.py new file mode 100644 index 00000000000..093d22f9f4d --- /dev/null +++ b/lmcache/v1/compute/models/base.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from abc import ABC, abstractmethod + +# Third Party +from torch import nn +import torch + +# First Party +from lmcache.v1.compute.attention.utils import infer_attn_backend_from_vllm +from lmcache.v1.compute.positional_encoding import get_fused_rope + +# TODO(Jiayi): A few things need to be tested/supported: +# TP, PP, Multimodal + + +class LMCBaseModel(nn.Module, ABC): + def __init__( + self, + vllm_model, + blender, + enable_sparse: bool = False, + ): + super().__init__() + self.vllm_model = vllm_model + + self.num_layers = len(vllm_model.model.layers) + + self.vllm_attn_layers = [] + self.lmc_attn_layers = [] + for i in range(self.num_layers): + vllm_attn = vllm_model.model.layers[i].self_attn.attn + self.vllm_attn_layers.append(vllm_attn) + + self.lmc_attn_layers.append( + infer_attn_backend_from_vllm(vllm_attn, enable_sparse) + ) + + # NOTE(Jiayi): better not to pass the blender in init + # if we want to make this LMCModel more general. + self.blender = blender + + # remove hard code + rotary_emb = vllm_model.model.layers[0].self_attn.rotary_emb + head_dim = rotary_emb.head_size + max_position_embeddings = rotary_emb.max_position_embeddings + rope_scaling = None + base = rotary_emb.base + is_neox_style = rotary_emb.is_neox_style + dtype = rotary_emb.dtype + self.fused_rotary_emb = get_fused_rope( + head_dim, + rotary_dim=head_dim, + max_position=max_position_embeddings, + base=base, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + dtype=dtype, + ) + + @abstractmethod + def _process_qkv(self, q, k, v, layer): + """Process QKV tensors. Model-specific implementation.""" + pass + + @torch.compile + def compute_layer( + self, + input_ids: torch.Tensor, + ): + input_ids = input_ids.cuda() + hidden_states = self.vllm_model.get_input_embeddings(input_ids) + residual = None + + attn_output = None + + # TODO(Jiayi): Need to build `attn_metadata` more elegantly. + attn_metadata = self.lmc_attn_layers[0].init_attn_metadata( + input_ids=input_ids, + ) + + for idx, layer in enumerate( + self.vllm_model.model.layers[ + self.vllm_model.model.start_layer : self.vllm_model.model.end_layer + ] + ): + # TODO(Jiayi) The last layer doesn't have to be computed + # hidden_states, residual = layer(positions, hidden_states, residual) + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = layer.input_layernorm(hidden_states) + else: + hidden_states, residual = layer.input_layernorm(hidden_states, residual) + # hidden_states = self.self_attn(positions=positions, + # hidden_states=hidden_states) + + qkv, _ = layer.self_attn.qkv_proj(hidden_states) + q, k, v = qkv.split( + [ + layer.self_attn.q_size, + layer.self_attn.kv_size, + layer.self_attn.kv_size, + ], + dim=-1, + ) + + # Model-specific QKV processing + q, k, v = self._process_qkv(q, k, v, layer) + + q, k, v, residual, attn_output, attn_metadata = self.blender.process_qkv( + q, k, v, residual, idx, attn_output, attn_metadata + ) + + num_heads = self.vllm_attn_layers[idx].num_heads + num_kv_heads = self.vllm_attn_layers[idx].num_kv_heads + head_size = self.vllm_attn_layers[idx].head_size + + q = q.view(-1, num_heads, head_size) + k = k.view(-1, num_kv_heads, head_size) + v = v.view(-1, num_kv_heads, head_size) + attn_output = attn_output.view(-1, num_heads, head_size) + + attn_output = self.lmc_attn_layers[idx].forward_contiguous( + q, k, v, attn_output, attn_metadata + ) + + attn_output = attn_output.view(-1, num_heads * head_size) + k = k.view(-1, num_kv_heads * head_size) + v = v.view(-1, num_kv_heads * head_size) + + hidden_states, _ = layer.self_attn.o_proj(attn_output) + + # Fully Connected + hidden_states, residual = layer.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = layer.mlp(hidden_states) + + yield diff --git a/lmcache/v1/compute/models/llama.py b/lmcache/v1/compute/models/llama.py index d9bb55bcfdc..7899ee0e25a 100644 --- a/lmcache/v1/compute/models/llama.py +++ b/lmcache/v1/compute/models/llama.py @@ -1,130 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -# Third Party -from torch import nn -import torch - # First Party -from lmcache.v1.compute.attention.utils import infer_attn_backend_from_vllm -from lmcache.v1.compute.positional_encoding import get_fused_rope - -# TODO(Jiayi): A few things need to be tested/supported: -# TP, PP, Multimodal - - -class LMCLlamaModel(nn.Module): - def __init__( - self, - vllm_model, - blender, - enable_sparse: bool = False, - ): - super().__init__() - self.vllm_model = vllm_model - - self.num_layers = len(vllm_model.model.layers) - - self.vllm_attn_layers = [] - self.lmc_attn_layers = [] - for i in range(self.num_layers): - vllm_attn = vllm_model.model.layers[i].self_attn.attn - self.vllm_attn_layers.append(vllm_attn) - - self.lmc_attn_layers.append( - infer_attn_backend_from_vllm(vllm_attn, enable_sparse) - ) - - # NOTE(Jiayi): better not to pass the blender in init - # if we want to make this LMCModel more general. - self.blender = blender - - # remove hard code - rotary_emb = vllm_model.model.layers[0].self_attn.rotary_emb - head_dim = rotary_emb.head_size - max_position_embeddings = rotary_emb.max_position_embeddings - rope_scaling = None - base = rotary_emb.base - is_neox_style = rotary_emb.is_neox_style - dtype = rotary_emb.dtype - self.fused_rotary_emb = get_fused_rope( - head_dim, - rotary_dim=head_dim, - max_position=max_position_embeddings, - base=base, - rope_scaling=rope_scaling, - is_neox_style=is_neox_style, - dtype=dtype, - ) - - @torch.compile - def compute_layer( - self, - input_ids: torch.Tensor, - ): - input_ids = input_ids.cuda() - hidden_states = self.vllm_model.get_input_embeddings(input_ids) - residual = None - - attn_output = None - - # TODO(Jiayi): Need to build `attn_metadata` more elegantly. - attn_metadata = self.lmc_attn_layers[0].init_attn_metadata( - input_ids=input_ids, - ) - - for idx, layer in enumerate( - self.vllm_model.model.layers[ - self.vllm_model.model.start_layer : self.vllm_model.model.end_layer - ] - ): - # TODO(Jiayi) The last layer doesn't have to be computed - # hidden_states, residual = layer(positions, hidden_states, residual) - - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = layer.input_layernorm(hidden_states) - else: - hidden_states, residual = layer.input_layernorm(hidden_states, residual) - # hidden_states = self.self_attn(positions=positions, - # hidden_states=hidden_states) - - qkv, _ = layer.self_attn.qkv_proj(hidden_states) - q, k, v = qkv.split( - [ - layer.self_attn.q_size, - layer.self_attn.kv_size, - layer.self_attn.kv_size, - ], - dim=-1, - ) - - q, k, v, residual, attn_output, attn_metadata = self.blender.process_qkv( - q, k, v, residual, idx, attn_output, attn_metadata - ) - - num_heads = self.vllm_attn_layers[idx].num_heads - num_kv_heads = self.vllm_attn_layers[idx].num_kv_heads - head_size = self.vllm_attn_layers[idx].head_size - - q = q.view(-1, num_heads, head_size) - k = k.view(-1, num_kv_heads, head_size) - v = v.view(-1, num_kv_heads, head_size) - attn_output = attn_output.view(-1, num_heads, head_size) - - attn_output = self.lmc_attn_layers[idx].forward_contiguous( - q, k, v, attn_output, attn_metadata - ) - - attn_output = attn_output.view(-1, num_heads * head_size) - k = k.view(-1, num_kv_heads * head_size) - v = v.view(-1, num_kv_heads * head_size) - - hidden_states, _ = layer.self_attn.o_proj(attn_output) +from lmcache.v1.compute.models.base import LMCBaseModel - # Fully Connected - hidden_states, residual = layer.post_attention_layernorm( - hidden_states, residual - ) - hidden_states = layer.mlp(hidden_states) - yield +class LMCLlamaModel(LMCBaseModel): + def _process_qkv(self, q, k, v, layer): + """Process QKV tensors for LLaMa model (no additional processing).""" + return q, k, v diff --git a/lmcache/v1/compute/models/qwen3.py b/lmcache/v1/compute/models/qwen3.py new file mode 100644 index 00000000000..df64e9c6d74 --- /dev/null +++ b/lmcache/v1/compute/models/qwen3.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# First Party +from lmcache.v1.compute.models.base import LMCBaseModel + + +class LMCQwen3Model(LMCBaseModel): + def _process_qkv(self, q, k, v, layer): + """Process QKV tensors for Qwen3 model with q_norm and k_norm layers.""" + # Qwen3 has q_norm and k_norm layers + q_by_head = q.view( + *q.shape[:-1], + q.shape[-1] // layer.self_attn.head_dim, + layer.self_attn.head_dim, + ) + q_by_head = layer.self_attn.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view( + *k.shape[:-1], + k.shape[-1] // layer.self_attn.head_dim, + layer.self_attn.head_dim, + ) + k_by_head = layer.self_attn.k_norm(k_by_head) + k = k_by_head.view(k.shape) + return q, k, v diff --git a/lmcache/v1/compute/models/utils.py b/lmcache/v1/compute/models/utils.py index d008bbf0e72..31e31ddc654 100644 --- a/lmcache/v1/compute/models/utils.py +++ b/lmcache/v1/compute/models/utils.py @@ -18,6 +18,16 @@ def infer_model_from_vllm(vllm_model, blender, enable_sparse: bool = False): from lmcache.v1.compute.models.llama import LMCLlamaModel return LMCLlamaModel(vllm_model, blender, enable_sparse) + elif model_name == "Qwen2ForCausalLM": + # First Party + from lmcache.v1.compute.models.llama import LMCLlamaModel + + return LMCLlamaModel(vllm_model, blender, enable_sparse) + elif model_name == "Qwen3ForCausalLM": + # First Party + from lmcache.v1.compute.models.qwen3 import LMCQwen3Model + + return LMCQwen3Model(vllm_model, blender, enable_sparse) else: # TODO(Jiayi): Add support for more models raise NotImplementedError( diff --git a/lmcache/v1/config.py b/lmcache/v1/config.py index 785e9bd6ac3..4c3227fc5fd 100644 --- a/lmcache/v1/config.py +++ b/lmcache/v1/config.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from typing import Any, Optional, Union +import ast import json import os import re +import uuid # Third Party import yaml @@ -71,6 +73,33 @@ def _to_bool( return str(value).strip().lower() in ["true", "1"] +def _parse_quoted_string(value: str) -> str: + """Parse a string that may be surrounded by quotes and handle escape characters. + + Args: + value: The input string that may be quoted + + Returns: + The unquoted string with escape characters properly handled + """ + if not value: + return value + + value = value.strip() + + if len(value) >= 2 and value[0] == value[-1] and value[0] in ("'", '"'): + try: + evaluated = ast.literal_eval(value) + if isinstance(evaluated, str): + return evaluated + except (ValueError, SyntaxError): + # If ast.literal_eval fails, it's not a valid Python literal. + # Fall back to simply stripping the outer quotes. + return value[1:-1] + + return value + + # Configuration aliases and deprecated mappings _CONFIG_ALIASES = { # Maps deprecated names to current names @@ -101,6 +130,7 @@ def _to_bool( "env_converter": _to_bool, }, "max_local_cpu_size": {"type": float, "default": 5.0, "env_converter": float}, + "reserve_local_cpu_size": {"type": float, "default": 0.0, "env_converter": float}, "local_disk": { "type": Optional[str], "default": None, @@ -176,8 +206,8 @@ def _to_bool( "env_converter": _to_bool, }, "lmcache_instance_id": { - "type": str, - "default": "lmcache_default_instance", + "type": Optional[str], + "default": None, "env_converter": str, }, "controller_pull_url": { @@ -361,6 +391,70 @@ def _to_bool( "default": None, "env_converter": float, }, + "lookup_server_worker_ids": { + "type": Optional[list[int]], + "default": None, + "env_converter": _to_int_list, + }, + "enable_scheduler_bypass_lookup": { + "type": bool, + "default": False, + "env_converter": _to_bool, + }, + "script_allowed_imports": { + "type": Optional[list[str]], + "default": None, + "env_converter": _to_str_list, + }, + # Lazy memory allocator configurations + "enable_lazy_memory_allocator": { + "type": bool, + "default": False, + "env_converter": _to_bool, + "description": ( + "Enable lazy memory allocator to reduce initial memory footprint. " + "Memory is allocated on-demand and expanded automatically when needed." + ), + }, + "lazy_memory_initial_ratio": { + "type": float, + "default": 0.2, + "env_converter": float, + "description": ( + "Initial memory allocation ratio (0.0-1.0). " + "Determines the percentage of target memory size to allocate at startup. " + "Default is 0.2 (20%)." + ), + }, + "lazy_memory_expand_trigger_ratio": { + "type": float, + "default": 0.5, + "env_converter": float, + "description": ( + "Memory usage ratio (0.0-1.0) that triggers automatic expansion. " + "When memory usage exceeds this threshold, expansion is triggered. " + "Default is 0.5 (50%)." + ), + }, + "lazy_memory_step_ratio": { + "type": float, + "default": 0.1, + "env_converter": float, + "description": ( + "Memory expansion step ratio (0.0-1.0). " + "Determines the percentage of target memory size to add in each expansion. " + "Default is 0.1 (10%)." + ), + }, + "lazy_memory_safe_size": { + "type": float, + "default": 0.0, + "env_converter": float, + "description": ( + "Safe threshold size in GB. Lazy allocator is only enabled when " + "max_local_cpu_size exceeds this value. Default is 0.0 GB (always enabled)." + ), + }, } @@ -404,7 +498,9 @@ def _create_config_class(): from dataclasses import make_dataclass def _post_init(self): - self.validate() + # Generate random instance ID if not set + if not self.lmcache_instance_id: + self.lmcache_instance_id = f"lmcache_instance_{uuid.uuid4().hex}" cls = make_dataclass( "LMCacheEngineConfig", @@ -415,6 +511,7 @@ def _post_init(self): "log_config": _log_config, "to_original_config": _to_original_config, "get_extra_config_value": _get_extra_config_value, + "get_lookup_server_worker_ids": _get_lookup_server_worker_ids, "from_defaults": classmethod(_from_defaults), "from_legacy": classmethod(_from_legacy), "from_file": classmethod(_from_file), @@ -434,8 +531,9 @@ def _post_init(self): def _validate_config(self): """Validate configuration""" + # auto-adjust save_unfull_chunk for async loading to prevent CPU fragmentation - if self.enable_async_loading or self.use_layerwise: + if self.enable_async_loading: logger.warning( "Automatically setting save_unfull_chunk=False because " "enable_async_loading=True or use_layerwise=True to prevent " @@ -443,6 +541,14 @@ def _validate_config(self): ) self.save_unfull_chunk = False + if self.enable_blending: + if not self.save_unfull_chunk: + logger.warning( + "Automatically setting save_unfull_chunk=True because " + "enable_blending=True" + ) + self.save_unfull_chunk = True + if self.enable_p2p: assert self.enable_controller assert self.controller_pull_url is not None @@ -469,8 +575,7 @@ def _validate_config(self): if enable_nixl_storage: assert self.extra_config.get("nixl_backend") is not None - assert self.extra_config.get("nixl_path") is not None - assert self.extra_config.get("nixl_file_pool_size") is not None + assert self.extra_config.get("nixl_pool_size") is not None assert self.nixl_buffer_size is not None assert self.nixl_buffer_device is not None @@ -515,6 +620,20 @@ def _get_extra_config_value(self, key, default_value=None): return default_value +def _get_lookup_server_worker_ids(self, use_mla, world_size): + if self.lookup_server_worker_ids is None: + # if mla is not enabled, return [], which means start + # lookup server on all worker as default; + # if mla is enabled, return [0], which means start lookup + # server on worker 0 as default. + return [0] if use_mla else [] + + # check the input + for worker_id in self.lookup_server_worker_ids: + assert -1 < worker_id < world_size + return self.lookup_server_worker_ids + + def _from_defaults(cls, **kwargs): """Create configuration from defaults""" config_values = {} @@ -586,6 +705,7 @@ def _from_legacy(cls, **kwargs): config_values[name] = config["default"] instance = cls(**config_values) + instance.validate() return instance @@ -646,13 +766,17 @@ def get_env_name(attr_name: str) -> str: for name, config in _CONFIG_DEFINITIONS.items(): if name in resolved_config: try: - value = resolved_config[name] + # Parse quoted strings and handle escape characters + raw_value = resolved_config[name] # Keep original value for logging + value = _parse_quoted_string(raw_value) converted_value = config["env_converter"](value) setattr(self, name, converted_value) except (ValueError, json.JSONDecodeError) as e: - logger.warning(f"Failed to parse {get_env_name(name)}: {e}") + logger.warning( + f"Failed to parse {get_env_name(name)}={raw_value!r}: {e}" + ) # Keep existing value if conversion fails - + self.validate() return self diff --git a/lmcache/v1/event_manager.py b/lmcache/v1/event_manager.py index 80b193c064a..df3004afebb 100644 --- a/lmcache/v1/event_manager.py +++ b/lmcache/v1/event_manager.py @@ -19,13 +19,15 @@ class EventManager: """ A thread-safe event manager to manage asynchronous events. Each event is identified by its type and a unique id. + Events are organized by status for efficient counting. """ def __init__(self) -> None: - self.events: dict[EventType, dict[str, tuple[EventStatus, asyncio.Future]]] = {} - for event_type in EventType: - self.events[event_type] = {} - + # Guard by lock + # Structure: events[event_type][event_status][event_id] = future + self.events: dict[EventType, dict[EventStatus, dict[str, asyncio.Future]]] = { + et: {es: {} for es in EventStatus} for et in EventType + } self.lock = threading.Lock() def add_event( @@ -38,11 +40,11 @@ def add_event( Add an event with the given type and id. """ with self.lock: - sub_events_dict = self.events.get(event_type, None) - assert sub_events_dict is not None, ( + status_dict = self.events.get(event_type, None) + assert status_dict is not None, ( f"Invalid event type {event_type} in EventManager." ) - sub_events_dict[event_id] = (EventStatus.ONGOING, future) + status_dict[EventStatus.ONGOING][event_id] = future def pop_event( self, @@ -53,21 +55,15 @@ def pop_event( Pop and return the event with the given type and id. """ with self.lock: - sub_events_dict = self.events.get(event_type, None) - assert sub_events_dict is not None, ( + status_dict = self.events.get(event_type, None) + assert status_dict is not None, ( f"Invalid event type {event_type} in EventManager." ) - assert event_id in sub_events_dict, ( - f"Event {event_id} of type {event_type} not found in EventManager." - ) - status, event = sub_events_dict.pop(event_id) - assert status == EventStatus.DONE, ( - f"Event {event_id} of type {event_type} is not done." - ) - assert event is not None, ( - f"Event {event_id} of type {event_type} not found in EventManager." + done_events = status_dict[EventStatus.DONE] + assert event_id in done_events, ( + f"Event {event_id} of type {event_type} is not done or not found." ) - return event + return done_events.pop(event_id) def update_event_status( self, @@ -79,16 +75,23 @@ def update_event_status( Update the status of the event with the given type and id. """ with self.lock: - sub_events_dict = self.events.get(event_type, None) - assert sub_events_dict is not None, ( + status_dict = self.events.get(event_type, None) + assert status_dict is not None, ( f"Invalid event type {event_type} in EventManager." ) - if event_id in sub_events_dict: - _, event = sub_events_dict[event_id] - sub_events_dict[event_id] = (status, event) - else: + # Find the event in any status dict + event = None + for s in EventStatus: + if event_id in status_dict[s]: + event = status_dict[s].pop(event_id) + break + + if event is None: raise KeyError(f"Event {event_id} of type {event_type} not found.") + # Move to new status dict + status_dict[status][event_id] = event + def get_event_status( self, event_type: EventType, @@ -98,12 +101,26 @@ def get_event_status( Get the status of the event with the given type and id. """ with self.lock: - sub_events_dict = self.events.get(event_type, None) - assert sub_events_dict is not None, ( + status_dict = self.events.get(event_type, None) + assert status_dict is not None, ( f"Invalid event type {event_type} in EventManager." ) - if event_id in sub_events_dict: - status, _ = sub_events_dict[event_id] - return status - else: - return EventStatus.NOT_FOUND + for status in EventStatus: + if event_id in status_dict[status]: + return status + return EventStatus.NOT_FOUND + + def get_events_count_by_status( + self, + event_type: EventType, + status: EventStatus, + ) -> int: + """ + Get the count of events for the given event type and status. + This is a lightweight O(1) operation using dict length. + """ + with self.lock: + status_dict = self.events.get(event_type, None) + if status_dict is None: + return 0 + return len(status_dict[status]) diff --git a/lmcache/v1/gpu_connector.py b/lmcache/v1/gpu_connector.py index a80c4266c73..bb25fc104ab 100644 --- a/lmcache/v1/gpu_connector.py +++ b/lmcache/v1/gpu_connector.py @@ -562,7 +562,7 @@ def batched_to_gpu(self, starts: List[int], ends: List[int], **kwargs): if self.cache_positions and layer_id == 0: old_positions_full[ start - buf_offset : end - buf_offset - ] = memory_obj.metadata.old_positions + ] = memory_obj.metadata.cached_positions elif layer_id == self.num_layers: yield @@ -686,7 +686,7 @@ def batched_from_gpu( non_blocking=True, ) if self.cache_positions: - memory_obj.metadata.old_positions = old_positions + memory_obj.metadata.cached_positions = old_positions yield self.store_stream.synchronize() @@ -1247,7 +1247,7 @@ def __init__( self.num_kv_cache, dtype=torch.int64, device="cpu" ) self.use_gpu = use_gpu - self.gpu_buffer_allocator: GPUMemoryAllocator + self.gpu_buffer_allocator: Optional[GPUMemoryAllocator] = None def _lazy_initialize_buffer(self, kv_caches): """ @@ -1319,6 +1319,10 @@ def batched_to_gpu(self, starts: List[int], ends: List[int], **kwargs): if self.use_gpu: buffer_shape = self.get_shape(num_tokens) + + assert self.gpu_buffer_allocator is not None, ( + "GPU buffer allocator should be initialized" + ) tmp_gpu_buffer_obj = self.gpu_buffer_allocator.allocate( buffer_shape, self.dtype, MemoryFormat.KV_T2D ) @@ -1344,7 +1348,7 @@ def batched_to_gpu(self, starts: List[int], ends: List[int], **kwargs): memory_obj.tensor, non_blocking=True ) else: - lmc_ops.single_layer_kv_transfer( + lmc_ops.single_layer_kv_transfer_sgl( memory_obj.tensor, self.kvcaches[0][layer_id], self.kvcaches[1][layer_id], @@ -1355,7 +1359,8 @@ def batched_to_gpu(self, starts: List[int], ends: List[int], **kwargs): if self.use_gpu: t, h, d = self.kvcaches[0][layer_id].shape - lmc_ops.single_layer_kv_transfer( + + lmc_ops.single_layer_kv_transfer_sgl( tmp_gpu_buffer_obj.tensor, self.kvcaches[0][layer_id].view(t, 1, h, d), self.kvcaches[1][layer_id].view(t, 1, h, d), @@ -1428,6 +1433,10 @@ def batched_from_gpu( if self.use_gpu: buffer_shape = self.get_shape(num_tokens) + + assert self.gpu_buffer_allocator is not None, ( + "GPU buffer allocator should be initialized" + ) tmp_gpu_buffer_obj = self.gpu_buffer_allocator.allocate( buffer_shape, self.dtype, MemoryFormat.KV_T2D ) @@ -1441,7 +1450,7 @@ def batched_from_gpu( # kvcaches -> gpu_buffer -> memobj if self.use_gpu: t, h, d = self.kvcaches[0][layer_id].shape - lmc_ops.single_layer_kv_transfer( + lmc_ops.single_layer_kv_transfer_sgl( tmp_gpu_buffer_obj.tensor, self.kvcaches[0][layer_id].view(t, 1, h, d), self.kvcaches[1][layer_id].view(t, 1, h, d), @@ -1464,7 +1473,7 @@ def batched_from_gpu( ) start_idx += chunk_len else: - lmc_ops.single_layer_kv_transfer( + lmc_ops.single_layer_kv_transfer_sgl( memory_obj.tensor, self.kvcaches[0][layer_id], self.kvcaches[1][layer_id], diff --git a/lmcache/v1/internal_api_server/cache_api.py b/lmcache/v1/internal_api_server/cache_api.py new file mode 100644 index 00000000000..5c86997d813 --- /dev/null +++ b/lmcache/v1/internal_api_server/cache_api.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import Annotated, List, Optional +import json + +# Third Party +from fastapi import APIRouter, Query +from starlette.requests import Request +from starlette.responses import PlainTextResponse + +router = APIRouter() + + +@router.delete("/cache/clear") +async def clear( + request: Request, + locations: Annotated[Optional[List[str]], Query()] = None, + request_configs: Optional[dict] = None, +): + """Clear cached data from the LMCache engine. + + This endpoint provides a way to clear cached KV (Key-Value) data from the + LMCache engine. It can clear all cached data or selectively clear data + from specific storage locations. + + Args: + request (Request): The FastAPI request object containing application state. + locations (Optional[List[str]], optional): List of storage backend locations + to clear cache from. If None, clears from all available locations. + Common values include ["LocalCPUBackend", "LocalDiskBackend"]. + Defaults to None. + request_configs (Optional[dict], optional): Additional configuration + parameters for the clear operation. Currently unused but reserved + for future extensions. Defaults to None. + + Returns: + PlainTextResponse: A plain text response + + Example: + Clear all cached data: + ```bash + curl -X DELETE "http://localhost:8000/cache/clear" + # Response: {"status": "success", "num_removed": 10, + # "locations": null, "request_configs": null} + ``` + + Clear cache from specific locations: + ```bash + curl -X DELETE "http://localhost:8000/cache/clear?locations=LocalCPUBackend&locations=LocalDiskBackend" + # Response: {"status": "success", "num_removed": 5, + # "locations": ["LocalCPUBackend", "LocalDiskBackend"], + # "request_configs": null} + ``` + """ + try: + lmcache_adapter = request.app.state.lmcache_adapter + lmcache_engine = getattr(lmcache_adapter, "lmcache_engine", None) + if not lmcache_engine: + error_info = { + "error": "/cache/clear API is unavailable", + "message": "LMCache engine not configured.", + } + return PlainTextResponse( + content=json.dumps(error_info, indent=2), + media_type="application/json", + status_code=503, # Service Unavailable + ) + num_removed = lmcache_engine.clear( + locations=locations, request_configs=request_configs + ) + success_info = { + "status": "success", + "num_removed": num_removed, + } + return PlainTextResponse( + content=json.dumps(success_info, indent=2), + media_type="application/json", + ) + except Exception as e: + error_info = {"error": "Failed to clear cache", "message": str(e)} + return PlainTextResponse( + content=json.dumps(error_info, indent=2), + media_type="application/json", + status_code=500, + ) diff --git a/lmcache/v1/internal_api_server/env_api.py b/lmcache/v1/internal_api_server/env_api.py new file mode 100644 index 00000000000..88d462d9e93 --- /dev/null +++ b/lmcache/v1/internal_api_server/env_api.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import json +import os + +# Third Party +from fastapi import APIRouter +from starlette.responses import PlainTextResponse + +router = APIRouter() + + +@router.get("/env") +async def get_env(): + """ + Get all environment variables + """ + env_dict = dict(os.environ) + return PlainTextResponse( + content=json.dumps(env_dict, indent=2, sort_keys=True), + media_type="text/plain", + ) diff --git a/lmcache/v1/internal_api_server/run_script_api.py b/lmcache/v1/internal_api_server/run_script_api.py index 365e96f7ab2..5ea8a13d6a9 100644 --- a/lmcache/v1/internal_api_server/run_script_api.py +++ b/lmcache/v1/internal_api_server/run_script_api.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from typing import Any +import importlib # Third Party from fastapi import APIRouter from starlette.requests import Request from starlette.responses import PlainTextResponse +# First Party +from lmcache.logging import init_logger + +logger = init_logger(__name__) + router = APIRouter() @@ -21,6 +27,26 @@ async def run_script(request: Request): script_content = await script_file.read() try: + # Get allowed imports from config + config = request.app.state.lmcache_adapter.config + allowed_imports = config.script_allowed_imports or [] + + # Pre-import allowed modules + allowed_modules = {} + for module_name in allowed_imports: + try: + module = importlib.import_module(module_name) + allowed_modules[module_name] = module + logger.info(f"Imported allowed module: {module_name}") + except ImportError as e: + logger.warning(f"Failed to import module {module_name}: {e}") + + # Create custom __import__ function that only allows configured modules + def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): + if name in allowed_modules: + return allowed_modules[name] + raise ImportError(f"Import of '{name}' is not allowed") + restricted_globals = { "__builtins__": { "print": print, @@ -31,9 +57,11 @@ async def run_script(request: Request): "dict": dict, "tuple": tuple, "set": set, + "__import__": restricted_import, }, "app": request.app, } + restricted_locals: dict[str, Any] = {} exec(script_content, restricted_globals, restricted_locals) diff --git a/lmcache/v1/lazy_memory_allocator.py b/lmcache/v1/lazy_memory_allocator.py new file mode 100644 index 00000000000..6c6a09bc390 --- /dev/null +++ b/lmcache/v1/lazy_memory_allocator.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Lazy memory allocator with async progressive expansion and zero-copy.""" + +# Standard +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union +import threading + +# Third Party +import sortedcontainers +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.observability import LMCStatsMonitor +from lmcache.utils import _lmcache_nvtx_annotate +from lmcache.v1.memory_management import ( + BufferAllocator, + FreeBlock, + MemoryFormat, + MemoryObj, + MixedMemoryAllocator, + TensorMemoryAllocator, + _allocate_cpu_memory, +) +from lmcache.v1.system_detection import NUMAMapping + +if TYPE_CHECKING: + # First Party + from lmcache.v1.config import LMCacheEngineConfig + +if torch.cuda.is_available(): + # First Party + import lmcache.c_ops as lmc_ops +else: + # First Party + import lmcache.non_cuda_equivalents as lmc_ops + +logger = init_logger(__name__) + + +class CompositeBuffer: + """Manages multiple memory segments with unified view (zero-copy).""" + + def __init__(self, initial_buffer: torch.Tensor): + self.segments: List[torch.Tensor] = [initial_buffer] + self.segment_offsets: List[int] = [0] + self.total_size = initial_buffer.numel() + self.lock = threading.Lock() + + def add_segment(self, new_buffer: torch.Tensor) -> int: + """Add a new memory segment to the composite buffer. + + Thread-safe: Protected by self.lock. + + Returns: + int: The offset of the new segment in the unified address space. + """ + with self.lock: + offset = self.total_size + self.segments.append(new_buffer) + self.segment_offsets.append(offset) + self.total_size += new_buffer.numel() + logger.info( + f"Added segment: {new_buffer.numel()} bytes, total: {self.total_size}" + ) + return offset + + def get_slice(self, start: int, size: int) -> torch.Tensor: + with self.lock: + segment_idx = self._find_segment(start) + if segment_idx == -1: + raise ValueError(f"Invalid offset: {start}") + + segment_start = self.segment_offsets[segment_idx] + segment = self.segments[segment_idx] + end = start + size + + if end <= segment_start + segment.numel(): + local_start = start - segment_start + return segment[local_start : local_start + size] + else: + raise ValueError( + f"Slice spans segments (start={start}, size={size}). " + "Bug in segment-aware coalescing." + ) + + def _find_segment(self, offset: int) -> int: + for i in range(len(self.segments) - 1, -1, -1): + if offset >= self.segment_offsets[i]: + if offset < self.segment_offsets[i] + self.segments[i].numel(): + return i + return -1 + + def numel(self) -> int: + return self.total_size + + +class CompositeTensorMemoryAllocator(TensorMemoryAllocator): + """TensorMemoryAllocator with segment-aware coalescing for CompositeBuffer.""" + + def __init__( + self, + composite_buffer: CompositeBuffer, + align_bytes: int = TensorMemoryAllocator.ALIGN_BYTES, + ): + self.composite_buffer = composite_buffer + self.buffer = composite_buffer.segments[0].view(torch.uint8).flatten() + self.align_bytes = align_bytes + self.explicit_list = sortedcontainers.SortedList(key=lambda x: x.start) + self.explicit_list.add(FreeBlock(start=0, size=self.buffer.numel())) + self.num_active_allocations = 0 + self.total_allocated_size = 0 + self.segment_boundaries = [composite_buffer.segments[0].numel()] + self.stats_monitor = LMCStatsMonitor.GetOrCreate() + + def expand_with_new_segment(self, new_buffer: torch.Tensor): + """Expand the allocator with a new memory segment. + + Thread Safety: + ============== + This method modifies shared data structures (explicit_list, segment_boundaries) + that are also accessed by allocate/free operations in the main thread. + + The caller MUST hold the host_mem_lock before calling this method to prevent + race conditions. + """ + offset = self.composite_buffer.add_segment(new_buffer) + new_size = new_buffer.numel() + self.segment_boundaries.append(offset + new_size) + + new_free_block = FreeBlock(start=offset, size=new_size) + prev_block = self.explicit_list[-1] if len(self.explicit_list) > 0 else None + succ_block = None + + if not self._coalesce(new_free_block, prev_block, succ_block): + self.explicit_list.add(new_free_block) + + logger.info( + f"Expanded: {new_size} bytes, total: {self.composite_buffer.numel()}" + ) + + def _is_segment_boundary(self, offset: int) -> bool: + return offset in self.segment_boundaries + + def _can_merge_with_prev( + self, curr_block: FreeBlock, prev_block: FreeBlock + ) -> bool: + """Override: Add segment boundary check for prev merge.""" + return super()._can_merge_with_prev( + curr_block, prev_block + ) and not self._is_segment_boundary(prev_block.start + prev_block.size) + + def _can_merge_with_succ( + self, curr_block: FreeBlock, succ_block: FreeBlock + ) -> bool: + """Override: Add segment boundary check for succ merge.""" + return super()._can_merge_with_succ( + curr_block, succ_block + ) and not self._is_segment_boundary(curr_block.start + curr_block.size) + + def _get_buffer_slice(self, start: int, size: int) -> torch.Tensor: + """Override: Use composite buffer for multi-segment access.""" + return self.composite_buffer.get_slice(start, size) + + +class AsyncMemoryExpander: + """Asynchronously expands memory in background. + + Design Philosophy: + ================== + This is a ONE-WAY, EXPANSION-ONLY mechanism designed to: + 1. Reduce startup latency: Start with a small initial allocation + 2. Minimize initial memory footprint: Avoid allocating full capacity upfront + 3. Progressive growth: Expand memory as needed in the background + + Key Characteristics: + - NO SHRINKING: Once memory is allocated, it is never released back to + the system + - ONE-TIME EXPANSION: The expander thread runs until target size is + reached, then stops + - LAZY ALLOCATION: Memory is allocated progressively, not all at once + + This design is optimal for workloads with monotonically increasing memory + needs, where the memory will eventually be fully utilized and doesn't need + to be reclaimed. + + Thread Safety Overview: + ======================= + This class manages a background daemon thread (_expansion_worker) that + progressively allocates and adds new memory segments to the allocator. + + Concurrency Model: + - Main thread: Performs allocate/free operations on the allocator + - Expander thread: Adds new memory segments via expand_with_new_segment() + """ + + def __init__( + self, + composite_buffer: CompositeBuffer, + allocator: CompositeTensorMemoryAllocator, + total_size: int, + step_ratio: float, + host_mem_lock: threading.Lock, + numa_mapping: Optional[NUMAMapping] = None, + memory_limit_callback=None, + ): + self.composite_buffer = composite_buffer + self.allocator = allocator + self.total_size = total_size + self.step_ratio = step_ratio + self.numa_mapping = numa_mapping + self.memory_limit_callback = memory_limit_callback + self.host_mem_lock = host_mem_lock + self.expansion_thread: Optional[threading.Thread] = None + self.stop_flag = threading.Event() + self.expansion_lock = threading.Lock() + self.is_expanding = False + + def start_expansion(self): + with self.expansion_lock: + if self.is_expanding: + return + self.is_expanding = True + self.stop_flag.clear() + self.expansion_thread = threading.Thread( + target=self._expansion_worker, daemon=True, name="MemoryExpander" + ) + self.expansion_thread.start() + logger.info("Started async expansion") + + def _get_effective_limit(self, current_size: int) -> Optional[int]: + """Calculate the effective memory limit based on callback. + + Args: + current_size: Current allocated memory size in bytes + + Returns: + Effective memory limit in bytes, or None if expansion should stop + """ + if not self.memory_limit_callback: + return self.total_size + + try: + limit_bytes = self.memory_limit_callback() + if limit_bytes <= 0: + return self.total_size + + effective_limit = min(self.total_size, limit_bytes) + if current_size >= effective_limit: + logger.warning( + f"Expansion stopped: {current_size} >= {effective_limit}" + ) + return None + + return effective_limit + except Exception as e: + logger.warning(f"Memory limit callback failed: {e}") + return self.total_size + + def _expansion_worker(self): + """Background worker that progressively expands memory to target size. + + Runs in daemon thread. Allocates memory in steps (step_ratio at a time) + until total_size is reached or memory limit is hit. Never shrinks. + """ + try: + current_size = self.composite_buffer.numel() + while current_size < self.total_size and not self.stop_flag.is_set(): + effective_limit = self._get_effective_limit(current_size) + if effective_limit is None: + break + + next_size = min( + int(self.total_size * self.step_ratio), + effective_limit - current_size, + ) + if next_size <= 0: + break + + logger.info( + f"Expanding: +{next_size}, current={current_size}, " + f"target={self.total_size}" + ) + + try: + new_buffer = _allocate_cpu_memory(next_size, self.numa_mapping) + except Exception as e: + logger.error(f"Allocation failed: {e}") + break + + with self.host_mem_lock: + self.allocator.expand_with_new_segment(new_buffer) + + current_size += next_size + + logger.info(f"Expansion completed: {self.composite_buffer.numel()} bytes") + except Exception as e: + logger.error(f"Expansion error: {e}", exc_info=True) + finally: + with self.expansion_lock: + self.is_expanding = False + + def stop(self): + self.stop_flag.set() + if self.expansion_thread and self.expansion_thread.is_alive(): + self.expansion_thread.join(timeout=5.0) + + +class LazyMixedMemoryAllocator(MixedMemoryAllocator): + """Lazy allocator: starts small, expands async when needed (zero-copy). + + Starts with initial_ratio of target size, triggers one-time background + expansion when usage exceeds expand_trigger_ratio. Ideal for fast startup + with low initial memory footprint. + + See AsyncMemoryExpander for detailed design philosophy. + """ + + def __init__( + self, + size: int, + config: "LMCacheEngineConfig", + use_paging: bool = False, + memory_limit_callback: Optional[Callable] = None, + **kwargs, + ): + # Extract configuration values from config + initial_ratio = config.lazy_memory_initial_ratio + expand_trigger_ratio = config.lazy_memory_expand_trigger_ratio + step_ratio = config.lazy_memory_step_ratio + + self.total_size = size + self.initial_ratio = initial_ratio + self.expand_trigger_ratio = expand_trigger_ratio + self.step_ratio = step_ratio + self.memory_limit_callback = memory_limit_callback + self.expansion_triggered = False + self.initial_size = int(size * initial_ratio) + self.numa_mapping = kwargs.get("numa_mapping", None) + self.size = self.initial_size + self._unregistered = False + self.async_expander: Optional[AsyncMemoryExpander] + + if not use_paging: + initial_buffer = _allocate_cpu_memory(self.initial_size, self.numa_mapping) + self.composite_buffer = CompositeBuffer(initial_buffer) + self.buffer = initial_buffer + self.pin_allocator = CompositeTensorMemoryAllocator(self.composite_buffer) + self.align_bytes = self.pin_allocator.align_bytes + self.host_mem_lock = threading.Lock() + self.buffer_allocator = BufferAllocator("cpu") + self.async_expander = AsyncMemoryExpander( + self.composite_buffer, + self.pin_allocator, + self.total_size, + self.step_ratio, + self.host_mem_lock, + self.numa_mapping, + self.memory_limit_callback, + ) + else: + logger.warning( + "Paged allocation with lazy expansion not fully supported. " + "Using initial size only." + ) + super().__init__(self.initial_size, use_paging, **kwargs) + self.async_expander = None + + logger.info( + f"LazyAllocator: initial={self.initial_size}B " + f"({initial_ratio * 100:.0f}%), target={self.total_size}B, " + f"trigger={expand_trigger_ratio * 100:.0f}%, " + f"step={step_ratio * 100:.0f}%" + ) + + def _check_and_trigger_expansion(self): + if self.expansion_triggered or not self.async_expander: + return + if not isinstance(self.pin_allocator, CompositeTensorMemoryAllocator): + return + + usage_ratio = ( + self.pin_allocator.total_allocated_size / self.composite_buffer.numel() + ) + if usage_ratio >= self.expand_trigger_ratio: + logger.info( + f"Triggering expansion: usage={usage_ratio * 100:.0f}%, " + f"threshold={self.expand_trigger_ratio * 100:.0f}%" + ) + self.async_expander.start_expansion() + self.expansion_triggered = True + + @_lmcache_nvtx_annotate + def allocate( + self, + shape: Union[torch.Size, Tuple[int, ...]], + dtype: Optional[torch.dtype], + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + allocator_type: Optional[str] = None, + ) -> Optional[MemoryObj]: + result = super().allocate(shape, dtype, fmt, allocator_type) + if result and fmt != MemoryFormat.BINARY_BUFFER: + self._check_and_trigger_expansion() + return result + + @_lmcache_nvtx_annotate + def batched_allocate( + self, + shape: Union[torch.Size, Tuple[int, ...]], + dtype: Optional[torch.dtype], + batch_size: int, + fmt: MemoryFormat = MemoryFormat.KV_2LTD, + allocator_type: Optional[str] = None, + ) -> Optional[List[MemoryObj]]: + result = super().batched_allocate(shape, dtype, batch_size, fmt, allocator_type) + if result and fmt != MemoryFormat.BINARY_BUFFER: + self._check_and_trigger_expansion() + return result + + def close(self): + if hasattr(self, "async_expander") and self.async_expander: + self.async_expander.stop() + + if not self._unregistered: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if hasattr(self, "composite_buffer"): + for segment in self.composite_buffer.segments: + ptr = segment.data_ptr() + if self.numa_mapping: + lmc_ops.free_pinned_numa_ptr(ptr, segment.numel()) + else: + lmc_ops.free_pinned_ptr(ptr) + logger.info("LazyMixedMemoryAllocator closed and memory freed") + else: + # Fall back to parent's close for paging mode + super().close() + return + self._unregistered = True + + def __str__(self): + return "LazyMixedMemoryAllocator" diff --git a/lmcache/v1/lookup_client/__init__.py b/lmcache/v1/lookup_client/__init__.py index bcb066ad53c..5b2023b2f2a 100644 --- a/lmcache/v1/lookup_client/__init__.py +++ b/lmcache/v1/lookup_client/__init__.py @@ -6,12 +6,16 @@ LMCacheLookupClient, LMCacheLookupServer, ) +from lmcache.v1.lookup_client.lmcache_lookup_client_bypass import ( + LMCacheBypassLookupClient, +) from lmcache.v1.lookup_client.mooncake_lookup_client import MooncakeLookupClient __all__ = [ "LookupClientInterface", "LookupClientFactory", "MooncakeLookupClient", + "LMCacheBypassLookupClient", "LMCacheLookupClient", "LMCacheLookupServer", ] diff --git a/lmcache/v1/lookup_client/abstract_client.py b/lmcache/v1/lookup_client/abstract_client.py index e8b6315b9a0..754d4a520dc 100644 --- a/lmcache/v1/lookup_client/abstract_client.py +++ b/lmcache/v1/lookup_client/abstract_client.py @@ -51,3 +51,12 @@ def supports_producer_reuse(self) -> bool: True if producer reuse is supported, False otherwise """ return False + + def clear_lookup_status(self, lookup_id: str) -> None: + """ + Clear temporary lookup status for a given lookup ID. + + Args: + lookup_id: The lookup ID whose status needs to be cleared. + """ + return diff --git a/lmcache/v1/lookup_client/factory.py b/lmcache/v1/lookup_client/factory.py index 77b005f455b..045f37b1925 100644 --- a/lmcache/v1/lookup_client/factory.py +++ b/lmcache/v1/lookup_client/factory.py @@ -8,6 +8,9 @@ from lmcache.v1.config import LMCacheEngineConfig from lmcache.v1.lookup_client.abstract_client import LookupClientInterface from lmcache.v1.lookup_client.hit_limit_lookup_client import HitLimitLookupClient +from lmcache.v1.lookup_client.lmcache_lookup_client_bypass import ( + LMCacheBypassLookupClient, +) from lmcache.v1.lookup_client.mooncake_lookup_client import MooncakeLookupClient if TYPE_CHECKING: @@ -30,6 +33,7 @@ class LookupClientFactory: def create_lookup_client( vllm_config: "VllmConfig", config: LMCacheEngineConfig, + lmcache_engine: Optional[LMCacheEngine] = None, ) -> LookupClientInterface: """ Create a lookup client based on the configuration. @@ -37,11 +41,13 @@ def create_lookup_client( Args: vllm_config: The vLLM configuration config: The LMCache engine configuration + lmcache_engine: Optional LMCacheEngine instance for bypass lookup client Returns: A lookup client instance """ + client: LookupClientInterface # Check if external_lookup_client is configured if config.external_lookup_client is not None: if config.enable_async_loading: @@ -60,7 +66,10 @@ def create_lookup_client( LMCacheLookupClient, ) - if config.enable_async_loading: + # Check if bypass lookup is enabled and lmcache_engine is provided + if config.enable_scheduler_bypass_lookup and lmcache_engine is not None: + client = LMCacheBypassLookupClient(vllm_config, lmcache_engine) + elif config.enable_async_loading: client = LMCacheAsyncLookupClient(vllm_config) else: client = LMCacheLookupClient(vllm_config) @@ -89,16 +98,13 @@ def create_lookup_server( "LMCache v1 config is expected for lookup server and client" ) - # Only create the KV lookup API server on worker rank 0 - # when there are multiple workers and when not using external lookup client - create_lookup_server_only_on_worker_0_for_mla = config.get_extra_config_value( - "create_lookup_server_only_on_worker_0_for_mla", - lmcache_engine.metadata.use_mla, + lookup_server_worker_ids = config.get_lookup_server_worker_ids( + lmcache_engine.metadata.use_mla, lmcache_engine.metadata.world_size ) if config.external_lookup_client is None and ( - not create_lookup_server_only_on_worker_0_for_mla - or lmcache_engine.metadata.worker_id == 0 + len(lookup_server_worker_ids) == 0 + or lmcache_engine.metadata.worker_id in lookup_server_worker_ids ): # First Party from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( diff --git a/lmcache/v1/lookup_client/hit_limit_lookup_client.py b/lmcache/v1/lookup_client/hit_limit_lookup_client.py index a3ab9e0922b..90482aa5be7 100644 --- a/lmcache/v1/lookup_client/hit_limit_lookup_client.py +++ b/lmcache/v1/lookup_client/hit_limit_lookup_client.py @@ -66,6 +66,9 @@ def lookup( ) return result + def clear_lookup_status(self, lookup_id: str) -> None: + self.actual_lookup_client.clear_lookup_status(lookup_id) + def supports_producer_reuse(self) -> bool: return self.actual_lookup_client.supports_producer_reuse() diff --git a/lmcache/v1/lookup_client/lmcache_async_lookup_client.py b/lmcache/v1/lookup_client/lmcache_async_lookup_client.py index e4e0d1ccd47..780b30b5ded 100644 --- a/lmcache/v1/lookup_client/lmcache_async_lookup_client.py +++ b/lmcache/v1/lookup_client/lmcache_async_lookup_client.py @@ -5,17 +5,20 @@ import time # Third Party -from vllm.utils import make_zmq_socket import msgspec import torch import zmq # First Party -from lmcache.integration.vllm.utils import create_lmcache_metadata, mla_enabled +from lmcache.integration.vllm.utils import create_lmcache_metadata from lmcache.logging import init_logger from lmcache.v1.cache_engine import LMCacheEngine from lmcache.v1.lookup_client.abstract_client import LookupClientInterface -from lmcache.v1.rpc_utils import get_zmq_rpc_path_lmcache +from lmcache.v1.rpc_utils import ( + get_zmq_context, + get_zmq_rpc_path_lmcache, + get_zmq_socket, +) if TYPE_CHECKING: # Third Party @@ -31,8 +34,14 @@ class LMCacheAsyncLookupClient(LookupClientInterface): ZMQ-based lookup client that communicates with a lookup server. Related extra_config: - - create_lookup_server_only_on_worker_0_for_mla: - is a flag to control whether to create lookup server only on worker 0. + - lookup_server_worker_ids: + is a config to control create lookup server on some workers. + if mla is not enabled, default is []; + if mla is enabled, default is [0]; + - if lookup_server_worker_ids is [], start lookup server on all workers + - if lookup_server_worker_ids is [0], start lookup server on worker0 + - if lookup_server_worker_ids is [0, 3, 6], start lookup server on + worker0, worker3 and worker6 """ def __init__( @@ -42,35 +51,39 @@ def __init__( metadata, config = create_lmcache_metadata(vllm_config) self.encoder = msgspec.msgpack.Encoder() - self.ctx = zmq.Context() # type: ignore[attr-defined] + self.ctx = get_zmq_context(use_asyncio=False) rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( "lmcache_rpc_port", 0 ) + self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size self.tensor_parallel_size = vllm_config.parallel_config.tensor_parallel_size - use_mla = mla_enabled(vllm_config.model_config) - self.create_lookup_server_only_on_worker_0_for_mla = ( - config.get_extra_config_value( - "create_lookup_server_only_on_worker_0_for_mla", use_mla - ) + self.num_ranks = self.tensor_parallel_size * self.pipeline_parallel_size + self.lookup_server_worker_ids = config.get_lookup_server_worker_ids( + metadata.use_mla, metadata.world_size ) - ranks = self.tensor_parallel_size + self.push_sockets = [] - if self.create_lookup_server_only_on_worker_0_for_mla: - ranks = 1 - for tp_rank in range(ranks): + if len(self.lookup_server_worker_ids) > 0: + ranks = self.lookup_server_worker_ids + self.num_ranks = len(self.lookup_server_worker_ids) + else: + ranks = [i for i in range(self.num_ranks)] + + for rank in ranks: worker_socket_path = get_zmq_rpc_path_lmcache( - vllm_config, "lookup_worker", rpc_port, tp_rank + vllm_config, "lookup_worker", rpc_port, rank ) logger.info( - f"lmcache lookup client connect to tp_rank {tp_rank} " + f"lmcache lookup client connect to rank {rank} " f"with worker socket path {worker_socket_path}" ) - push_socket = make_zmq_socket( + push_socket = get_zmq_socket( self.ctx, worker_socket_path, + "ipc", zmq.PUSH, # type: ignore[attr-defined] - bind=False, + "connect", ) self.push_sockets.append(push_socket) @@ -78,11 +91,12 @@ def __init__( scheduler_socket_path = get_zmq_rpc_path_lmcache( vllm_config, "lookup_scheduler", rpc_port, 0 ) - self.pull_socket = make_zmq_socket( + self.pull_socket = get_zmq_socket( self.ctx, scheduler_socket_path, + "ipc", zmq.PULL, # type: ignore[attr-defined] - bind=True, + "bind", ) logger.info( f"lmcache lookup client connect to scheduler " @@ -107,15 +121,15 @@ def __init__( # (e.g., worker process). self.lock = threading.Lock() - # map from lookup_id to req's status. + # map from lookup_id (i.e., req_id) to req's status. # None indicates ongoing. # int indicates number of hit tokens. self.reqs_status: dict[str, Optional[int]] = {} - # map from lookup_id to number of hit tokens for each worker + # map from lookup_id (i.e., req_id) to number of hit tokens for each worker self.res_for_each_worker: dict[str, list[int]] = {} - # The two parts are [lookup_id, num_hit_tokens] + # The two parts are [lookup_id (i.e., req_id), num_hit_tokens] self.num_parts = 2 self.running = True @@ -146,7 +160,6 @@ def lookup( time.sleep(self.lookup_backoff_time) return None elif req_status != -1: - self.reqs_status.pop(lookup_id) return req_status self.reqs_status[lookup_id] = None hashes = [] @@ -174,10 +187,7 @@ def lookup( request_configs_buf, ] - ranks = self.tensor_parallel_size - if self.create_lookup_server_only_on_worker_0_for_mla: - ranks = 1 - for i in range(ranks): + for i in range(self.num_ranks): self.push_sockets[i].send_multipart(msg_buf, copy=False) time.sleep(self.lookup_backoff_time) return None @@ -196,18 +206,19 @@ def process_responses_from_workers(self): self.res_for_each_worker[lookup_id].append(res) all_res = self.res_for_each_worker[lookup_id] - if len(all_res) == self.tensor_parallel_size or ( - self.create_lookup_server_only_on_worker_0_for_mla - and len(all_res) == 1 - ): + if len(all_res) == self.num_ranks: self.res_for_each_worker.pop(lookup_id) # NOTE: it is possible that the number of hit - # tokens is different across TP ranks, so we + # tokens is different across (TP and PP) ranks, so we # can use the minimum value as the number of # hit tokens. self.reqs_status[lookup_id] = min(all_res) + def clear_lookup_status(self, lookup_id: str) -> None: + with self.lock: + self.reqs_status.pop(lookup_id, None) + def supports_producer_reuse(self) -> bool: """Return True as LMCacheLookupClient supports producer kvcache reuse""" return True @@ -241,17 +252,19 @@ def __init__(self, lmcache_engine: LMCacheEngine, vllm_config: "VllmConfig"): scheduler_socket_path = get_zmq_rpc_path_lmcache( vllm_config, "lookup_scheduler", rpc_port, 0 ) - self.push_socket = make_zmq_socket( + self.push_socket = get_zmq_socket( self.ctx, scheduler_socket_path, + "ipc", zmq.PUSH, # type: ignore[attr-defined] - bind=False, + "connect", ) - self.pull_socket = make_zmq_socket( + self.pull_socket = get_zmq_socket( self.ctx, worker_socket_path, + "ipc", zmq.PULL, # type: ignore[attr-defined] - bind=True, + "bind", ) self.lmcache_engine = lmcache_engine diff --git a/lmcache/v1/lookup_client/lmcache_lookup_client.py b/lmcache/v1/lookup_client/lmcache_lookup_client.py index ff6e0df849b..0627b68cf83 100644 --- a/lmcache/v1/lookup_client/lmcache_lookup_client.py +++ b/lmcache/v1/lookup_client/lmcache_lookup_client.py @@ -1,21 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from collections import namedtuple from typing import TYPE_CHECKING, Optional, Union import json import threading # Third Party -from vllm.utils import make_zmq_socket import msgspec import torch import zmq # First Party -from lmcache.integration.vllm.utils import create_lmcache_metadata, mla_enabled +from lmcache.integration.vllm.utils import create_lmcache_metadata from lmcache.logging import init_logger from lmcache.v1.cache_engine import LMCacheEngine from lmcache.v1.lookup_client.abstract_client import LookupClientInterface -from lmcache.v1.rpc_utils import get_zmq_rpc_path_lmcache +from lmcache.v1.rpc_utils import ( + get_zmq_context, + get_zmq_rpc_path_lmcache, + get_zmq_socket, +) if TYPE_CHECKING: # Third Party @@ -29,8 +33,14 @@ class LMCacheLookupClient(LookupClientInterface): ZMQ-based lookup client that communicates with a lookup server. Related extra_config: - - create_lookup_server_only_on_worker_0_for_mla: - is a flag to control whether to create lookup server only on worker 0. + - lookup_server_worker_ids: + is a config to control create lookup server on some workers. + if mla is not enabled, default is []; + if mla is enabled, default is [0]; + - if lookup_server_worker_ids is [], start lookup server on all workers + - if lookup_server_worker_ids is [0], start lookup server on worker0 + - if lookup_server_worker_ids is [0, 3, 6], start lookup server on + worker0, worker3 and worker6 """ def __init__( @@ -40,44 +50,62 @@ def __init__( metadata, config = create_lmcache_metadata(vllm_config) self.encoder = msgspec.msgpack.Encoder() - self.ctx = zmq.Context() # type: ignore[attr-defined] + self.ctx = get_zmq_context(use_asyncio=False) self.config = config rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( "lmcache_rpc_port", 0 ) + self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size self.tensor_parallel_size = vllm_config.parallel_config.tensor_parallel_size - use_mla = mla_enabled(vllm_config.model_config) - self.create_lookup_server_only_on_worker_0_for_mla = ( - config.get_extra_config_value( - "create_lookup_server_only_on_worker_0_for_mla", use_mla - ) + self.num_ranks = self.tensor_parallel_size * self.pipeline_parallel_size + self.lookup_server_worker_ids = config.get_lookup_server_worker_ids( + metadata.use_mla, metadata.world_size ) - ranks = self.tensor_parallel_size - self.sockets = [] - if self.create_lookup_server_only_on_worker_0_for_mla: - ranks = 1 - # Set timeout values from config - timeout_ms = config.lookup_timeout_ms + self.sockets = [] + if len(self.lookup_server_worker_ids) > 0: + ranks = self.lookup_server_worker_ids + self.num_ranks = len(self.lookup_server_worker_ids) + else: + ranks = [i for i in range(self.num_ranks)] - for tp_rank in range(ranks): - socket_path = get_zmq_rpc_path_lmcache( - vllm_config, "lookup", rpc_port, tp_rank + # Store socket creation parameters for recreation + SocketParams = namedtuple("SocketParams", ["socket_path", "rank"]) + self.socket_params = [ + SocketParams( + socket_path=get_zmq_rpc_path_lmcache( + vllm_config, "lookup", rpc_port, rank + ), + rank=rank, ) + for rank in ranks + ] + self.timeout_ms = config.lookup_timeout_ms + + # NOTE: map from lookup_id (i.e., req_id) to req's status. + # int indicates number of hit tokens. + # The assumption here is that once a request is looked up, + # the following lookups of the same request must have the + # same result. + self.reqs_status: dict[str, int] = {} + + for params in self.socket_params: logger.info( - f"lmcache lookup client connect to tp_rank {tp_rank} " - f"with socket path {socket_path}" + "lmcache lookup client connect to rank %s with socket path %s", + params.rank, + params.socket_path, ) - socket = make_zmq_socket( + socket = get_zmq_socket( self.ctx, - socket_path, - zmq.REQ, # type: ignore[attr-defined] - bind=False, + params.socket_path, + "ipc", + zmq.REQ, + "connect", ) # Set socket timeout during initialization - socket.setsockopt(zmq.RCVTIMEO, timeout_ms) - socket.setsockopt(zmq.SNDTIMEO, timeout_ms) + socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) + socket.setsockopt(zmq.SNDTIMEO, self.timeout_ms) self.sockets.append(socket) @@ -95,21 +123,59 @@ def __init__( else: self.token_database = ChunkedTokenDatabase(config, metadata) - # FIXME(Jiayi): Cacheblend need token ids + def _recreate_socket(self) -> None: + """Recreate all sockets.""" + for rank_idx in range(self.num_ranks): + # Close old socket + old_socket = self.sockets[rank_idx] + if old_socket is not None: + try: + old_socket.close(linger=0) + except zmq.ZMQError as e: + logger.warning( + "ZMQ error closing old socket for rank %s: %s", + rank_idx, + e, + ) + except AttributeError: + # Socket already closed or invalid + pass + + # Create new socket using stored parameters + params = self.socket_params[rank_idx] + logger.info( + "Recreating socket for rank %s with path %s", + params.rank, + params.socket_path, + ) + + new_socket = get_zmq_socket( + self.ctx, + params.socket_path, + "ipc", + zmq.REQ, + "connect", + ) + new_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) + new_socket.setsockopt(zmq.SNDTIMEO, self.timeout_ms) + + self.sockets[rank_idx] = new_socket + def lookup( self, token_ids: Union[torch.Tensor, list[int]], lookup_id: str, request_configs: Optional[dict] = None, ) -> Optional[int]: + cached_num_hit_toks = self.reqs_status.get(lookup_id, None) + if cached_num_hit_toks is not None: + return cached_num_hit_toks + lookup_id_buf = lookup_id.encode("utf-8") request_configs_str = "" if request_configs is not None and len(request_configs) != 0: request_configs_str = json.dumps(request_configs) request_configs_buf = request_configs_str.encode("utf-8") - ranks = self.tensor_parallel_size - if self.create_lookup_server_only_on_worker_0_for_mla: - ranks = 1 # NOTE(Jiayi): We cannot only send hashes when blending enabled # because the blender need the input embedding. @@ -130,6 +196,7 @@ def lookup( request_configs_buf, ] else: + # print(len(token_ids)) tokens_buf = self.encoder.encode(token_ids) msg_buf = [ tokens_buf, @@ -138,32 +205,52 @@ def lookup( ] results = [] + failed_rank = -1 try: - for i in range(ranks): + for i in range(self.num_ranks): + failed_rank = i self.sockets[i].send_multipart(msg_buf, copy=False) # TODO(Jiayi): we can use zmq poll to optimize a bit - for i in range(ranks): + for i in range(self.num_ranks): + failed_rank = i resp = self.sockets[i].recv() result = int.from_bytes(resp, "big") results.append(result) - except zmq.Again: - logger.error(f"Timeout occurred for rank {i}") + except zmq.Again as e: + logger.error( + "Timeout occurred for rank %s, recreating all sockets. Error: %s", + failed_rank, + e, + ) + self._recreate_socket() return 0 except zmq.ZMQError as e: - logger.error(f"ZMQ error for rank {i}: {str(e)}") + logger.error( + "ZMQ error for rank %s: %s, recreating all sockets", + failed_rank, + e, + ) + self._recreate_socket() return 0 - assert len(results) == ranks + assert len(results) == self.num_ranks if len(set(results)) > 1: logger.warning( - f"Lookup results (number of hit tokens) differ " - f"across tensor parallel ranks: {results}." + "Lookup results (number of hit tokens) differ " + "across (TP and PP) ranks: %s.", + results, ) # NOTE: it is possible that the number of hit tokens is different - # across TP ranks, so we can use the minimum value as the + # across (TP and PP) ranks, so we can use the minimum value as the # number of hit tokens. - return min(results) + num_hit_toks = min(results) + self.reqs_status[lookup_id] = num_hit_toks + + return num_hit_toks + + def clear_lookup_status(self, lookup_id: str) -> None: + self.reqs_status.pop(lookup_id, None) def supports_producer_reuse(self) -> bool: """Return True as LMCacheLookupClient supports producer kvcache reuse""" @@ -174,13 +261,13 @@ def close(self): try: socket.close(linger=0) except Exception as e: - logger.warning(f"Error closing socket: {e}") + logger.warning("Error closing socket: %s", e) try: if self.ctx: self.ctx.term() except Exception as e: - logger.warning(f"Error terminating ZMQ context: {e}") + logger.warning("Error terminating ZMQ context: %s", e) class LMCacheLookupServer: @@ -195,11 +282,12 @@ def __init__(self, lmcache_engine: LMCacheEngine, vllm_config: "VllmConfig"): socket_path = get_zmq_rpc_path_lmcache( vllm_config, "lookup", rpc_port, vllm_config.parallel_config.rank ) - self.socket = make_zmq_socket( + self.socket = get_zmq_socket( self.ctx, socket_path, + "ipc", zmq.REP, # type: ignore[attr-defined] - bind=True, + "bind", ) self.lmcache_engine = lmcache_engine @@ -239,7 +327,7 @@ def process_request(): response = result.to_bytes(4, "big") self.socket.send(response) - logger.info(f"lmcache lookup server start on {socket_path}") + logger.info("lmcache lookup server start on %s", socket_path) self.thread = threading.Thread(target=process_request, daemon=True) self.thread.start() diff --git a/lmcache/v1/lookup_client/lmcache_lookup_client_bypass.py b/lmcache/v1/lookup_client/lmcache_lookup_client_bypass.py new file mode 100644 index 00000000000..6ddcd74425b --- /dev/null +++ b/lmcache/v1/lookup_client/lmcache_lookup_client_bypass.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import TYPE_CHECKING, Optional, Union + +# Third Party +import torch + +# First Party +from lmcache.integration.vllm.utils import create_lmcache_metadata +from lmcache.logging import init_logger +from lmcache.v1.cache_engine import LMCacheEngine +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.lookup_client.abstract_client import LookupClientInterface + +if TYPE_CHECKING: + # Third Party + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class LMCacheBypassLookupClient(LookupClientInterface): + """ + Bypass lookup client that directly calls LMCacheEngine.lookup() + instead of using ZMQ communication. This is particularly useful + for MLA scenarios where only rank 0 needs to perform lookups. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + lmcache_engine: LMCacheEngine, + ): + """ + Initialize the bypass lookup client. + + Args: + vllm_config: The vLLM configuration + lmcache_engine: The LMCacheEngine instance to use for lookups + """ + metadata, config = create_lmcache_metadata(vllm_config) + + assert isinstance(config, LMCacheEngineConfig), ( + "LMCache v1 configuration should be passed." + ) + + self.lmcache_engine = lmcache_engine + self.config = config + + # Use the token database from the provided LMCacheEngine + self.token_database = self.lmcache_engine.token_database + self.enable_blending = self.config.enable_blending + + logger.info("LMCacheBypassLookupClient initialized") + + def lookup( + self, + token_ids: Union[torch.Tensor, list[int]], + lookup_id: str, + request_configs: Optional[dict] = None, + ) -> Optional[int]: + try: + if not self.enable_blending: + # Process tokens to get hashes and offsets + hashes = [] + offsets = [] + for start, end, key in self.token_database.process_tokens( + token_ids, make_key=False + ): + hashes.append(key) + offsets.append(end - start) + + # Call LMCacheEngine lookup with hashes and offsets + result = self.lmcache_engine.lookup( + hashes=hashes, + offsets=offsets, + lookup_id=lookup_id, + pin=True, + request_configs=request_configs, + ) + else: + # For blending mode, pass tokens directly + result = self.lmcache_engine.lookup( + tokens=token_ids, + lookup_id=lookup_id, + pin=True, + request_configs=request_configs, + ) + + return result + + except Exception as e: + logger.error(f"Error in bypass lookup: {e}") + return 0 + + def supports_producer_reuse(self) -> bool: + return True + + def close(self): + # No resources to clean up for bypass client + logger.info("LMCacheBypassLookupClient closed") diff --git a/lmcache/v1/memory_management.py b/lmcache/v1/memory_management.py index 8411254e069..48dd2a8847b 100644 --- a/lmcache/v1/memory_management.py +++ b/lmcache/v1/memory_management.py @@ -301,6 +301,8 @@ def _allocate_cpu_memory( size: int, numa_mapping: Optional[NUMAMapping] = None, ) -> torch.Tensor: + if size == 0: + return torch.empty(0, dtype=torch.uint8) if numa_mapping: if torch.cuda.is_available(): current_device_id = torch.cuda.current_device() @@ -336,8 +338,8 @@ def __init__( parent_allocator: Optional["MemoryAllocatorInterface"], ): assert metadata.dtype is not None, "dtype must be specified for TensorMemoryObj" + super().__init__(metadata) self.raw_data = raw_data - self.meta = metadata self.valid = True self.lock = threading.Lock() self.parent_allocator = parent_allocator @@ -486,7 +488,7 @@ def __init__(self, raw_bytes: bytes, metadata: Optional[MemoryObjMetadata] = Non self.raw_data = raw_bytes if metadata is None: bytes_shape = torch.Size([len(self.raw_data), 0, 0, 0]) - self.meta = MemoryObjMetadata( + metadata = MemoryObjMetadata( shape=bytes_shape, dtype=None, address=0, @@ -495,8 +497,7 @@ def __init__(self, raw_bytes: bytes, metadata: Optional[MemoryObjMetadata] = Non pin_count=0, fmt=MemoryFormat.BINARY_BUFFER, ) - else: - self.meta = metadata + super().__init__(metadata) self.valid = True def invalidate(self): @@ -708,6 +709,18 @@ def _Compute_raw_size(shape: torch.Size, dtype: torch.dtype) -> int: def _Compute_aligned_size(raw_size: int, align: int) -> int: return (raw_size + align - 1) & ~(align - 1) + def _can_merge_with_prev( + self, curr_block: FreeBlock, prev_block: FreeBlock + ) -> bool: + """Hook: Check if curr_block can merge with prev_block.""" + return prev_block.can_be_coalesced(curr_block) + + def _can_merge_with_succ( + self, curr_block: FreeBlock, succ_block: FreeBlock + ) -> bool: + """Hook: Check if curr_block can merge with succ_block.""" + return curr_block.can_be_coalesced(succ_block) + @_lmcache_nvtx_annotate def _coalesce( self, @@ -721,15 +734,12 @@ def _coalesce( Returns True if the current block was coalesced, otherwise False. """ - if prev_block is not None and prev_block.can_be_coalesced(curr_block): - merge_prev = True - else: - merge_prev = False - - if succ_block is not None and curr_block.can_be_coalesced(succ_block): - merge_succ = True - else: - merge_succ = False + merge_prev = prev_block is not None and self._can_merge_with_prev( + curr_block, prev_block + ) + merge_succ = succ_block is not None and self._can_merge_with_succ( + curr_block, succ_block + ) if merge_prev and merge_succ: prev_block.size += curr_block.size + succ_block.size # type: ignore @@ -798,14 +808,19 @@ def allocate( self.stats_monitor.update_active_memory_objs_count(self.num_active_allocations) # Allocate the block + raw_data = self._get_buffer_slice(block.start, raw_size) return TensorMemoryObj( - raw_data=self.buffer[block.start : block.start + raw_size], + raw_data=raw_data, metadata=MemoryObjMetadata( - shape, dtype, block.start, aligned_size, 1, False, fmt + shape, dtype, block.start, aligned_size, 1, 0, fmt ), parent_allocator=self, ) + def _get_buffer_slice(self, start: int, size: int) -> torch.Tensor: + """Hook: Get buffer slice. Override for custom buffer access.""" + return self.buffer[start : start + size] + @_lmcache_nvtx_annotate def batched_allocate( self, @@ -876,7 +891,7 @@ def batched_allocate( TensorMemoryObj( raw_data=raw_data, metadata=MemoryObjMetadata( - shape, dtype, temp_start, unit_aligned_size, 1, False, fmt + shape, dtype, temp_start, unit_aligned_size, 1, 0, fmt ), parent_allocator=self, ) @@ -1407,10 +1422,13 @@ def __init__(self, size: int, use_paging: bool = False, **kwargs): :param int size: The size of the pinned memory in bytes. """ - ptr = lmc_ops.alloc_pinned_ptr(size, 0) - array_type = ctypes.c_uint8 * size - buf = array_type.from_address(ptr) - self.buffer = torch.frombuffer(buf, dtype=torch.uint8) + if size == 0: + self.buffer = torch.empty(0, dtype=torch.uint8) + else: + ptr = lmc_ops.alloc_pinned_ptr(size, 0) + array_type = ctypes.c_uint8 * size + buf = array_type.from_address(ptr) + self.buffer = torch.frombuffer(buf, dtype=torch.uint8) self._unregistered = False self.allocator: MemoryAllocatorInterface @@ -1481,6 +1499,8 @@ def close(self): if not self._unregistered: if torch.cuda.is_available(): torch.cuda.synchronize() + if self.buffer.numel() == 0: + return lmc_ops.free_pinned_ptr(self.buffer.data_ptr()) self._unregistered = True @@ -1622,6 +1642,8 @@ def close(self): if not self._unregistered: if torch.cuda.is_available(): torch.cuda.synchronize() + if self.buffer.numel() == 0: + return if self.numa_mapping: lmc_ops.free_pinned_numa_ptr(self.buffer.data_ptr(), self.size) else: @@ -1878,6 +1900,7 @@ def init_cpu_memory_allocator( dtype, fmt, ) + self.align_bytes = self.cpu_allocator.align_bytes def allocate( self, diff --git a/.buildkite/configs/gds.yaml b/lmcache/v1/multiprocess/__init__.py similarity index 100% rename from .buildkite/configs/gds.yaml rename to lmcache/v1/multiprocess/__init__.py diff --git a/lmcache/v1/multiprocess/custom_types.py b/lmcache/v1/multiprocess/custom_types.py new file mode 100644 index 00000000000..bf87cbed4bb --- /dev/null +++ b/lmcache/v1/multiprocess/custom_types.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from dataclasses import dataclass +from typing import Any, Callable +import pickle + +# Third Party +import msgspec +import torch + +""" +Defines the types and the customized encoder/decoders for inter-process +communications. +""" + + +class CudaIPCWrapper: + def __init__(self, tensor: torch.Tensor): + assert tensor.storage_offset() == 0 + assert tensor.is_contiguous() + storage = tensor.untyped_storage() + handle = storage._share_cuda_() + + self.handle = handle + self.dtype = tensor.dtype + self.shape = tensor.shape + self.device = tensor.device.index # Explicit device ordinal + + def to_tensor(self): + """ + Note: + This function may break if torch cuda is not initialized. + We should call `torch.cuda.init()` before using this function. + """ + device = self.handle[0] + storage = torch.UntypedStorage._new_shared_cuda(*self.handle) + t = torch.tensor(0, device=device, dtype=self.dtype) + t.set_(storage) + return t.view(self.shape) + + def __eq__(self, other): + if not isinstance(other, CudaIPCWrapper): + return False + return ( + self.handle == other.handle + and self.dtype == other.dtype + and self.shape == other.shape + and self.device == other.device + ) + + @staticmethod + def Serialize(obj: "CudaIPCWrapper") -> bytes: + return pickle.dumps(obj) + + @staticmethod + def Deserialize(data: bytes) -> "CudaIPCWrapper": + return pickle.loads(data) + + +@dataclass(order=True, frozen=True) +class IPCCacheEngineKey: + model_name: str + world_size: int + worker_id: int + chunk_hash: bytes + + @staticmethod + def IntHash2Bytes(chunk_hash: int) -> bytes: + return chunk_hash.to_bytes(4, byteorder="big") + + @staticmethod + def Bytes2IntHash(chunk_hash: bytes) -> int: + return int.from_bytes(chunk_hash, byteorder="big") & ((1 << 64) - 1) + + @classmethod + def from_int_hash( + cls, model_name: str, world_size: int, worker_id: int, chunk_hash: int + ) -> "IPCCacheEngineKey": + return cls( + model_name=model_name, + world_size=world_size, + worker_id=worker_id, + chunk_hash=cls.IntHash2Bytes(chunk_hash), + ) + + @staticmethod + def Serialize(obj: "IPCCacheEngineKey") -> bytes: + return msgspec.msgpack.encode(obj) + + @staticmethod + def Deserialize(data: bytes) -> "IPCCacheEngineKey": + return msgspec.msgpack.decode(data, type=IPCCacheEngineKey) + + +# Type exports +KVCache = list[CudaIPCWrapper] + + +@dataclass +class CustomizedSerdeConfig: + serializer: Callable[[Any], bytes] + deserializer: Callable[[bytes], Any] + code: int + + +_CUSTOMERIZED_SERIALIZERS = { + CudaIPCWrapper: CustomizedSerdeConfig( + serializer=CudaIPCWrapper.Serialize, + deserializer=CudaIPCWrapper.Deserialize, + code=1, + ), +} + + +def get_customized_encoder(type: Any) -> msgspec.msgpack.Encoder: + # TODO: `type` is not used here + def enc_hook(obj: Any) -> Any: + for supported_type, cfg in _CUSTOMERIZED_SERIALIZERS.items(): + if isinstance(obj, supported_type): + data = cfg.serializer(obj) + return msgspec.msgpack.Ext(cfg.code, data) + raise TypeError(f"Unsupported type for serialization: {type(obj)}") + + return msgspec.msgpack.Encoder(enc_hook=enc_hook) + + +def get_customized_decoder(type: Any) -> msgspec.msgpack.Decoder: + def ext_hook(code: int, data: bytes) -> Any: + for cfg in _CUSTOMERIZED_SERIALIZERS.values(): + if cfg.code == code: + return cfg.deserializer(data) + raise TypeError(f"Unsupported ext code for deserialization: {code}") + + return msgspec.msgpack.Decoder(ext_hook=ext_hook, type=type) diff --git a/lmcache/v1/multiprocess/mq.py b/lmcache/v1/multiprocess/mq.py new file mode 100644 index 00000000000..2ce886f24e4 --- /dev/null +++ b/lmcache/v1/multiprocess/mq.py @@ -0,0 +1,548 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Callable, Generic, Optional, TypeVar +import queue +import threading +import uuid + +# Third Party +import msgspec +import zmq + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.multiprocess.custom_types import ( + CudaIPCWrapper, + get_customized_decoder, + get_customized_encoder, +) +from lmcache.v1.multiprocess.protocol import ( + HandlerType, + RequestType, + get_payload_classes, + get_response_class, +) + +logger = init_logger(__name__) + +T = TypeVar("T") + +# Internal type used for the client-server communication +RequestUID = int + + +# Helper functions +def encode_request_uid(uid: RequestUID) -> bytes: + return msgspec.msgpack.encode(uid) + + +def decode_request_uid(b_uid: bytes) -> RequestUID: + return msgspec.msgpack.decode(b_uid, type=RequestUID) + + +def unwrap_request_payloads( + b_payloads: list[bytes], payload_clss: list[Any] +) -> list[Any]: + if len(b_payloads) != len(payload_clss): + raise ValueError("Payload count does not match expected count") + + decoded_payloads = [ + msgspec_decode(payload, cls=cls) + for payload, cls in zip(b_payloads, payload_clss, strict=False) + ] + return decoded_payloads + + +_SPECIAL_ENCODER_DECODERS = { + CudaIPCWrapper: ( + get_customized_encoder(CudaIPCWrapper), + get_customized_decoder(CudaIPCWrapper), + ), + list[CudaIPCWrapper]: ( + get_customized_encoder(list[CudaIPCWrapper]), + get_customized_decoder(list[CudaIPCWrapper]), + ), +} + + +def msgspec_encode(obj: Any, cls: Any) -> bytes: + # Handle special cases + if cls in _SPECIAL_ENCODER_DECODERS: + encoder, _ = _SPECIAL_ENCODER_DECODERS[cls] + return encoder.encode(obj) + return msgspec.msgpack.encode(obj) + + +def msgspec_decode(b_obj: bytes, cls: Any) -> Any: + # Handle special cases + if cls in _SPECIAL_ENCODER_DECODERS: + _, decoder = _SPECIAL_ENCODER_DECODERS[cls] + return decoder.decode(b_obj) + return msgspec.msgpack.decode(b_obj, type=cls) + + +# Main classes +class MessagingFuture(Generic[T]): + def __init__(self): + self.is_done_ = threading.Event() + self.result_ = None + + def query(self) -> bool: + """ + Check if the future is done. + + Returns: + bool: True if the future is done, False otherwise. + """ + return self.is_done_.is_set() + + def wait(self, timeout: Optional[float] = None) -> bool: + """ + Wait for the future to be done. + + Args: + timeout (Optional[float]): Maximum time to wait in seconds. + If None, wait indefinitely. + + Returns: + bool: True if the future is done, False if the timeout was reached. + """ + return self.is_done_.wait(timeout) + + def result(self, timeout: Optional[float] = None) -> T: + """ + Get the result of the future. + + Args: + timeout (Optional[float]): Maximum time to wait in seconds. + If None, wait indefinitely. + + Returns: + T: The result of the future. + + Raises: + TimeoutError: If the future is not done within the timeout. + """ + flag = self.wait(timeout) + if not flag: + raise TimeoutError("Future result not available within timeout") + return self.result_ + + def set_result(self, result: T) -> None: + self.result_ = result + self.is_done_.set() + + +class MessageQueueClient: + @dataclass + class WrappedRequest: + request_uid: RequestUID + future: MessagingFuture[Any] + request_type: RequestType + request_payloads: list[Any] + + def __init__(self, server_url: str, context: zmq.Context): + # Socket + self.ctx = context + self.socket = self.ctx.socket(zmq.DEALER) + self.socket.connect(server_url) + + # Input queue + self.task_notifier, self.task_waiter = self._prepare_task_sockets() + self.input_queue: queue.Queue = queue.Queue() + + # Poller + self.poller = zmq.Poller() + self.poller.register(self.socket, zmq.POLLIN) + self.poller.register(self.task_waiter, zmq.POLLIN) + + # main thread + self.is_finished = threading.Event() + self.worker_thread = threading.Thread(target=self._main_loop, daemon=True) + self.worker_thread.start() + + # Pending job's futures + self.request_counter = 0 + self.pending_futures: dict[int, MessagingFuture[Any]] = {} + + def _prepare_task_sockets(self) -> tuple[zmq.Socket, zmq.Socket]: + """Create 2 inproc socket pair for the zmq-poller compatible task + queue + + Returns: + tuple[zmq.Socket, zmq.Socket]: The (push_socket, pull_socket) + """ + inproc_url = "inproc://mq_client_task_queue/" + str(uuid.uuid4()) + push_socket = self.ctx.socket(zmq.PUSH) + pull_socket = self.ctx.socket(zmq.PULL) + pull_socket.bind(inproc_url) + push_socket.connect(inproc_url) + return push_socket, pull_socket + + def _process_outbound_task(self): + try: + while wrapped_request := self.input_queue.get_nowait(): + # wrapped_request = self.input_queue.get_nowait() + + # Update the pending futures + request_uid = wrapped_request.request_uid + self.pending_futures[request_uid] = wrapped_request.future + + # Send the request + b_request_uid = msgspec_encode(request_uid, cls=RequestUID) + b_request_type = msgspec_encode( + wrapped_request.request_type, cls=RequestType + ) + payload_classes = get_payload_classes(wrapped_request.request_type) + if len(payload_classes) != len(wrapped_request.request_payloads): + raise ValueError("Payload count does not match expected count") + + b_payloads = [ + msgspec_encode(payload, cls=cls) + for payload, cls in zip( + wrapped_request.request_payloads, + payload_classes, + strict=False, + ) + ] + self.socket.send_multipart([b_request_uid, b_request_type] + b_payloads) + except queue.Empty: + pass + + def _main_loop(self): + # NOTE: make sure we only edit the pending_futures dict in this thread + while not self.is_finished.is_set(): + socks = dict(self.poller.poll(1000)) + inbound_state = socks.get(self.socket, None) + outbound_state = socks.get(self.task_waiter, None) + + if outbound_state and outbound_state & zmq.POLLIN: + # Drain the notifier + while True: + try: + self.task_waiter.recv(zmq.DONTWAIT) + except zmq.Again: + break + + # Process the output tasks + self._process_outbound_task() + + if inbound_state and inbound_state & zmq.POLLIN: + msg = self.socket.recv_multipart() + assert len(msg) >= 2, ( + "Expected at least 2 message part " + "[request_uid, request_type, *response]" + ) + b_request_uid, b_request_type, *b_response = msg + request_uid = msgspec_decode(b_request_uid, cls=RequestUID) + request_type = msgspec_decode(b_request_type, cls=RequestType) + response_cls = get_response_class(request_type) + + # TODO: we need a typing system for responses + if request_uid in self.pending_futures: + future = self.pending_futures.pop(request_uid) + if b_response: + response = msgspec_decode(b_response[0], cls=response_cls) + future.set_result(response) + else: + future.set_result(None) + + def submit_request( + self, + request_type: RequestType, + request_payloads: list[Any], + response_cls: Optional[T] = None, + ) -> MessagingFuture[T]: + """Submit a request to the server. + + Args: + request_type (RequestType): The type of the request. + request_payloads (list[Any]): The payloads of the request. + response_cls (Optional[T]): The expected response class. + This should be get from `get_response_class(request_type)`. + + Returns: + MessagingFuture[T]: A future that will hold the response. + """ + future: MessagingFuture[T] = MessagingFuture() + request_uid = self.request_counter + self.request_counter += 1 + self.input_queue.put( + MessageQueueClient.WrappedRequest( + request_uid=request_uid, + future=future, + request_type=request_type, + request_payloads=request_payloads, + ) + ) + self.task_notifier.send(b"1") + return future + + def close(self) -> None: + self.is_finished.set() + self.worker_thread.join() + self.socket.close() + + +ResponseType = TypeVar("ResponseType", covariant=True) +StateType = TypeVar("StateType", covariant=True) + + +class RequestHandlerBase(Generic[ResponseType]): + def __call__(self, payloads: list[bytes]): + raise NotImplementedError + + def get_response_class(self) -> ResponseType: + raise NotImplementedError + + def get_handler_type(self) -> HandlerType: + raise NotImplementedError + + +class SyncRequestHandler(RequestHandlerBase[ResponseType]): + """ + The handler for those "fast" functions that can be executed in the main loop + """ + + def __init__( + self, + payload_clss: list[Any], + response_cls: ResponseType, + handler: Callable[..., ResponseType], + ): + self.payload_clss = payload_clss + self.response_cls = response_cls + self.handler = handler + + def __call__(self, payloads: list[bytes]) -> ResponseType: + return self.handler(*unwrap_request_payloads(payloads, self.payload_clss)) + + def get_response_class(self) -> ResponseType: + return self.response_cls + + def get_handler_type(self) -> HandlerType: + return HandlerType.SYNC + + +class BlockingRequestHandler(RequestHandlerBase[ResponseType]): + """ + Returns the future of the response. + """ + + def __init__( + self, + executor: ThreadPoolExecutor, + payload_clss: list[Any], + response_cls: ResponseType, + handler: Callable[..., ResponseType], + ): + self.executor = executor + self.payload_clss = payload_clss + self.handler = handler + self.response_cls = response_cls + + def __call__(self, payloads: list[bytes]) -> Future[ResponseType]: + decoded_payloads = unwrap_request_payloads(payloads, self.payload_clss) + future = self.executor.submit(self.handler, *decoded_payloads) + return future + + def get_response_class(self) -> ResponseType: + return self.response_cls + + def get_handler_type(self) -> HandlerType: + return HandlerType.BLOCKING + + +class NonBlockingRequestHandler(Generic[ResponseType, StateType]): + """ + The handler for the "fire and probe" functions that launch async tasks + and have special mechanism to probe the task status. + + It requires 2 callables as the input: + - the first one is to launch the async task. This function should return + a 'state handle' that can be used to probe the task status later. + - the second one is to probe the task status and get the return value + with the 'state handle' returned by the first function. + """ + + # TODO: implement this in the future versions if needed + pass + + +class MessageQueueServer: + def __init__(self, bind_url: str, context: zmq.Context, max_workers: int = 4): + # Socket + self.ctx = context + self.socket = self.ctx.socket(zmq.ROUTER) + self.socket.bind(bind_url) + + # Poller + self.poller = zmq.Poller() + self.poller.register(self.socket, zmq.POLLIN) + + # Main loop thread + self.is_finished = threading.Event() + self.worker_thread = threading.Thread(target=self._main_loop, daemon=True) + + # Thread pool for blocking handlers + self.thread_pool = ThreadPoolExecutor(max_workers=max_workers) + + # Registered handlers: request_type -> (payload_cls, handler) + self.handlers: dict[RequestType, RequestHandlerBase[Any]] = {} + + def _call_sync_handler( + self, + handler_entry: SyncRequestHandler[Any], + payloads: list[bytes], + prefix_frames: list[bytes], + ) -> Any: + """ + Call the sync handler and send the response back to the client. + + Args: + handler_entry (SyncRequestHandler[Any]): The handler entry. + payloads (list[bytes]): The payloads of the request. + prefix_frames (list[bytes]): The prefix frames to send back. + """ + response = handler_entry(payloads) + response_cls = handler_entry.get_response_class() + b_response = msgspec_encode(response, cls=response_cls) + if response is not None: + self.socket.send_multipart(prefix_frames + [b_response]) + else: + self.socket.send_multipart(prefix_frames) + + def _call_blocking_handler( + self, + handler_entry: BlockingRequestHandler[Any], + payloads: list[bytes], + prefix_frames: list[bytes], + ) -> Any: + """ + Call the blocking handler in a separate thread and send the response + back to the client. + + Args: + handler_entry (BlockingRequestHandler[Any]): The handler entry. + payloads (list[bytes]): The payloads of the request. + prefix_frames (list[bytes]): The prefix frames to send back. + """ + future = handler_entry(payloads) + + def _send_response(fut: Future): + try: + response = fut.result() + response_cls = handler_entry.get_response_class() + b_response = msgspec_encode(response, cls=response_cls) + if response is not None: + self.socket.send_multipart(prefix_frames + [b_response]) + else: + self.socket.send_multipart(prefix_frames) + except Exception as e: + logger.error("Error in blocking handler: %s", e) + + future.add_done_callback(_send_response) + + def _call_handler( + self, + handler_entry: RequestHandlerBase[Any], + payloads: list[bytes], + prefix_frames: list[bytes], + ) -> Any: + match handler_entry.get_handler_type(): + case HandlerType.SYNC: + assert isinstance(handler_entry, SyncRequestHandler) + self._call_sync_handler(handler_entry, payloads, prefix_frames) + case HandlerType.BLOCKING: + assert isinstance(handler_entry, BlockingRequestHandler) + self._call_blocking_handler(handler_entry, payloads, prefix_frames) + case HandlerType.NON_BLOCKING: + raise NotImplementedError("Non-blocking handler is not supported yet") + case _: + raise ValueError("Unknown handler type") + + def _main_loop(self): + while not self.is_finished.is_set(): + socks = dict(self.poller.poll(1000)) + if socks.get(self.socket) == zmq.POLLIN: + msg = self.socket.recv_multipart() + assert len(msg) >= 3, ( + "Expected at least 3 message parts " + "[identity, request_uid, request_type, *payloads]" + ) + + identity, b_request_uid, b_request_type, *payloads = msg + request_type = msgspec_decode(b_request_type, cls=RequestType) + + if handler_entry := self.handlers.get(request_type): + try: + self._call_handler( + handler_entry=handler_entry, + payloads=payloads, + prefix_frames=[identity, b_request_uid, b_request_type], + ) + except Exception as e: + logger.error("Error handling request %s: %s", request_type, e) + else: + logger.error( + "No handler registered for request type %s", request_type + ) + logger.error("Available handlers: %s", list(self.handlers.keys())) + + def add_handler( + self, + request_type: RequestType, + payload_clss: list[Any], + handler_type: HandlerType, + handler, + ) -> None: + """Register a handler for a specific request type. + + Args: + request_type (RequestType): The type of the request to handle. + payload_clss (list[Any]): The expected payload classes for the request. + This should be get from `get_payload_classes(request_type)`. + handler (callable): The handler function that takes the payloads + as arguments. + """ + match handler_type: + case HandlerType.SYNC: + self.add_sync_handler(request_type, payload_clss, handler) + case HandlerType.BLOCKING: + self.add_blocking_handler(request_type, payload_clss, handler) + case HandlerType.NON_BLOCKING: + raise NotImplementedError("Non-blocking handler is not supported yet") + case _: + raise ValueError(f"Unknown handler type: {handler_type}") + # self.handlers[request_type] = self.HandlerEntry(payload_clss, handler) + + def add_sync_handler( + self, request_type: RequestType, payload_clss: list[Any], handler + ) -> None: + response_cls = get_response_class(request_type) + self.handlers[request_type] = SyncRequestHandler( + payload_clss, response_cls, handler + ) + + def add_blocking_handler( + self, request_type: RequestType, payload_clss: list[Any], handler + ) -> None: + response_cls = get_response_class(request_type) + self.handlers[request_type] = BlockingRequestHandler( + self.thread_pool, payload_clss, response_cls, handler + ) + + def add_nonblocking_handler( + self, request_type: RequestType, payload_clss: list[Any], handler + ) -> None: + raise NotImplementedError + + def start(self): + self.worker_thread.start() + + def close(self) -> None: + self.is_finished.set() + self.worker_thread.join() + self.socket.close() diff --git a/lmcache/v1/multiprocess/protocol.py b/lmcache/v1/multiprocess/protocol.py new file mode 100644 index 00000000000..4460d10dbf2 --- /dev/null +++ b/lmcache/v1/multiprocess/protocol.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from dataclasses import dataclass +from typing import Any, Optional +import enum + +# First Party +from lmcache.v1.multiprocess.custom_types import IPCCacheEngineKey, KVCache + +""" +Main RPC protocol for the LMCache core server and clients. The following +functions are supported: + +- REGISTER_KV_CACHE: + instance_id: int + kv_caches: KVCache + +- UNREGISTER_KV_CACHE: + instance_id: int + +- STORE: + keys: list[KeyType] + instance_id: int + gpu_block_ids: list[int] + +- RETRIEVE: + keys: list[KeyType] + instance_id: int + gpu_block_ids: list[int] + enable_layerwise: Optional[bool] + +- LOOKUP: + keys: list[KeyType] + lock: Optional[bool] +""" + +# Identifier for different vLLM instances +InstanceID = int + + +class RequestType(enum.Enum): + REGISTER_KV_CACHE = enum.auto() + UNREGISTER_KV_CACHE = enum.auto() + STORE = enum.auto() + RETRIEVE = enum.auto() + LOOKUP = enum.auto() + + # For debug, could be used as heartbeats + NOOP = enum.auto() + + +class HandlerType(enum.Enum): + SYNC = enum.auto() # Handler runs directly in the main loop + BLOCKING = enum.auto() # Handler may block, run in a thread pool + NON_BLOCKING = enum.auto() # Not supported yet + + +@dataclass +class ProtocolDefinition: + """ + Helper class for the protocol definition + """ + + payload_classes: list[Any] + response_class: Optional[Any] + handler_type: HandlerType + + +def get_payload_classes(req_type: RequestType) -> list[Any]: + if pd := _PROTOCOL_DEFINTIONS.get(req_type, None): + return pd.payload_classes + else: + raise ValueError(f"Invalid request type: {req_type}") + + +def get_response_class(req_type: RequestType) -> Optional[Any]: + if pd := _PROTOCOL_DEFINTIONS.get(req_type, None): + return pd.response_class + else: + raise ValueError(f"Invalid request type: {req_type}") + + +def get_handler_type(req_type: RequestType) -> HandlerType: + if pd := _PROTOCOL_DEFINTIONS.get(req_type, None): + return pd.handler_type + else: + raise ValueError(f"Invalid request type: {req_type}") + + +KeyType = IPCCacheEngineKey + +_PROTOCOL_DEFINTIONS = { + # Register KV Cache + # - instance_id: int + # - kv_cache: KVCacheType + # Returns: None + RequestType.REGISTER_KV_CACHE: ProtocolDefinition( + payload_classes=[int, KVCache], + response_class=None, + handler_type=HandlerType.SYNC, + ), + # Unregister KV Cache + # - instance_id: int + # Returns: None + RequestType.UNREGISTER_KV_CACHE: ProtocolDefinition( + payload_classes=[int], + response_class=None, + handler_type=HandlerType.SYNC, + ), + # Store + # - keys: list[KeyType] + # - instance_id: int + # - gpu_block_ids: list[int] + # Returns: bool (success) + RequestType.STORE: ProtocolDefinition( + payload_classes=[list[KeyType], int, list[int]], + response_class=bool, + handler_type=HandlerType.BLOCKING, + ), + # Retrieve + # - keys: list[KeyType] + # - instance_id: int + # - gpu_block_ids: list[int] + # Returns: list[bool] + # NOTE: no layerwise support for now + RequestType.RETRIEVE: ProtocolDefinition( + payload_classes=[list[KeyType], int, list[int]], + response_class=list[bool], + handler_type=HandlerType.BLOCKING, + ), + # Lookup + # - keys: list[KeyType] + # - lock: Optional[bool] + # Returns: list[bool] (found or not for each key) + RequestType.LOOKUP: ProtocolDefinition( + payload_classes=[list[KeyType], Optional[bool]], + response_class=list[bool], + handler_type=HandlerType.BLOCKING, + ), + # Debug commands + RequestType.NOOP: ProtocolDefinition( + payload_classes=[], + response_class=str, + handler_type=HandlerType.SYNC, + ), +} diff --git a/lmcache/v1/offload_server/zmq_server.py b/lmcache/v1/offload_server/zmq_server.py index b1084037508..b1bebf84d3f 100644 --- a/lmcache/v1/offload_server/zmq_server.py +++ b/lmcache/v1/offload_server/zmq_server.py @@ -5,7 +5,6 @@ import threading # Third Party -from vllm.utils import make_zmq_socket import msgspec import zmq @@ -13,7 +12,11 @@ from lmcache.v1.cache_engine import LMCacheEngine from lmcache.v1.offload_server.abstract_server import OffloadServerInterface from lmcache.v1.offload_server.message import OffloadMsg, OffloadRetMsg -from lmcache.v1.rpc_utils import get_zmq_rpc_path_lmcache +from lmcache.v1.rpc_utils import ( + get_zmq_context, + get_zmq_rpc_path_lmcache, + get_zmq_socket, +) if TYPE_CHECKING: # Third Party @@ -27,16 +30,17 @@ def __init__( vllm_config: "VllmConfig", tp_rank: int, ): - self.ctx = zmq.Context() # type: ignore[attr-defined] + self.ctx = get_zmq_context(use_asyncio=False) offload_rpc_port = int(os.environ.get("LMCACHE_OFFLOAD_RPC_PORT", 100)) socket_path = get_zmq_rpc_path_lmcache( vllm_config, "offload", offload_rpc_port, tp_rank ) - self.socket = make_zmq_socket( + self.socket = get_zmq_socket( self.ctx, socket_path, + "ipc", zmq.REP, # type: ignore[attr-defined] - bind=True, + "bind", ) self.lmcache_engine = lmcache_engine diff --git a/lmcache/v1/plugin/plugin_launcher.py b/lmcache/v1/plugin/plugin_launcher.py index 76ac3b650df..f11a6b7233d 100644 --- a/lmcache/v1/plugin/plugin_launcher.py +++ b/lmcache/v1/plugin/plugin_launcher.py @@ -115,8 +115,9 @@ def _get_interpreter(self, file: Path) -> str: with open(file, "r", encoding="utf-8") as f: first_line = f.readline().strip() if first_line.startswith("#!"): - # Extract interpreter path from shebang - return first_line[2:].strip() + # Extract interpreter + interpreter_str = first_line[2:].strip() + interpreters.append(interpreter_str) except Exception as e: logger.error( f"Error reading interpreter from plugin file {file} - " diff --git a/lmcache/v1/protocol.py b/lmcache/v1/protocol.py index d693bc38738..c795cc8866e 100644 --- a/lmcache/v1/protocol.py +++ b/lmcache/v1/protocol.py @@ -2,14 +2,14 @@ # Standard from dataclasses import dataclass from enum import IntEnum, auto -from typing import Optional +from typing import Optional, Union import struct # Third Party import torch # First Party -from lmcache.utils import CacheEngineKey +from lmcache.utils import CacheEngineKey, LayerCacheEngineKey, parse_cache_key from lmcache.v1.memory_management import MemoryFormat MAX_KEY_LENGTH = 150 @@ -131,7 +131,7 @@ class ClientMetaMessage: """ command: ClientCommand - key: CacheEngineKey + key: Union[CacheEngineKey, LayerCacheEngineKey] length: int fmt: MemoryFormat dtype: Optional[torch.dtype] @@ -170,7 +170,7 @@ def deserialize(s: bytes) -> "ClientMetaMessage": ) return ClientMetaMessage( ClientCommand(command), - CacheEngineKey.from_string(key.decode().strip()), + parse_cache_key(key.decode().strip()), length, MemoryFormat(fmt), INT_TO_DTYPE[dtype], diff --git a/lmcache/v1/rpc_utils.py b/lmcache/v1/rpc_utils.py index 09c55eadb73..73320a9704d 100644 --- a/lmcache/v1/rpc_utils.py +++ b/lmcache/v1/rpc_utils.py @@ -81,7 +81,7 @@ def get_zmq_rpc_path_lmcache( vllm_config: Optional["VllmConfig"] = None, service_name: ServiceKind = "lookup", rpc_port: int = 0, - tp_rank: int = 0, + rank: int = 0, ) -> str: """Get the ZMQ RPC path for LMCache lookup and offload communication.""" # Third Party @@ -100,9 +100,9 @@ def get_zmq_rpc_path_lmcache( engine_id = vllm_config.kv_transfer_config.engine_id if isinstance(rpc_port, str): - rpc_port = rpc_port + str(tp_rank) + rpc_port = rpc_port + str(rank) else: - rpc_port += tp_rank + rpc_port += rank logger.debug( "Base URL: %s, Engine: %s, Service Name: %s, RPC Port: %s", @@ -113,7 +113,7 @@ def get_zmq_rpc_path_lmcache( ) socket_path = ( - f"ipc://{base_url}/engine_{engine_id}_service_{service_name}_" + f"{base_url}/engine_{engine_id}_service_{service_name}_" f"lmcache_rpc_port_{rpc_port}" ) diff --git a/lmcache/v1/storage_backend/__init__.py b/lmcache/v1/storage_backend/__init__.py index 0b58f9ccd66..15b02c23fe2 100644 --- a/lmcache/v1/storage_backend/__init__.py +++ b/lmcache/v1/storage_backend/__init__.py @@ -27,11 +27,24 @@ logger = init_logger(__name__) +def is_cuda_worker(metadata: LMCacheEngineMetadata) -> bool: + """ + Check if the current role is worker and CUDA is available. + + Args: + metadata: The LMCache engine metadata. + + Returns: + True if the worker is not a scheduler and CUDA is available. + """ + return metadata.role != "scheduler" and torch.cuda.is_available() + + def create_dynamic_backends( config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, loop: asyncio.AbstractEventLoop, - local_cpu_backend: LocalCPUBackend, + local_cpu_backend: Optional[LocalCPUBackend], dst_device: str, storage_backends: OrderedDict[str, StorageBackendInterface], ) -> None: @@ -100,10 +113,12 @@ def CreateStorageBackends( dst_device: str = "cuda", lmcache_worker: Optional["LMCacheWorker"] = None, ) -> OrderedDict[str, StorageBackendInterface]: - # Replace 'cuda' with 'cuda:' - if dst_device == "cuda": + if is_cuda_worker(metadata): dst_device = f"cuda:{torch.cuda.current_device()}" - + elif dst_device == "xpu": + dst_device = f"xpu:{torch.xpu.current_device()}" + else: + dst_device = "cpu" storage_backends: OrderedDict[str, StorageBackendInterface] = OrderedDict() extra_config = config.extra_config @@ -120,17 +135,25 @@ def CreateStorageBackends( # TODO(Jiayi): The hierarchy is fixed for now # NOTE(Jiayi): The local_cpu backend is always created because # other backends might need it as a buffer. - if not config.enable_pd or config.local_cpu: - local_cpu_backend = LocalCPUBackend( - config, - metadata, - dst_device, - lmcache_worker, - ) - backend_name = str(local_cpu_backend) - storage_backends[backend_name] = local_cpu_backend + local_cpu_backend: Optional[LocalCPUBackend] = None + if metadata.role == "scheduler": + # For scheduler role, local_cpu_backend is None + pass + elif not config.enable_pd or config.local_cpu: + if config.max_local_cpu_size > 0: + local_cpu_backend = LocalCPUBackend( + config, + metadata, + dst_device, + lmcache_worker, + ) + backend_name = str(local_cpu_backend) + storage_backends[backend_name] = local_cpu_backend + else: + logger.info("No cpu memory is allocated as max_local_cpu_size <= 0") if config.enable_p2p: + assert local_cpu_backend is not None p2p_backend = P2PBackend( config, metadata, @@ -152,6 +175,7 @@ def CreateStorageBackends( ) if config.local_disk and config.max_local_disk_size > 0: + assert local_cpu_backend is not None local_disk_backend = LocalDiskBackend( config, loop, local_cpu_backend, dst_device, lmcache_worker ) diff --git a/lmcache/v1/storage_backend/abstract_backend.py b/lmcache/v1/storage_backend/abstract_backend.py index b44061f11e6..6a060235242 100644 --- a/lmcache/v1/storage_backend/abstract_backend.py +++ b/lmcache/v1/storage_backend/abstract_backend.py @@ -116,6 +116,16 @@ def get_blocking( """ raise NotImplementedError + def get_non_blocking( + self, + key: CacheEngineKey, + location: Optional[str] = None, + ) -> Optional[Future]: + """ + A non-blocking function to get the kv cache from the storage backend. + """ + raise NotImplementedError + async def batched_async_contains( self, lookup_id: str, @@ -246,6 +256,31 @@ def close( """ raise NotImplementedError + def support_batched_contains(self) -> bool: + return False + + def batched_contains( + self, + keys: List[CacheEngineKey], + pin: bool = False, + stop_after_first_not_exits: bool = True, + ) -> List[bool]: + """ + Check whether the keys are in the storage backend. + + :param List[CacheEngineKey] keys: The keys of the MemoryObj. + + :param bool pin: Whether to pin the key. + If True, the corresponding KV cache will be + pinned in the storage backend. + + :param bool stop_after_first_not_exits: Stop when find the first not exists key, + all subsequent results will return False directly. + + :return: Return a bool list, True if the key exists, False otherwise. + """ + raise NotImplementedError + class AllocatorBackendInterface(StorageBackendInterface): """ diff --git a/lmcache/v1/storage_backend/connector/__init__.py b/lmcache/v1/storage_backend/connector/__init__.py index 50ac65eab17..73ce7fe81b5 100644 --- a/lmcache/v1/storage_backend/connector/__init__.py +++ b/lmcache/v1/storage_backend/connector/__init__.py @@ -76,6 +76,24 @@ def parse_remote_url(url: str) -> ParsedRemoteURL: ) +class SafeLocalCPUBackend(LocalCPUBackend): + """ + A safe stub for LocalCPUBackend that can be used when local_cpu_backend is None. + """ + + def __init__(self, config: LMCacheEngineConfig): + pass + + def allocate(self, *args, **kwargs): + raise RuntimeError( + "SafeLocalCPUBackend.allocate() should never be called. " + "This indicates a bug where scheduler role is trying to allocate memory." + ) + + def __str__(self): + return "SafeLocalCPUBackend(dummy)" + + class ConnectorContext: """ Context for creating a connector. @@ -84,6 +102,7 @@ class ConnectorContext: url: The remote URL loop: The asyncio event loop local_cpu_backend: The local CPU backend + (wrapped as SafeLocalCPUBackend if None) config: Optional LMCache engine configuration parsed_url: Parsed representation of the URL """ @@ -92,13 +111,19 @@ def __init__( self, url: str, loop: asyncio.AbstractEventLoop, - local_cpu_backend: LocalCPUBackend, + local_cpu_backend: Optional[LocalCPUBackend], config: Optional[LMCacheEngineConfig], metadata: Optional[LMCacheEngineMetadata], ): self.url = url self.loop = loop - self.local_cpu_backend = local_cpu_backend + # Wrap None as SafeLocalCPUBackend to satisfy type requirements + # The SafeLocalCPUBackend will raise an error if allocate() is called + self.local_cpu_backend: LocalCPUBackend = ( + local_cpu_backend + if local_cpu_backend is not None + else SafeLocalCPUBackend(config) + ) self.config = config self.metadata = metadata @@ -132,7 +157,7 @@ def __init__( self, url: str, loop: asyncio.AbstractEventLoop, - local_cpu_backend: LocalCPUBackend, + local_cpu_backend: Optional[LocalCPUBackend], config: Optional[LMCacheEngineConfig] = None, metadata: Optional[LMCacheEngineMetadata] = None, ) -> None: @@ -197,7 +222,7 @@ def create_connector(self) -> RemoteConnector: def CreateConnector( url: str, loop: asyncio.AbstractEventLoop, - local_cpu_backend: LocalCPUBackend, + local_cpu_backend: Optional[LocalCPUBackend], config: Optional[LMCacheEngineConfig] = None, metadata: Optional[LMCacheEngineMetadata] = None, ) -> InstrumentedRemoteConnector: @@ -238,7 +263,7 @@ def CreateConnector( Args: url: The remote URL loop: The asyncio event loop - local_cpu_backend: The local CPU backend + local_cpu_backend: The local CPU backend (can be None for scheduler role) config: Optional LMCache engine configuration metadata: Optional LMCache engine metadata diff --git a/lmcache/v1/storage_backend/connector/audit_adapter.py b/lmcache/v1/storage_backend/connector/audit_adapter.py index 09cba94ec03..681aefbd4bb 100644 --- a/lmcache/v1/storage_backend/connector/audit_adapter.py +++ b/lmcache/v1/storage_backend/connector/audit_adapter.py @@ -73,4 +73,5 @@ def create_connector(self, context: ConnectorContext) -> RemoteConnector: context.config, context.metadata, ) - return AuditConnector(connector.getWrappedConnector(), context.config) + # Metaclass dynamically implements all abstract methods at runtime + return AuditConnector(connector.getWrappedConnector(), context.config) # type: ignore[abstract] diff --git a/lmcache/v1/storage_backend/connector/audit_connector.py b/lmcache/v1/storage_backend/connector/audit_connector.py index feed67c3aec..4ac33d8c0f4 100644 --- a/lmcache/v1/storage_backend/connector/audit_connector.py +++ b/lmcache/v1/storage_backend/connector/audit_connector.py @@ -1,27 +1,196 @@ # SPDX-License-Identifier: Apache-2.0 # Standard from threading import Lock -from typing import Dict, List, Optional +from typing import Dict, Optional +import abc import asyncio +import functools import hashlib +import logging import time # First Party -from lmcache.logging import init_logger from lmcache.utils import CacheEngineKey from lmcache.v1.config import LMCacheEngineConfig from lmcache.v1.memory_management import MemoryObj from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector -logger = init_logger(__name__) +logger = logging.getLogger(__name__) -class AuditConnector(RemoteConnector): - """Audit wrapper for RemoteConnector that verifies data integrity - and logs operations. +class AuditConnectorMeta(abc.ABCMeta): + """Metaclass that dynamically generates wrapper methods for all + RemoteConnector methods + """ + + def __new__(mcs, name, bases, namespace): + # Get all methods from RemoteConnector including abstract methods + # We need to check both the class dict and inherited methods + all_methods = {} + + # Collect methods from RemoteConnector and its bases + for base in RemoteConnector.__mro__: + for method_name, method_obj in base.__dict__.items(): + if method_name not in all_methods and callable(method_obj): + all_methods[method_name] = method_obj + + for method_name, method in all_methods.items(): + # Skip private methods and methods already defined in namespace + if method_name.startswith("_") or method_name in namespace: + continue + + # Skip class methods, static methods, and properties + if isinstance(method, (classmethod, staticmethod, property)): + continue + + # Check if method is marked with @NotAudit + is_not_audit = getattr(method, "_not_audit", False) + + # Determine if method is async + is_async = asyncio.iscoroutinefunction(method) + + # Create appropriate wrapper and add to namespace + if is_not_audit: + if is_async: + namespace[method_name] = mcs._create_passthrough_async_method( + method_name, method + ) + else: + namespace[method_name] = mcs._create_passthrough_sync_method( + method_name, method + ) + else: + if is_async: + namespace[method_name] = mcs._create_audit_async_method( + method_name, method + ) + else: + namespace[method_name] = mcs._create_audit_sync_method( + method_name, method + ) + + # Create the class with all methods in namespace + cls = super().__new__(mcs, name, bases, namespace) + + # Clear abstract methods since we've implemented them all via wrappers + # The wrappers delegate to real_connector which has the actual implementations + if hasattr(cls, "__abstractmethods__"): + cls.__abstractmethods__ = frozenset() + + return cls + + @staticmethod + def _create_passthrough_async_method(method_name: str, original_method): + """Create a pass-through async method without logging""" + + @functools.wraps(original_method) + async def wrapper(self, *args, **kwargs): + real_method = getattr(self.real_connector, method_name) + return await real_method(*args, **kwargs) + + wrapper.__name__ = method_name + wrapper.__qualname__ = f"AuditConnector.{method_name}" + return wrapper + + @staticmethod + def _create_passthrough_sync_method(method_name: str, original_method): + """Create a pass-through sync method without logging""" + + @functools.wraps(original_method) + def wrapper(self, *args, **kwargs): + real_method = getattr(self.real_connector, method_name) + return real_method(*args, **kwargs) + + wrapper.__name__ = method_name + wrapper.__qualname__ = f"AuditConnector.{method_name}" + return wrapper + + @staticmethod + def _create_audit_async_method(method_name: str, original_method): + """Create an audit async method with logging""" + + @functools.wraps(original_method) + async def wrapper(self, *args, **kwargs): + # Special handling for put/get methods with checksum + if method_name == "put": + return await self._audit_put(*args, **kwargs) + elif method_name == "get": + return await self._audit_get(*args, **kwargs) + + # Check if method is in excluded commands + if hasattr(self, "excluded_cmds") and method_name in self.excluded_cmds: + real_method = getattr(self.real_connector, method_name) + return await real_method(*args, **kwargs) + + # Generic audit logging + self.logger.debug( + f"[REMOTE_AUDIT][{self.real_connector}]:{method_name.upper()}|START" + ) + t1 = time.perf_counter() + try: + real_method = getattr(self.real_connector, method_name) + result = await real_method(*args, **kwargs) + t2 = time.perf_counter() + cost = (t2 - t1) * 1000 + self.logger.info( + f"[REMOTE_AUDIT][{self.real_connector}]:{method_name.upper()}|" + f"SUCCESS|Cost:{cost:.6f}ms" + ) + return result + except Exception as e: + self.logger.error( + f"[REMOTE_AUDIT][{self.real_connector}]:{method_name.upper()}|" + f"FAILED|Error: {str(e)}" + ) + raise + + wrapper.__name__ = method_name + wrapper.__qualname__ = f"AuditConnector.{method_name}" + return wrapper + + @staticmethod + def _create_audit_sync_method(method_name: str, original_method): + """Create an audit sync method with logging""" + + @functools.wraps(original_method) + def wrapper(self, *args, **kwargs): + # Check if method is in excluded commands + if hasattr(self, "excluded_cmds") and method_name in self.excluded_cmds: + real_method = getattr(self.real_connector, method_name) + return real_method(*args, **kwargs) + + self.logger.debug( + f"[REMOTE_AUDIT][{self.real_connector}]:{method_name.upper()}|START" + ) + t1 = time.perf_counter() + try: + real_method = getattr(self.real_connector, method_name) + result = real_method(*args, **kwargs) + t2 = time.perf_counter() + cost = (t2 - t1) * 1000 + self.logger.info( + f"[REMOTE_AUDIT][{self.real_connector}]:{method_name.upper()}|" + f"SUCCESS|Cost:{cost:.6f}ms" + ) + return result + except Exception as e: + self.logger.error( + f"[REMOTE_AUDIT][{self.real_connector}]:{method_name.upper()}|" + f"FAILED|Error: {str(e)}" + ) + raise + + wrapper.__name__ = method_name + wrapper.__qualname__ = f"AuditConnector.{method_name}" + return wrapper + + +class AuditConnector(RemoteConnector, metaclass=AuditConnectorMeta): + """Audit wrapper for RemoteConnector that dynamically wraps all methods. Features: - - Wraps any RemoteConnector implementation + - Automatically wraps all RemoteConnector methods + - Methods marked with @NotAudit are forwarded without logging - Configurable checksum verification via URL parameter - Logs all operations with timestamps - Optional checksum validation for put/get operations @@ -56,9 +225,6 @@ def __init__( self.logger = logger.getChild("audit") - # Dynamically replace excluded methods - self._replace_excluded_methods() - logger.info( f"[REMOTE_AUDIT][{self.real_connector}]:INITIALIZED|" f"Calc Checksum:{self.calc_checksum}|" @@ -66,37 +232,11 @@ def __init__( f"Excluded Cmds: {self.excluded_cmds}" ) - def _replace_excluded_methods(self): - """Dynamically replace methods that should be excluded from auditing""" - for method_name in self.excluded_cmds: - if hasattr(self.real_connector, method_name): - # Create a direct pass-through method - real_method = getattr(self.real_connector, method_name) - - if asyncio.iscoroutinefunction(real_method): - - def create_async_wrapper(rm): - async def async_wrapper(*args, **kwargs): - return await rm(*args, **kwargs) - - return async_wrapper - - setattr(self, method_name, create_async_wrapper(real_method)) - else: - - def create_sync_wrapper(rm): - def sync_wrapper(*args, **kwargs): - return rm(*args, **kwargs) - - return sync_wrapper - - setattr(self, method_name, create_sync_wrapper(real_method)) - def _calculate_checksum(self, data: bytes) -> str: """Calculate SHA-256 checksum for data validation""" return hashlib.sha256(data).hexdigest() - async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): + async def _audit_put(self, key: CacheEngineKey, memory_obj: MemoryObj): """Store data with optional checksum tracking""" data = memory_obj.byte_array checksum = self._calculate_checksum(data) if self.calc_checksum else "N/A" @@ -127,7 +267,7 @@ async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): ) raise - async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: + async def _audit_get(self, key: CacheEngineKey) -> Optional[MemoryObj]: """Retrieve data with optional integrity check""" self.logger.debug( f"[REMOTE_AUDIT][{self.real_connector}]:GET|START|" @@ -178,65 +318,3 @@ async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: f"FAILED|Key:{key}|Error: {str(e)}" ) raise - - async def exists(self, key: CacheEngineKey) -> bool: - """Check key existence with audit log""" - self.logger.debug( - f"[REMOTE_AUDIT][{self.real_connector}]:EXISTS|START|Key:{key}" - ) - t1 = time.perf_counter() - result = await self.real_connector.exists(key) - t2 = time.perf_counter() - cost = (t2 - t1) * 1000 - self.logger.info( - f"[REMOTE_AUDIT][{self.real_connector}]:EXISTS|{result}|" - f"Cost:{cost:.6f}ms|" - f"Key:{key}" - ) - return result - - def exists_sync(self, key: CacheEngineKey) -> bool: - """Check key existence with audit log synchronized""" - self.logger.debug(f"[REMOTE_AUDIT]EXISTS_SYNC|START|Key:{key}") - result = self.real_connector.exists_sync(key) - self.logger.info(f"[REMOTE_AUDIT]EXISTS_SYNC|{result}|Key: {key}") - return result - - async def list(self) -> List[str]: - """List keys with audit log""" - self.logger.debug("[REMOTE_AUDIT][{self.real_connector}]:LIST|START") - t1 = time.perf_counter() - result = await self.real_connector.list() - t2 = time.perf_counter() - cost = (t2 - t1) * 1000 - self.logger.info( - f"[REMOTE_AUDIT][{self.real_connector}]:LIST|SUCCESS|" - f"Count:{len(result)}|Cost:{cost:.6f}ms" - ) - return result - - async def close(self): - """Cleanup resources with audit log""" - self.logger.debug(f"[REMOTE_AUDIT][{self.real_connector}]:CLOSE|START") - await self.real_connector.close() - self.logger.info(f"[REMOTE_AUDIT][{self.real_connector}]:CLOSE|SUCCESS") - - def support_ping(self) -> bool: - self.logger.debug(f"[REMOTE_AUDIT][{self.real_connector}]:SUPPORT_PING|START") - support = self.real_connector.support_ping() - self.logger.info( - f"[REMOTE_AUDIT][{self.real_connector}]:SUPPORT_PING|{support}" - ) - return support - - async def ping(self) -> int: - self.logger.debug(f"[REMOTE_AUDIT][{self.real_connector}]:PING|START") - t1 = time.perf_counter() - error_code = await self.real_connector.ping() - t2 = time.perf_counter() - cost = (t2 - t1) * 1000 - self.logger.debug( - f"[REMOTE_AUDIT][{self.real_connector}]:PING|{error_code}|" - f"Cost:{cost:.6f}ms" - ) - return error_code diff --git a/lmcache/v1/storage_backend/connector/base_connector.py b/lmcache/v1/storage_backend/connector/base_connector.py index 9dde700f9e6..8f2d1429aa5 100644 --- a/lmcache/v1/storage_backend/connector/base_connector.py +++ b/lmcache/v1/storage_backend/connector/base_connector.py @@ -2,6 +2,7 @@ # Standard from typing import List, Optional import abc +import asyncio # Third Party import torch @@ -16,6 +17,15 @@ logger = init_logger(__name__) +def NotAudit(func): + """ + Decorator to mark methods that should not be audited. + These methods will be directly forwarded to the real connector without logging. + """ + func._not_audit = True + return func + + class RemoteConnector(metaclass=abc.ABCMeta): """ Interface for remote connector @@ -28,6 +38,7 @@ class RemoteConnector(metaclass=abc.ABCMeta): full_chunk_size: Optional[int] = None single_token_size: Optional[int] = None + @NotAudit def init_chunk_meta( self, config: Optional[LMCacheEngineConfig], @@ -73,6 +84,7 @@ def init_chunk_meta( f"single token size: {self.single_token_size}" ) + @NotAudit def reshape_partial_chunk( self, memory_obj: MemoryObj, @@ -103,6 +115,7 @@ def reshape_partial_chunk( return memory_obj + @NotAudit def post_init(self): """ Post-initialization method to be called after the connector is created. @@ -243,10 +256,7 @@ async def batched_put( raise NotImplementedError def support_batched_async_contains(self) -> bool: - """ - Connectors that support batched async contains should override this method. - """ - return False + return True async def batched_async_contains( self, @@ -254,26 +264,99 @@ async def batched_async_contains( keys: List[CacheEngineKey], pin: bool = False, ) -> int: + """Check how many keys exist in file system in batch + + Args: + lookup_id: Identifier for this lookup operation + keys: List of keys to check + pin: Whether to pin the keys (not used in FS connector) + + Returns: + Number of consecutive keys that exist, starting from the first key """ - Check if the remote server contains the keys - """ - raise NotImplementedError + tasks = [self.exists(key) for key in keys] + results = await asyncio.gather(*tasks) + if False in results: + return results.index(False) + return len(results) def support_batched_get_non_blocking(self) -> bool: - """ - Connectors that support batched get non-blocking should override this method. - """ - return False + return True async def batched_get_non_blocking( self, lookup_id: str, keys: List[CacheEngineKey], ) -> List[MemoryObj]: + """Batched get the memory_objs of the corresponding keys (non-blocking) + + This method returns only the consecutive prefix of successfully retrieved + memory objects. Once a key is not found (None) or an exception occurs, + all subsequent memory objects (even if successfully retrieved) will be + released to avoid memory leaks, and only the prefix before the first + failure will be returned. + + Args: + lookup_id: Identifier for this lookup operation + keys: List of keys to get + + Returns: + List of consecutive memory objects from the beginning until the first + failure (None or Exception). Empty list if the first key fails. """ - Batched get the memory_objs of the corresponding keys + # Use asyncio.gather to fetch all keys concurrently + results = await asyncio.gather( + *(self.get(key) for key in keys), return_exceptions=True + ) + + # Only return consecutive prefix of valid memory objects + memory_objs = [] + found_failure = False + for result in results: + if found_failure: + # Release subsequent memory objects to avoid memory leak + if isinstance(result, MemoryObj): + result.ref_count_down() + elif isinstance(result, MemoryObj): + memory_objs.append(result) + else: + # First failure encountered (None or Exception) + if isinstance(result, Exception): + logger.warning(f"Exception during batched get: {result}") + found_failure = True + + return memory_objs + + def remove_sync(self, key: CacheEngineKey) -> bool: + """ + Remove a memory object. + + :param CacheEngineKey key: The key of the MemoryObj. + + :return: a bool indicates whether remove is successful. """ raise NotImplementedError + def batched_contains( + self, keys: List[CacheEngineKey], stop_after_first_not_exits: bool = True + ) -> List[bool]: + """ + Batched contains. + + :param List[CacheEngineKey] keys: The keys to check. + + :param bool stop_after_first_not_exits: Stop when find the first not exists key, + all subsequent results will return False directly. + + :return: Return a bool list, True if the key exists, False otherwise. + """ + raise NotImplementedError + + def support_batched_contains(self) -> bool: + """ + Is supported batched_contains + """ + return False + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" diff --git a/lmcache/v1/storage_backend/connector/fs_adapter.py b/lmcache/v1/storage_backend/connector/fs_adapter.py index f633c484552..fd6f73a68c6 100644 --- a/lmcache/v1/storage_backend/connector/fs_adapter.py +++ b/lmcache/v1/storage_backend/connector/fs_adapter.py @@ -23,13 +23,7 @@ def create_connector(self, context: ConnectorContext) -> RemoteConnector: logger.info(f"Creating FS connector for URL: {context.url}") parse_url = parse_remote_url(context.url) - relative_tmp_dir = ( - None - if context.config is None - else context.config.get_extra_config_value( - "fs_connector_relative_tmp_dir", None - ) - ) + return FSConnector( - parse_url.path, context.loop, context.local_cpu_backend, relative_tmp_dir + parse_url.path, context.loop, context.local_cpu_backend, context.config ) diff --git a/lmcache/v1/storage_backend/connector/fs_connector.py b/lmcache/v1/storage_backend/connector/fs_connector.py index d34fa202d80..152b4f3eaff 100644 --- a/lmcache/v1/storage_backend/connector/fs_connector.py +++ b/lmcache/v1/storage_backend/connector/fs_connector.py @@ -12,6 +12,7 @@ # First Party from lmcache.logging import init_logger from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig from lmcache.v1.memory_management import MemoryObj from lmcache.v1.protocol import RemoteMetadata from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector @@ -35,13 +36,14 @@ def __init__( base_paths_str: str, loop: asyncio.AbstractEventLoop, local_cpu_backend: LocalCPUBackend, - relative_tmp_dir: Optional[str], + config: Optional[LMCacheEngineConfig], ): """ Args: base_paths_str: Comma separated storage paths loop: Asyncio event loop local_cpu_backend: Memory allocator interface + config: Lmcache engine config """ # Parse comma separated paths self.base_paths = ( @@ -52,15 +54,27 @@ def __init__( self.loop = loop self.local_cpu_backend = local_cpu_backend - self.relative_tmp_dir = ( - None if relative_tmp_dir is None else Path(relative_tmp_dir) + + relative_tmp_dir = ( + None + if config is None + else config.get_extra_config_value("fs_connector_relative_tmp_dir", None) ) - if self.relative_tmp_dir is not None: + self.relative_tmp_dir = None + if relative_tmp_dir is not None: + self.relative_tmp_dir = Path(relative_tmp_dir) assert not self.relative_tmp_dir.is_absolute() + self.read_ahead_size = ( + None + if config is None + else config.get_extra_config_value("fs_connector_read_ahead_size", None) + ) + logger.info( f"Initialized FSConnector with base paths {self.base_paths}, " - f"relative tmp dir: {self.relative_tmp_dir}" + f"relative tmp dir: {self.relative_tmp_dir}, " + f"read ahead size: {self.read_ahead_size}" ) # Create directories for all paths for path in self.base_paths: @@ -114,6 +128,7 @@ async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: """Get data from file system""" file_path = self._get_file_path(key) + memory_obj = None try: async with aiofiles.open(file_path, "rb") as f: if self.save_chunk_meta: @@ -141,24 +156,48 @@ async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: # Read the actual data into allocated memory buffer = memory_obj.byte_array - num_read = await f.readinto(buffer) if self.save_chunk_meta: + # if save chunk meta, read meta will trigger + # read ahead if fs supported + num_read = await f.readinto(buffer) if num_read != len(buffer): raise RuntimeError( f"Partial read data {len(buffer)} got {num_read}" ) else: + if self.read_ahead_size is None: + num_read = await f.readinto(buffer) + else: + if not isinstance(buffer, memoryview): + buffer = memoryview(buffer) + + # trigger read head if fs supported + num_read_ahead = await f.readinto( + buffer[: self.read_ahead_size] + ) + assert num_read_ahead <= self.read_ahead_size + + # if num_read_ahead == self.read_ahead_size, + # means there may still be some remaining content + if num_read_ahead == self.read_ahead_size: + num_read_tail = await f.readinto( + buffer[self.read_ahead_size :] + ) + assert num_read_tail is not None + num_read = num_read_ahead + num_read_tail + else: + num_read = num_read_ahead # reshape and check assert num_read is not None memory_obj = self.reshape_partial_chunk(memory_obj, num_read) return memory_obj - except FileNotFoundError: - # Key does not exist is normal case - return None except Exception as e: - logger.error(f"Failed to read from file {file_path}: {str(e)}") + if not isinstance(e, FileNotFoundError): + logger.error(f"Failed to read from file {file_path}: {str(e)}") + if memory_obj is not None: + memory_obj.ref_count_down() return None async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): @@ -194,6 +233,24 @@ async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): await aiofiles.os.unlink(temp_path) # Remove corrupted file raise + def remove_sync(self, key: CacheEngineKey) -> bool: + """ + Remove the file associated with the given key. + + Args: + key: The key to remove. + + Returns: + bool: True if the file was successfully removed, False otherwise. + """ + file_path = self._get_file_path(key) + try: + os.remove(file_path) + return True + except OSError as e: + logger.error(f"Failed to remove file {file_path}: {e}") + return False + @no_type_check async def list(self) -> List[str]: """List all keys in file system""" diff --git a/lmcache/v1/storage_backend/connector/instrumented_connector.py b/lmcache/v1/storage_backend/connector/instrumented_connector.py index 5daa1f02b49..a35bbdebb7a 100644 --- a/lmcache/v1/storage_backend/connector/instrumented_connector.py +++ b/lmcache/v1/storage_backend/connector/instrumented_connector.py @@ -112,7 +112,35 @@ async def batched_get( async def batched_put( self, keys: List[CacheEngineKey], memory_objs: List[MemoryObj] ): - return await self._connector.batched_put(keys, memory_objs) + try: + await self._connector.batched_put(keys, memory_objs) + except Exception as e: + logger.warning(f"batched put error: {e}") + finally: + for memory_obj in memory_objs: + memory_obj.ref_count_down() + + def remove_sync(self, key: CacheEngineKey) -> bool: + return self._connector.remove_sync(key) + + def batched_contains( + self, keys: List[CacheEngineKey], stop_after_first_not_exits: bool = True + ) -> List[bool]: + return self._connector.batched_contains(keys, stop_after_first_not_exits) + + def support_batched_contains(self) -> bool: + return self._connector.support_batched_contains() + + def init_chunk_meta(self, config, metadata) -> None: + return self._connector.init_chunk_meta(config, metadata) + + def reshape_partial_chunk( + self, memory_obj: MemoryObj, bytes_read: int + ) -> MemoryObj: + return self._connector.reshape_partial_chunk(memory_obj, bytes_read) + + def post_init(self): + return self._connector.post_init() def __repr__(self) -> str: return f"InstrumentedRemoteConnector({self._connector})" diff --git a/lmcache/v1/storage_backend/connector/mock_connector.py b/lmcache/v1/storage_backend/connector/mock_connector.py index d51e7aa0e3b..945a2423a1d 100644 --- a/lmcache/v1/storage_backend/connector/mock_connector.py +++ b/lmcache/v1/storage_backend/connector/mock_connector.py @@ -209,7 +209,8 @@ async def exists(self, key: CacheEngineKey) -> bool: ) def exists_sync(self, key: CacheEngineKey) -> bool: - raise NotImplementedError("MockConnector does not support synchronous exists") + """Synchronous exists check without async lock (for testing purposes)""" + return key in self.lru_store.dict async def _get(self, key: CacheEngineKey) -> Optional[MemoryObj]: mock_obj = await self.lru_store.get(key) diff --git a/lmcache/v1/storage_backend/connector/mooncakestore_connector.py b/lmcache/v1/storage_backend/connector/mooncakestore_connector.py index 499e4acc00b..d93743cd248 100644 --- a/lmcache/v1/storage_backend/connector/mooncakestore_connector.py +++ b/lmcache/v1/storage_backend/connector/mooncakestore_connector.py @@ -20,6 +20,7 @@ from lmcache.v1.protocol import RemoteMetadata from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend +from lmcache.v1.system_detection import NUMADetector logger = init_logger(__name__) @@ -44,6 +45,9 @@ def from_file(file_path: str) -> "MooncakeStoreConfig": """Load the config from a JSON file.""" with open(file_path) as fin: config = json.load(fin) + # Read Mooncake-specific knob + prefer_local_alloc = config.get("mooncake_prefer_local_alloc", False) + return MooncakeStoreConfig( local_hostname=config.get("local_hostname"), metadata_server=config.get("metadata_server"), @@ -54,7 +58,7 @@ def from_file(file_path: str) -> "MooncakeStoreConfig": master_server_address=config.get("master_server_address"), transfer_timeout=config.get("transfer_timeout", 1), storage_root_dir=config.get("storage_root_dir", ""), - prefer_local_alloc=config.get("prefer_local_alloc", False), + prefer_local_alloc=prefer_local_alloc, ) @staticmethod @@ -75,6 +79,9 @@ def load_from_lmcache_config( extra_config = config.extra_config if extra_config is None: raise ValueError("The extra config is not set.") + # Read Mooncake-specific knob + prefer_local_alloc = extra_config.get("mooncake_prefer_local_alloc", False) + return MooncakeStoreConfig( local_hostname=extra_config["local_hostname"], metadata_server=extra_config["metadata_server"], @@ -85,7 +92,7 @@ def load_from_lmcache_config( master_server_address=extra_config["master_server_address"], transfer_timeout=extra_config.get("transfer_timeout", 1), storage_root_dir=extra_config.get("storage_root_dir", ""), - prefer_local_alloc=extra_config.get("prefer_local_alloc", False), + prefer_local_alloc=prefer_local_alloc, ) @@ -101,7 +108,11 @@ def __init__( ): try: # Third Party - from mooncake.store import MooncakeDistributedStore, ReplicateConfig + from mooncake.store import ( + MooncakeDistributedStore, + ReplicateConfig, + bind_to_numa_node, + ) except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " @@ -121,8 +132,9 @@ def __init__( else: raise ValueError("MOONCAKE_CONFIG_PATH/lmcache_config must be provided") - if host != "" and port != 0: - self.config.master_server_address = host + ":" + str(port) + if not self.config.master_server_address: + if host != "" and port != 0: + self.config.master_server_address = host + ":" + str(port) if dev_name != "": self.config.device_name = dev_name logger.info("Mooncake Configuration loaded. config: %s", self.config) @@ -146,6 +158,36 @@ def __init__( logger.info(f" device_name: {self.config.device_name}") logger.info(f" master_server_address: {self.config.master_server_address}") + try: + numa_mapping = getattr( + local_cpu_backend.memory_allocator, "numa_mapping", None + ) + if numa_mapping is None and lmcache_config is not None: + numa_mapping = NUMADetector.get_numa_mapping(lmcache_config) + + if numa_mapping: + current_device_id = torch.cuda.current_device() + gpu_to_numa = getattr(numa_mapping, "gpu_to_numa_mapping", {}) + numa_id = gpu_to_numa.get(current_device_id) + logger.info( + f"NUMA mapping detected (pre-Mooncake setup): {gpu_to_numa}" + ) + if numa_id is not None: + bind_to_numa_node(numa_id) + logger.info( + f"GPU {current_device_id}, NUMA node {numa_id} binding done" + ) + else: + logger.info( + f"NUMA mapping not found for GPU {current_device_id}" + ) + else: + logger.info("NUMA mapping unavailable or disabled") + except Exception as e: + logger.warning( + f"Failed to determine NUMA mapping before Mooncake setup: {e}" + ) + self.store.setup( self.config.local_hostname, self.config.metadata_server, @@ -264,6 +306,22 @@ async def batched_get( # Use optimized mode with local metadata return await self._batch_get_into(keys) + def support_batched_async_contains(self) -> bool: + return True + + async def batched_async_contains( + self, + lookup_id: str, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + num_hit_counts = 0 + for key in keys: + if not self.store.is_exist(key.to_string()): + break + num_hit_counts += 1 + return num_hit_counts + async def _batch_get_into( self, keys: List[CacheEngineKey] ) -> List[Optional[MemoryObj]]: @@ -440,6 +498,71 @@ async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): # Use put_from without metadata (zero-copy) await self._put_without_metadata(key_str, memory_obj) + def support_batched_put(self) -> bool: + return True + + async def batched_put( + self, + keys: List[CacheEngineKey], + memory_objs: List[MemoryObj], + ): + """ + Batched put with clear split by metadata mode. + - save_chunk_meta False: use Mooncake's batch_put_from (zero-copy). + - save_chunk_meta True: no batch API; fall back to sequential put_parts. + """ + if not keys: + return + + if self.save_chunk_meta: + await self._batched_put_with_metadata(keys, memory_objs) + else: + await self._batched_put_zero_copy(keys, memory_objs) + + async def _batched_put_zero_copy( + self, + keys: List[CacheEngineKey], + memory_objs: List[MemoryObj], + ) -> None: + key_strs = [k.to_string() for k in keys] + buffer_ptrs: list[int] = [] + buffer_sizes: list[int] = [] + for obj in memory_objs: + tensor = obj.tensor + assert tensor is not None + buffer_ptrs.append(tensor.data_ptr()) + buffer_sizes.append(tensor.numel() * tensor.element_size()) + + try: + await asyncio.wait_for( + asyncio.to_thread( + self.store.batch_put_from, + key_strs, + buffer_ptrs, + buffer_sizes, + self.replica_config, + ), + timeout=self.config.transfer_timeout, + ) + except asyncio.TimeoutError: + logger.warning( + "Timeout during batch_put_from; some decoders may redo prefill." + ) + finally: + for obj in memory_objs: + obj.ref_count_down() + + async def _batched_put_with_metadata( + self, + keys: List[CacheEngineKey], + memory_objs: List[MemoryObj], + ) -> None: + for key, obj in zip(keys, memory_objs, strict=False): + try: + await self._put_with_metadata(key.to_string(), obj) + finally: + obj.ref_count_down() + async def _put_without_metadata(self, key_str: str, memory_obj: MemoryObj): """ Zero-copy put using put_from when metadata is not stored remotely. diff --git a/lmcache/v1/storage_backend/connector/redis_adapter.py b/lmcache/v1/storage_backend/connector/redis_adapter.py index 07bed2123dd..d9788dccf42 100644 --- a/lmcache/v1/storage_backend/connector/redis_adapter.py +++ b/lmcache/v1/storage_backend/connector/redis_adapter.py @@ -75,3 +75,48 @@ def create_connector(self, context: ConnectorContext) -> RemoteConnector: loop=context.loop, local_cpu_backend=context.local_cpu_backend, ) + + +class RedisClusterConnectorAdapter(ConnectorAdapter): + """Adapter for Redis Cluster connectors.""" + + def __init__(self) -> None: + super().__init__("redis-cluster://") + + def can_parse(self, url: str) -> bool: + return url.startswith(self.schema) + + def create_connector(self, context: ConnectorContext) -> RemoteConnector: + # Local + from .redis_connector import RedisClusterConnector + + logger.info(f"Creating Redis Cluster connector for URL: {context.url}") + url = context.url[len(self.schema) :] + + # Parse username and password + username: str = "" + password: str = "" + if "@" in url: + auth, url = url.split("@", 1) + if ":" in auth: + username, password = auth.split(":", 1) + else: + username = auth + + # Parse host and port + hosts_and_ports: List[Tuple[str, int]] = [] + assert self.schema is not None + for sub_url in url.split(","): + if not sub_url.startswith(self.schema): + sub_url = self.schema + sub_url + + parsed_url = parse_remote_url(sub_url) + hosts_and_ports.append((parsed_url.host, parsed_url.port)) + + return RedisClusterConnector( + hosts_and_ports=hosts_and_ports, + username=username, + password=password, + loop=context.loop, + local_cpu_backend=context.local_cpu_backend, + ) diff --git a/lmcache/v1/storage_backend/connector/redis_connector.py b/lmcache/v1/storage_backend/connector/redis_connector.py index d18d87bafc5..38f3b82e44e 100644 --- a/lmcache/v1/storage_backend/connector/redis_connector.py +++ b/lmcache/v1/storage_backend/connector/redis_connector.py @@ -7,6 +7,7 @@ import os # Third Party +from redis.asyncio.cluster import ClusterNode, RedisCluster import redis.asyncio as redis # First Party @@ -387,3 +388,232 @@ async def list(self) -> List[str]: async def close(self): self.master.close() self.slave.close() + + +class RedisClusterConnector(RemoteConnector): + """ + The remote url starts with "redis-cluster:// and can include one or + multiple hosts:ports, separated by commas. + + Example: + remote_url: "redis-cluster://host1:7000,host2:7000,host3:7000" + + Extra environment variables: + - REDIS_TIMEOUT (optional) -- Timeout in seconds, default is 1 if not set + """ + + def __init__( + self, + hosts_and_ports: List[Tuple[str, int]], + username: str, + password: str, + loop: asyncio.AbstractEventLoop, + local_cpu_backend: LocalCPUBackend, + ): + # Convert hosts_and_ports to startup_nodes format expected by RedisCluster + startup_nodes = [ClusterNode(h, p) for (h, p) in hosts_and_ports] + + # set a large max + self.max_connections = 150 + # redis will crash if we have more than max_connections connections + self.sem = asyncio.Semaphore(self.max_connections) + + # Initialize cluster connection + self.cluster = RedisCluster( + startup_nodes=startup_nodes, + username=username, + password=password, + max_connections=self.max_connections, + decode_responses=False, + ) + self.loop = loop + self.local_cpu_backend = local_cpu_backend + + self.pq_executor = AsyncPQExecutor(loop) + + async def _exists(self, key: CacheEngineKey) -> bool: + async with self.sem: + return bool(await self.cluster.exists(key.to_string() + "metadata")) + + async def exists(self, key: CacheEngineKey) -> bool: + return await self.pq_executor.submit_job( + self._exists, key=key, priority=Priorities.PEEK + ) + + def exists_sync(self, key: CacheEngineKey) -> bool: + future = asyncio.run_coroutine_threadsafe(self.exists(key), self.loop) + return bool(future.result()) + + async def _get(self, key: CacheEngineKey) -> Optional[MemoryObj]: + key_str = key.to_string() + async with self.sem: + metadata_bytes = await self.cluster.get(key_str + "metadata") + + if metadata_bytes is None: + return None + + assert not inspect.isawaitable(metadata_bytes) + + metadata = RemoteMetadata.deserialize(memoryview(metadata_bytes)) + + memory_obj = self.local_cpu_backend.allocate( + metadata.shape, + metadata.dtype, + metadata.fmt, + ) + if memory_obj is None: + logger.warning("Failed to allocate memory during remote receive") + return None + + # TODO(Jiayi): Find a way to do `get` inplace + kv_bytes = await self.cluster.get(key_str + "kv_bytes") + + assert not inspect.isawaitable(kv_bytes) + + if kv_bytes is None: + # TODO (Jiayi): We might need a way to better handle + # consistency issues. + # TODO (Jiayi): A better way is to aggregate metadata + # and kv cache in one key. + logger.warning( + "Key exists but KV cache does not exist." + "Might happen when the cache is evicted by redis." + ) + async with self.sem: + await self.cluster.delete(key_str + "metadata") + return None + + if isinstance(memory_obj.byte_array, memoryview): + view = memory_obj.byte_array + if view.format == " Optional[MemoryObj]: + return await self.pq_executor.submit_job( + self._get, key=key, priority=Priorities.GET + ) + + def support_batched_put(self) -> bool: + return True + + async def _batched_put( + self, keys: List[CacheEngineKey], memory_objs: List[MemoryObj] + ): + # calling self.put will create a circular dependency + await asyncio.gather( + *( + self._put(key, memory_obj) + for key, memory_obj in zip(keys, memory_objs, strict=False) + ) + ) + + async def batched_put( + self, keys: List[CacheEngineKey], memory_objs: List[MemoryObj] + ): + await self.pq_executor.submit_job( + self._batched_put, + keys=keys, + memory_objs=memory_objs, + priority=Priorities.PUT, + ) + + async def _put(self, key: CacheEngineKey, memory_obj: MemoryObj): + # TODO(Jiayi): The following code is ugly. + # Please use a function like `memory_obj.to_meta()`. + kv_bytes = memory_obj.byte_array + kv_shape = memory_obj.get_shape() + kv_dtype = memory_obj.get_dtype() + memory_format = memory_obj.get_memory_format() + + metadata_bytes = RemoteMetadata( + len(kv_bytes), kv_shape, kv_dtype, memory_format + ).serialize() + + key_str = key.to_string() + # kv bytes needs to be set first to avoid race condition + async with self.sem: + await self.cluster.set(key_str + "kv_bytes", kv_bytes) + await self.cluster.set(key_str + "metadata", metadata_bytes) + + async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): + await self.pq_executor.submit_job( + self._put, key=key, memory_obj=memory_obj, priority=Priorities.PUT + ) + + # TODO + @no_type_check + async def list(self) -> List[str]: + pass + + async def close(self): + await self.pq_executor.shutdown(wait=True) + await self.cluster.close() + logger.info("Closed the redis cluster connection") + + def support_batched_async_contains(self) -> bool: + return True + + async def _batched_async_contains( + self, + lookup_id: str, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + num_hit_counts = 0 + for key in keys: + async with self.sem: + if not await self.cluster.exists(key.to_string() + "metadata"): + return num_hit_counts + num_hit_counts += 1 + return num_hit_counts + + async def batched_async_contains( + self, + lookup_id: str, + keys: List[CacheEngineKey], + pin: bool = False, + ) -> int: + return await self.pq_executor.submit_job( + self._batched_async_contains, + lookup_id=lookup_id, + keys=keys, + pin=pin, + priority=Priorities.PEEK, + ) + + def support_batched_get_non_blocking(self) -> bool: + return True + + async def _batched_get_non_blocking( + self, + lookup_id: str, + keys: List[CacheEngineKey], + ) -> List[MemoryObj]: + # calling self.get will create a circular dependency + results = await asyncio.gather(*(self._get(key) for key in keys)) + return [r for r in results if r is not None] + + async def batched_get_non_blocking( + self, + lookup_id: str, + keys: List[CacheEngineKey], + ) -> List[MemoryObj]: + return await self.pq_executor.submit_job( + self._batched_get_non_blocking, + lookup_id=lookup_id, + keys=keys, + priority=Priorities.PREFETCH, + ) diff --git a/lmcache/v1/storage_backend/connector/sagemaker_hyperpod_adapter.py b/lmcache/v1/storage_backend/connector/sagemaker_hyperpod_adapter.py new file mode 100644 index 00000000000..37c1bc07382 --- /dev/null +++ b/lmcache/v1/storage_backend/connector/sagemaker_hyperpod_adapter.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import Any, Dict +import os + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.storage_backend.connector import ( + ConnectorAdapter, + ConnectorContext, +) +from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector + +logger = init_logger(__name__) + + +class SageMakerHyperPodConnectorAdapter(ConnectorAdapter): + """Adapter for SageMaker HyperPod connectors.""" + + def __init__(self) -> None: + super().__init__("sagemaker-hyperpod://") + + def create_connector(self, context: ConnectorContext) -> RemoteConnector: + """ + Create a SageMaker HyperPod connector from the given context. + + Args: + context: Connector context containing configuration + + Returns: + Initialized SageMaker HyperPod connector + + Raises: + ValueError: If configuration is invalid + RuntimeError: If shared memory initialization fails + """ + # Local import to avoid circular dependencies + # Local + from .sagemaker_hyperpod_connector import SageMakerHyperPodConnector + + config = context.config + assert config is not None, "Config must not be None" + assert context.loop is not None, "context.loop must not be None" + assert context.local_cpu_backend is not None, ( + "context.local_cpu_backend must not be None" + ) + + # Default configuration values with type hints + defaults: Dict[str, Any] = { + "bucket": "lmcache", + "shared_memory_name": "shared_memory", + "max_concurrent_requests": 100, + "max_connections": 256, + "max_connections_per_host": 128, + "timeout_ms": 5000, + "lease_ttl_s": 30.0, + "put_stream_chunk_bytes": 65536, + "use_https": False, + "max_lease_size_mb": None, + } + + # Extract and validate configuration + extra_config = config.extra_config or {} + + bucket_name = str( + extra_config.get("sagemaker_hyperpod_bucket", defaults["bucket"]) + ) + shared_memory_name = extra_config.get( + "sagemaker_hyperpod_shared_memory_name", defaults["shared_memory_name"] + ) + max_concurrent_requests = self._get_positive_int( + extra_config, + "sagemaker_hyperpod_max_concurrent_requests", + int(defaults["max_concurrent_requests"]), # Cast to int + ) + max_connections = self._get_positive_int( + extra_config, + "sagemaker_hyperpod_max_connections", + int(defaults["max_connections"]), # Cast to int + ) + max_connections_per_host = self._get_positive_int( + extra_config, + "sagemaker_hyperpod_max_connections_per_host", + int(defaults["max_connections_per_host"]), # Cast to int + ) + timeout_ms = self._get_positive_int( + extra_config, + "sagemaker_hyperpod_timeout_ms", + int(defaults["timeout_ms"]), # Cast to int + ) + lease_ttl_s = self._get_positive_float( + extra_config, + "sagemaker_hyperpod_lease_ttl_s", + float(defaults["lease_ttl_s"]), # Cast to float + ) + put_stream_chunk_bytes = self._get_positive_int( + extra_config, + "sagemaker_hyperpod_put_stream_chunk_bytes", + int(defaults["put_stream_chunk_bytes"]), # Cast to int + ) + use_https = bool( + extra_config.get("sagemaker_hyperpod_use_https", defaults["use_https"]) + ) + max_lease_size_mb = extra_config.get( + "sagemaker_hyperpod_max_lease_size_mb", defaults["max_lease_size_mb"] + ) + + if max_lease_size_mb is not None: + try: + max_lease_size_mb = float(max_lease_size_mb) + if max_lease_size_mb <= 0: + raise ValueError( + f"sagemaker_hyperpod_max_lease_size_mb must be positive," + f" got {max_lease_size_mb}" + ) + except (TypeError, ValueError) as e: + raise ValueError( + f"Invalid value for sagemaker_hyperpod_max_lease_size_mb:" + f" {max_lease_size_mb}" + ) from e + + # Parse and construct URL + url = self._parse_url(context.url, use_https) + + logger.info( + f"Creating SageMaker HyperPod connector: url={url}, " + f"bucket={bucket_name}, shared_memory={shared_memory_name}, " + f"max_connections={max_connections}, " + f"max_concurrent_requests={max_concurrent_requests}, " + f"timeout_ms={timeout_ms}, lease_ttl_s={lease_ttl_s}s" + f"max_lease_size_mb=" + f"{max_lease_size_mb if max_lease_size_mb else 'unlimited'}" + ) + + # Create connector instance + connector = SageMakerHyperPodConnector( + sagemaker_hyperpod_url=url, + loop=context.loop, + local_cpu_backend=context.local_cpu_backend, + bucket_name=bucket_name, + shared_memory_name=shared_memory_name, + max_concurrent_requests=max_concurrent_requests, + max_connections=max_connections, + max_connections_per_host=max_connections_per_host, + timeout_ms=timeout_ms, + lease_ttl_s=lease_ttl_s, + put_stream_chunk_bytes=put_stream_chunk_bytes, + max_lease_size_mb=max_lease_size_mb, + ) + + # Initialize shared memory connection + try: + connector.post_init() + except Exception as e: + logger.error(f"Failed to initialize SageMaker HyperPod connector: {e}") + raise RuntimeError( + f"SageMaker HyperPod connector initialization failed: {e}" + ) from e + + logger.info("SageMaker HyperPod connector created successfully") + return connector + + @staticmethod + def _parse_url(url: str, use_https: bool) -> str: + """ + Parse and normalize the SageMaker HyperPod URL. + + Args: + url: Raw URL from context (e.g., "sagemaker-hyperpod://127.0.0.1:9200") + use_https: Whether to use HTTPS protocol + + Returns: + Normalized HTTP/HTTPS URL + """ + assert url, "SageMaker HyperPod URL must not be empty" + + expanded_url = os.path.expandvars(url) + + # Strip the sagemaker-hyperpod:// prefix + raw_url = expanded_url.removeprefix("sagemaker-hyperpod://") + + assert raw_url, ( + "SageMaker HyperPod URL must contain host information after 'sagemaker-hyperpod://'" + ) + + # If URL already has protocol, use it as-is + if raw_url.startswith("http://") or raw_url.startswith("https://"): + return raw_url + + # Otherwise, add appropriate protocol + protocol = "https" if use_https else "http" + return f"{protocol}://{raw_url}" + + @staticmethod + def _get_positive_int(config_dict: dict, key: str, default: int) -> int: + """ + Extract a positive integer from config with validation. + + Args: + config_dict: Configuration dictionary + key: Configuration key + default: Default value if key not found + + Returns: + Validated positive integer + + Raises: + ValueError: If value is not a positive integer + """ + value = config_dict.get(key, default) + try: + int_value = int(value) + if int_value <= 0: + raise ValueError(f"{key} must be positive, got {int_value}") + return int_value + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid value for {key}: {value}") from e + + @staticmethod + def _get_positive_float(config_dict: dict, key: str, default: float) -> float: + """ + Extract a positive float from config with validation. + + Args: + config_dict: Configuration dictionary + key: Configuration key + default: Default value if key not found + + Returns: + Validated positive float + + Raises: + ValueError: If value is not a positive float + """ + value = config_dict.get(key, default) + try: + float_value = float(value) + if float_value <= 0: + raise ValueError(f"{key} must be positive, got {float_value}") + return float_value + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid value for {key}: {value}") from e diff --git a/lmcache/v1/storage_backend/connector/sagemaker_hyperpod_connector.py b/lmcache/v1/storage_backend/connector/sagemaker_hyperpod_connector.py new file mode 100644 index 00000000000..20e2380229e --- /dev/null +++ b/lmcache/v1/storage_backend/connector/sagemaker_hyperpod_connector.py @@ -0,0 +1,950 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from dataclasses import dataclass +from enum import IntEnum, auto +from multiprocessing import shared_memory +from typing import AsyncIterator, List, Optional, Tuple +import asyncio +import json +import urllib.parse + +# Third Party +import aiohttp +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.utils import CacheEngineKey +from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.protocol import RemoteMetadata +from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector +from lmcache.v1.storage_backend.job_executor.pq_executor import AsyncPQExecutor +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend + +logger = init_logger(__name__) + + +# Constants +METADATA_SIZE_BYTES = 28 # RemoteMetadata is 7 int32 fields +METADATA_SHAPE_DIMS = 4 # Number of shape dimensions in metadata +DEFAULT_CHUNK_SIZE_BYTES = 65536 # 64KB default for streaming +HTTP_OK = 200 +HTTP_NO_CONTENT = 204 +HTTP_NOT_FOUND = 404 +HTTP_CONFLICT = 409 + + +class Priorities(IntEnum): + """Priority levels for job execution in the priority queue.""" + + LEASE = 0 # Highest priority - lease acquisition/release + PREFETCH = auto() # Medium priority - prefetching data + PUT = auto() # Lower priority - storing data + + +@dataclass +class LeaseInfo: + """Information about a lease obtained from ai-toolkit daemon. + + A lease represents temporary exclusive access to cached data in shared memory. + The daemon manages leases to prevent data from being evicted while in use. + """ + + lease_id: str + offsets: List[Tuple[int, int]] # (offset, length) pairs in shared memory + + +class SageMakerHyperPodConnector(RemoteConnector): + """ + SageMaker HyperPod remote connector for communicating with KV cache daemon + in SageMaker HyperPod. + + This connector provides high-performance access to KV cache data stored in + a remote SageMaker HyperPod service using: + - Shared memory (data plane) - zero-copy access via shared memory segment + - HTTP (control plane) - lease acquisition, release, and PUT operations + + The connector uses a lease-based protocol with immediate release after all reads + """ + + def __init__( + self, + sagemaker_hyperpod_url: str, + loop: asyncio.AbstractEventLoop, + local_cpu_backend: LocalCPUBackend, + bucket_name: str, + shared_memory_name: Optional[str], + max_concurrent_requests: int, + max_connections: int, + max_connections_per_host: int, + timeout_ms: int, + lease_ttl_s: float = 10.0, + put_stream_chunk_bytes: int = DEFAULT_CHUNK_SIZE_BYTES, + max_lease_size_mb: Optional[float] = None, + **kwargs, # Accept and ignore unused legacy parameters + ): + """ + Initialize SageMaker HyperPod connector. + + Args: + sagemaker_hyperpod_url: Base URL of the ai-toolkit daemon + loop: Event loop for async operations + local_cpu_backend: Backend for local memory allocation + bucket_name: Bucket name for KV storage namespace + shared_memory_name: Name of shared memory segment + (if None, shared memory disabled) + max_concurrent_requests: Maximum concurrent control plane requests + max_connections: Maximum total HTTP connections in pool + max_connections_per_host: Maximum HTTP connections per host + timeout_ms: Timeout for lease acquisition requests + lease_ttl_s: Server-side lease timeout (default: 10s) + put_stream_chunk_bytes: Chunk size for + streaming PUT requests (default: 64KB) + **kwargs: Unused legacy parameters (ignored for backward compatibility) + """ + super().__init__() + + # Core configuration + self.base_url = sagemaker_hyperpod_url.rstrip("/") + self.loop = loop + self.local_cpu_backend = local_cpu_backend + self.bucket_name = bucket_name + self.shared_memory_name = shared_memory_name + self.lease_ttl_s = lease_ttl_s + self.put_stream_chunk_bytes = max(1024, put_stream_chunk_bytes) # Minimum 1KB + self.max_lease_size_bytes = ( + int(max_lease_size_mb * 1024 * 1024) if max_lease_size_mb else None + ) + + # HTTP configuration + self.max_concurrent_requests = max(1, max_concurrent_requests) + self.max_connections = max(1, max_connections) + self.max_connections_per_host = max(1, max_connections_per_host) + self.timeout_ms = max(100, timeout_ms) # Minimum 100ms + + # HTTP session (lazy initialized) + self.http_session: Optional[aiohttp.ClientSession] = None + self.session_lock = asyncio.Lock() + + # Concurrency control + self.control_inflight = asyncio.Semaphore(self.max_concurrent_requests) + self.put_inflight = asyncio.Semaphore(self.max_concurrent_requests) + self.pq_executor = AsyncPQExecutor(loop) + + # Shared memory (lazy initialized) + self.shared_memory_obj: Optional[shared_memory.SharedMemory] = None + self.shared_memory_map: Optional[memoryview] = None + + # Statistics for monitoring + self.stats = { + "get_success": 0, + "get_failure": 0, + "put_success": 0, + "put_failure": 0, + "lease_acquired": 0, + "lease_released": 0, + "lease_release_failed": 0, + } + + logger.info( + f"SageMaker HyperPod Connector initialized: url={self.base_url}, " + f"bucket={self.bucket_name}, shared_memory={self.shared_memory_name}, " + f"connections={self.max_connections}, lease_ttl={lease_ttl_s}s" + ) + + def post_init(self): + """Initialize shared memory connection after construction.""" + if self.shared_memory_name: + self._init_shared_memory() + + def _init_shared_memory(self): + """Initialize shared memory connection to ai-toolkit daemon.""" + try: + self.shared_memory_obj = shared_memory.SharedMemory( + name=self.shared_memory_name, create=False + ) + self.shared_memory_map = memoryview(self.shared_memory_obj.buf) + size_mb = len(self.shared_memory_map) / (1024**2) + logger.info( + f"Shared memory opened: {self.shared_memory_name} ({size_mb:.2f} MB)" + ) + except FileNotFoundError: + logger.error( + f"Shared memory segment '{self.shared_memory_name}' not found. " + "Ensure ai-toolkit daemon is running." + ) + self.shared_memory_map = None + raise + except Exception as e: + logger.error(f"Failed to initialize shared memory: {e}") + self.shared_memory_map = None + raise + + async def _ensure_http_session(self) -> aiohttp.ClientSession: + """Ensure HTTP session with connection pooling is initialized.""" + if self.http_session is None: + async with self.session_lock: + if self.http_session is None: # Double-check locking + connector = aiohttp.TCPConnector( + limit=self.max_connections, + limit_per_host=self.max_connections_per_host, + ttl_dns_cache=300, + use_dns_cache=True, + keepalive_timeout=30, + enable_cleanup_closed=True, + ) + + timeout = aiohttp.ClientTimeout( + total=30, + connect=5, + sock_read=10, + ) + + self.http_session = aiohttp.ClientSession( + connector=connector, + timeout=timeout, + headers={"User-Agent": "LMCache-SageMaker-HyperPod/1.0"}, + ) + logger.info( + f"HTTP session created with {self.max_connections} " + f"max connections" + ) + + return self.http_session + + async def _http_request( + self, + method: str, + url: str, + data=None, + params=None, + timeout: float = 5.0, + headers=None, + gate: Optional[asyncio.Semaphore] = None, + ): + """Execute HTTP request with optional semaphore gate.""" + if gate is None: + return await self._http_request_impl( + method, url, data=data, params=params, timeout=timeout, headers=headers + ) + async with gate: + return await self._http_request_impl( + method, url, data=data, params=params, timeout=timeout, headers=headers + ) + + async def _http_request_impl( + self, + method: str, + url: str, + data=None, + params=None, + timeout: float = 5.0, + headers=None, + ): + """Execute HTTP request with connection pooling and error handling.""" + try: + session = await self._ensure_http_session() + request_timeout = aiohttp.ClientTimeout(total=timeout) + + async with session.request( + method, + url, + data=data, + params=params, + timeout=request_timeout, + headers=headers, + ) as response: + # Parse JSON response if available + body_json = None + content_type = response.headers.get("Content-Type", "") + if content_type.startswith("application/json"): + try: + body_json = await response.json() + except aiohttp.ContentTypeError as e: + logger.warning( + f"JSON parsing failed for {method} {url}:" + f"invalid content-type - {e}" + ) + except json.JSONDecodeError as e: + logger.warning( + f"JSON parsing failed for {method} {url}:" + f"malformed JSON - {e}" + ) + except Exception as e: + logger.warning(f"JSON parsing failed for {method} {url}: {e}") + + return { + "status": response.status, + "json": body_json, + } + + except asyncio.TimeoutError: + logger.warning(f"HTTP {method} timeout: {url}") + return None + except aiohttp.ClientError as e: + logger.error(f"HTTP {method} client error: {url} - {e}") + return None + except Exception as e: + logger.error(f"HTTP {method} failed: {url} - {e}") + return None + + def _key_to_string(self, key: CacheEngineKey) -> str: + """Convert CacheEngineKey to URL-safe string format.""" + key_str = key.to_string() + return urllib.parse.quote(key_str, safe="") + + async def _release_lease(self, key: CacheEngineKey, lease_id: str) -> bool: + """ + Release a lease to free server resources immediately. + + Args: + key: The cache key + lease_id: The lease ID to release + + Returns: + True if release successful, False on error + """ + key_str = self._key_to_string(key) + url = f"{self.base_url}/v1/leases/{lease_id}/release" + + try: + result = await self._http_request( + "POST", + url, + timeout=5.0, + gate=self.control_inflight, + ) + + if result and result["status"] == HTTP_OK: + self.stats["lease_released"] += 1 + logger.debug(f"Lease released: key={key_str}, lease_id={lease_id}") + return True + else: + status = result["status"] if result else "TIMEOUT" + self.stats["lease_release_failed"] += 1 + logger.warning( + f"Lease release failed: key={key_str}, lease_id={lease_id}, " + f"status={status}" + ) + return False + + except Exception as e: + self.stats["lease_release_failed"] += 1 + logger.warning( + f"Lease release error: key={key_str}, lease_id={lease_id} - {e}" + ) + return False + + async def _acquire_lease(self, key: CacheEngineKey) -> Optional[LeaseInfo]: + """ + Acquire a lease for the given key. + + A lease prevents the daemon from evicting data while we're reading it. + The response includes offset information for shared memory access. + + Args: + key: The cache key to acquire lease for + + Returns: + LeaseInfo if successful, None otherwise + """ + key_str = self._key_to_string(key) + url = f"{self.base_url}/v1/kv/{self.bucket_name}/{key_str}/leases" + params = { + "timeout_ms": self.timeout_ms, + "ttl_s": self.lease_ttl_s, + } + + result = await self._http_request( + "POST", + url, + params=params, + timeout=self.timeout_ms / 1000.0, + gate=self.control_inflight, + ) + + if not result or result["status"] != HTTP_OK or not result["json"]: + logger.debug(f"Lease acquisition failed: key={key_str}") + return None + + lease_data = result["json"] + offsets = [(o["offset"], o["len"]) for o in lease_data.get("offsets", [])] + + if not offsets: + logger.debug(f"Lease has no offsets: key={key_str}") + return None + + lease_info = LeaseInfo( + lease_id=lease_data["id"], + offsets=offsets, + ) + + total_size = sum(length for _, length in offsets) + + if ( + self.max_lease_size_bytes is not None + and total_size > self.max_lease_size_bytes + ): + logger.warning( + f"Lease size {total_size / 1024:.2f} KB exceeds limit " + f"{self.max_lease_size_bytes / 1024:.2f} KB, releasing" + ) + await self._release_lease(key, lease_info.lease_id) + return None + + self.stats["lease_acquired"] += 1 + + logger.debug( + f"Lease acquired: key={key_str}, lease_id={lease_info.lease_id}, " + f"size={total_size / 1024:.2f} KB, blocks={len(offsets)}" + ) + + return lease_info + + async def _executor_submit_lease_acquisition( + self, key: CacheEngineKey + ) -> Optional[LeaseInfo]: + """Submit lease acquisition to priority executor.""" + return await self.pq_executor.submit_job( + self._acquire_lease, + key=key, + priority=Priorities.LEASE, + ) + + def _read_from_shared_memory( + self, key: CacheEngineKey, lease_info: LeaseInfo + ) -> Optional[MemoryObj]: + """ + Read data from shared memory using lease offsets. + + Data format: [RemoteMetadata header (28 bytes)] + [KV cache payload] + Data may be fragmented across multiple blocks in shared memory. + + Args: + key: The cache key being read + lease_info: Lease information with memory offsets + + Returns: + MemoryObj containing the data, or None on error + """ + if self.shared_memory_map is None: + logger.error("Shared memory not available") + return None + + if not lease_info.offsets: + logger.error("No offsets in lease") + return None + + shm_size = len(self.shared_memory_map) + for offset, length in lease_info.offsets: + if offset < 0 or length < 0: + logger.error( + f"Invalid offset or length: offset={offset}, length={length}" + ) + return None + if offset + length > shm_size: + logger.error( + f"Offset out of bounds: offset={offset}, length={length}, " + f"shm_size={shm_size}" + ) + return None + + memory_obj = None + try: + # Validate total size + total_size = sum(length for _, length in lease_info.offsets) + if total_size < METADATA_SIZE_BYTES: + logger.error(f"Insufficient data for metadata: {total_size} bytes") + return None + + # Read metadata header (may span multiple blocks) + header = self._read_bytes_from_offsets( + lease_info.offsets, 0, METADATA_SIZE_BYTES + ) + if len(header) < METADATA_SIZE_BYTES: + logger.error("Failed to read complete metadata header") + return None + + # Parse metadata + metadata = RemoteMetadata.deserialize(header) + if metadata.length <= 0: + logger.error(f"Invalid payload length: {metadata.length}") + return None + + # Restore original shape (remove padding zeros) + actual_shape = self._parse_shape(metadata.shape) + + # Allocate local CPU memory + memory_obj = self.local_cpu_backend.allocate( + actual_shape, + metadata.dtype, + metadata.fmt, + ) + if memory_obj is None: + logger.error(f"Failed to allocate memory for key {key.to_string()}") + return None + + # Get writable view + view = self._get_writable_view(memory_obj.byte_array) + + # Copy payload data from shared memory (skip header) + copied = self._copy_bytes_from_offsets( + lease_info.offsets, METADATA_SIZE_BYTES, metadata.length, view + ) + + if copied != metadata.length: + logger.error( + f"Data size mismatch: expected {metadata.length}, got {copied}" + ) + memory_obj.ref_count_down() + return None + + logger.debug( + f"Read from shared memory: key={key.to_string()}, " + f"shape={actual_shape}, dtype={metadata.dtype}," + f"size={metadata.length} bytes" + ) + + return memory_obj + + except Exception as e: + logger.error( + f"Error reading from shared memory: key={key.to_string()} - {e}" + ) + if memory_obj is not None: + memory_obj.ref_count_down() + return None + + def _read_bytes_from_offsets( + self, offsets: List[Tuple[int, int]], skip_bytes: int, read_bytes: int + ) -> bytearray: + """Read bytes from shared memory offsets, skipping initial bytes.""" + if self.shared_memory_map is None: + logger.error("Shared memory not available") + return bytearray() + + result = bytearray(read_bytes) + filled = 0 + bytes_to_skip = skip_bytes + shm_size = len(self.shared_memory_map) + + for offset, length in offsets: + if filled >= read_bytes: + break + + # Skip header bytes in first chunk(s) + if bytes_to_skip > 0: + if length <= bytes_to_skip: + bytes_to_skip -= length + continue + offset += bytes_to_skip + length -= bytes_to_skip + bytes_to_skip = 0 + + if length <= 0: + continue + + take = min(read_bytes - filled, length) + + if offset < 0 or take <= 0: + logger.error(f"Invalid read parameters: offset={offset}, take={take}") + break + if offset + take > shm_size: + logger.error( + f"Read would exceed shared memory bounds: " + f"offset={offset}, take={take}, shm_size={shm_size}" + ) + break + + result[filled : filled + take] = self.shared_memory_map[ + offset : offset + take + ] + filled += take + + return result + + def _copy_bytes_from_offsets( + self, + offsets: List[Tuple[int, int]], + skip_bytes: int, + copy_bytes: int, + dest_view: memoryview, + ) -> int: + """Copy bytes from shared memory offsets to destination view.""" + if self.shared_memory_map is None: + logger.error("Shared memory not available") + return 0 + + copied = 0 + bytes_to_skip = skip_bytes + shm_size = len(self.shared_memory_map) + + for offset, length in offsets: + if copied >= copy_bytes: + break + + # Skip header bytes + if bytes_to_skip > 0: + if length <= bytes_to_skip: + bytes_to_skip -= length + continue + offset += bytes_to_skip + length -= bytes_to_skip + bytes_to_skip = 0 + + if length <= 0: + continue + + take = min(copy_bytes - copied, length) + + if offset < 0 or take <= 0: + logger.error(f"Invalid copy parameters: offset={offset}, take={take}") + break + if offset + take > shm_size: + logger.error( + f"Copy would exceed shared memory bounds: " + f"offset={offset}, take={take}, shm_size={shm_size}" + ) + break + + dest_view[copied : copied + take] = self.shared_memory_map[ + offset : offset + take + ] + copied += take + + return copied + + @staticmethod + def _parse_shape(shape: torch.Size) -> torch.Size: + """Parse shape from metadata, removing padding zeros.""" + actual_shape_list: List[int] = [] + for dim in shape: + if dim == 0 and len(actual_shape_list) > 0: + break + actual_shape_list.append(dim) + return torch.Size(actual_shape_list) if actual_shape_list else torch.Size([1]) + + @staticmethod + def _get_writable_view(byte_array) -> memoryview: + """Get a writable memoryview from byte array.""" + if isinstance(byte_array, memoryview): + view = byte_array + if getattr(view, "format", None) == " bool: + """ + Check if a key exists in remote storage. + + Acquires a lease, checks existence, then releases immediately. + """ + lease = await self._executor_submit_lease_acquisition(key) + if lease is None: + return False + + try: + return True + finally: + await self._release_lease(key, lease.lease_id) + + def exists_sync(self, key: CacheEngineKey) -> bool: + """Check if a key exists in remote storage (sync wrapper).""" + future = asyncio.run_coroutine_threadsafe(self.exists(key), self.loop) + return bool(future.result()) + + async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: + """ + Retrieve KV cache data for the given key. + + Flow: + 1. Acquire a new lease + 2. Read from shared memory using lease offsets + 3. Release lease immediately (in finally block) + + Args: + key: The cache key to retrieve + + Returns: + MemoryObj containing the KV cache data, or None if not found + """ + lease_info = await self._executor_submit_lease_acquisition(key) + + if lease_info is None: + self.stats["get_failure"] += 1 + logger.debug(f"GET failed (no lease): key={key.to_string()}") + return None + + try: + memory_obj = self._read_from_shared_memory(key, lease_info) + + if memory_obj is not None: + self.stats["get_success"] += 1 + logger.debug( + f"GET success: key={key.to_string()}, " + f"shape={memory_obj.get_shape()}" + ) + else: + self.stats["get_failure"] += 1 + logger.error( + f"Failed to read from shared memory: key={key.to_string()}" + ) + + return memory_obj + + except Exception as e: + self.stats["get_failure"] += 1 + logger.error(f"GET error: key={key.to_string()} - {e}") + return None + + finally: + # Always release lease immediately after read + await self._release_lease(key, lease_info.lease_id) + + async def batched_get( + self, keys: List[CacheEngineKey] + ) -> List[Optional[MemoryObj]]: + """Get multiple keys in parallel.""" + tasks = [self.get(key) for key in keys] + return await asyncio.gather(*tasks) + + def support_batched_put(self) -> bool: + """Indicate support for batched PUT operations.""" + return True + + async def batched_put( + self, keys: List[CacheEngineKey], memory_objs: List[MemoryObj] + ): + """Store multiple objects in parallel.""" + await asyncio.gather( + *(self.put(key, mem) for key, mem in zip(keys, memory_objs, strict=True)) + ) + + async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): + """Store data to ai-toolkit (queued with priority).""" + return await self.pq_executor.submit_job( + self._put, + key=key, + memory_obj=memory_obj, + priority=Priorities.PUT, + ) + + async def _put(self, key: CacheEngineKey, memory_obj: MemoryObj): + """Internal PUT operation - sends data via HTTP streaming.""" + key_str = self._key_to_string(key) + url = f"{self.base_url}/v1/kv/{self.bucket_name}/{key_str}" + + try: + # Build streaming payload (header + data) + payload_len, payload_iter = self._build_put_stream(memory_obj) + + logger.debug( + f"PUT: key={key_str}, size={payload_len / 1024:.2f} KB, " + f"shape={memory_obj.get_shape()}" + ) + + # Send HTTP PUT request with streaming + result = await self._http_request( + "PUT", + url, + data=payload_iter, + timeout=self.timeout_ms / 1000.0, + headers={"Content-Length": str(payload_len)}, + gate=self.put_inflight, + ) + + if result and result["status"] == HTTP_OK: + self.stats["put_success"] += 1 + logger.info( + f"PUT success: key={key_str}, size={payload_len / 1024:.2f} KB" + ) + elif result and result["status"] == HTTP_CONFLICT: + # 409 Conflict = key already exists (not an error) + self.stats["put_success"] += 1 + logger.debug(f"PUT skipped (already exists): key={key_str}") + else: + status = result["status"] if result else "TIMEOUT" + self.stats["put_failure"] += 1 + logger.error(f"PUT failed: key={key_str}, status={status}") + + except Exception as e: + self.stats["put_failure"] += 1 + logger.error(f"PUT exception: key={key_str} - {e}") + + def _build_put_stream(self, memory_obj: MemoryObj) -> Tuple[int, AsyncIterator]: + """ + Build streaming payload: [RemoteMetadata (28 bytes)] + [KV cache data] + + Args: + memory_obj: The memory object to stream + + Returns: + Tuple of (total_length, async_generator) + """ + # Prepare data view + kv_view = self._get_writable_view(memory_obj.byte_array) + kv_len = len(kv_view) + + # Prepare metadata + shape = list(memory_obj.get_shape()) + padded_shape = (shape + [0] * METADATA_SHAPE_DIMS)[:METADATA_SHAPE_DIMS] + + metadata = RemoteMetadata( + kv_len, + torch.Size(padded_shape), + memory_obj.get_dtype(), + memory_obj.get_memory_format(), + ) + + # Serialize metadata header + header = bytearray(METADATA_SIZE_BYTES) + metadata.serialize_into(header) + header_bytes = bytes(header) + + total_len = len(header_bytes) + kv_len + chunk_size = self.put_stream_chunk_bytes + + async def generator() -> AsyncIterator: + # First yield header + yield header_bytes + # Then yield data in chunks + offset = 0 + while offset < kv_len: + next_offset = min(kv_len, offset + chunk_size) + yield kv_view[offset:next_offset] + offset = next_offset + + return total_len, generator() + + def support_batched_async_contains(self) -> bool: + """Indicate support for batched async contains operation.""" + return True + + async def _batched_async_contains( + self, lookup_id: str, keys: List[CacheEngineKey], pin: bool = False + ) -> int: + """ + Check existence of keys sequentially until first miss. + + Args: + lookup_id: Lookup identifier (for logging/tracking) + keys: List of keys to check + pin: Whether to pin data (unused, for API compatibility) + + Returns: + Number of consecutive hits from the start + """ + num_hits = 0 + for key in keys: + lease = None + try: + lease = await self._executor_submit_lease_acquisition(key) + if lease is None: + break + + num_hits += 1 + + except Exception as exc: + logger.debug(f"Lease acquisition failed for {key}: {exc}") + break + finally: + if lease is not None: + await self._release_lease(key, lease.lease_id) + + return num_hits + + async def batched_async_contains( + self, lookup_id: str, keys: List[CacheEngineKey], pin: bool = False + ) -> int: + """Check existence of multiple keys (queued with priority).""" + return await self.pq_executor.submit_job( + self._batched_async_contains, + lookup_id=lookup_id, + keys=keys, + pin=pin, + priority=Priorities.LEASE, + ) + + def support_batched_get_non_blocking(self) -> bool: + """Indicate support for non-blocking batched GET.""" + return True + + async def _batched_get_non_blocking( + self, lookup_id: str, keys: List[CacheEngineKey] + ) -> List[MemoryObj]: + """Prefetch multiple keys and filter out None results.""" + results = await self.batched_get(keys) + return [r for r in results if r is not None] + + async def batched_get_non_blocking( + self, lookup_id: str, keys: List[CacheEngineKey] + ) -> List[MemoryObj]: + """Prefetch multiple keys (queued with priority).""" + return await self.pq_executor.submit_job( + self._batched_get_non_blocking, + lookup_id=lookup_id, + keys=keys, + priority=Priorities.PREFETCH, + ) + + def support_batched_get(self) -> bool: + """Indicate support for batched GET operations.""" + return True + + async def list(self) -> List[str]: + """List operation not supported by ai-toolkit.""" + return [] + + def remove_sync(self, key: CacheEngineKey) -> bool: + """Remove operation not supported by ai-toolkit.""" + return True + + def support_ping(self) -> bool: + """Indicate ping operation is not supported.""" + return False + + async def ping(self) -> int: + """Ping operation not implemented.""" + raise NotImplementedError( + "Ping operation not supported by SageMaker HyperPod connector" + ) + + async def close(self): + """Clean up all resources and log statistics.""" + # Log final statistics + logger.info( + f"SageMaker HyperPod Connector Statistics: " + f"GET(ok/fail)={self.stats['get_success']}/{self.stats['get_failure']}, " + f"PUT(ok/fail)={self.stats['put_success']}/{self.stats['put_failure']}, " + f"leases(acq/rel/fail)={self.stats['lease_acquired']}/" + f"{self.stats['lease_released']}/{self.stats['lease_release_failed']}" + ) + + # Shutdown priority queue executor + try: + await self.pq_executor.shutdown(wait=True) + except Exception as e: + logger.warning(f"Error shutting down executor: {e}") + + # Close HTTP session + if self.http_session is not None: + try: + await self.http_session.close() + except Exception as e: + logger.warning(f"Error closing HTTP session: {e}") + self.http_session = None + + # Release shared memory + if self.shared_memory_map is not None: + self.shared_memory_map = None + + if self.shared_memory_obj is not None: + try: + self.shared_memory_obj.close() + except Exception as e: + logger.warning(f"Error closing shared memory object: {e}") + self.shared_memory_obj = None + + logger.info("SageMaker HyperPod connector closed") diff --git a/lmcache/v1/storage_backend/connector/valkey_adapter.py b/lmcache/v1/storage_backend/connector/valkey_adapter.py new file mode 100644 index 00000000000..d5bf5ded78a --- /dev/null +++ b/lmcache/v1/storage_backend/connector/valkey_adapter.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import List, Tuple + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.storage_backend.connector import ( + ConnectorAdapter, + ConnectorContext, + parse_remote_url, +) +from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector + +logger = init_logger(__name__) + + +class ValkeyConnectorAdapter(ConnectorAdapter): + """Adapter for Valkey Server connectors.""" + + def __init__(self) -> None: + super().__init__("valkey://") + + def create_connector(self, context: ConnectorContext) -> RemoteConnector: + # Local + from .valkey_connector import ValkeyClusterConnector, ValkeyConnector + + config = context.config + + if config is not None and config.extra_config is not None: + self.valkey_username = config.extra_config.get("valkey_username", "") + self.valkey_password = config.extra_config.get("valkey_password", "") + self.valkey_database = config.extra_config.get("valkey_database", None) + self.valkey_mode = config.extra_config.get("valkey_mode", "standalone") + else: + self.valkey_username = "" + self.valkey_password = "" + self.valkey_database = None + self.valkey_mode = "standalone" + + logger.info(f"Creating Valkey connector for URL: {context.url}") + + if self.valkey_mode == "cluster": + hosts_and_ports: List[Tuple[str, int]] = [] + assert self.schema is not None + for sub_url in context.url.split(","): + if not sub_url.startswith(self.schema): + sub_url = self.schema + sub_url + + parsed_url = parse_remote_url(sub_url) + hosts_and_ports.append((parsed_url.host, parsed_url.port)) + + return ValkeyClusterConnector( + hosts_and_ports=hosts_and_ports, + loop=context.loop, + local_cpu_backend=context.local_cpu_backend, + username=self.valkey_username, + password=self.valkey_password, + ) + else: + url = context.url[len(self.schema) :] + return ValkeyConnector( + url=url, + loop=context.loop, + local_cpu_backend=context.local_cpu_backend, + username=self.valkey_username, + password=self.valkey_password, + database_id=self.valkey_database, + ) diff --git a/lmcache/v1/storage_backend/connector/valkey_connector.py b/lmcache/v1/storage_backend/connector/valkey_connector.py new file mode 100644 index 00000000000..ff11cee73d7 --- /dev/null +++ b/lmcache/v1/storage_backend/connector/valkey_connector.py @@ -0,0 +1,389 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from enum import IntEnum, auto +from typing import List, Optional, Tuple, no_type_check +import asyncio +import inspect + +# Third Party +from glide import ( + Batch, + ClusterBatch, + GlideClient, + GlideClientConfiguration, + GlideClusterClient, + GlideClusterClientConfiguration, + NodeAddress, + ServerCredentials, +) + +# First Party +from lmcache.logging import init_logger +from lmcache.utils import CacheEngineKey +from lmcache.v1.memory_management import MemoryObj +from lmcache.v1.protocol import RemoteMetadata +from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector +from lmcache.v1.storage_backend.job_executor.pq_executor import AsyncPQExecutor +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend + +logger = init_logger(__name__) + + +class Priorities(IntEnum): + PEEK = auto() + PREFETCH = auto() + GET = auto() + PUT = auto() + + +class ValkeyConnector(RemoteConnector): + def __init__( + self, + url: str, + loop: asyncio.AbstractEventLoop, + local_cpu_backend: LocalCPUBackend, + username: str, + password: str, + database_id: Optional[int] = None, + ): + if ":" in url: + host, port_str = url.split(":", 1) + port = int(port_str) + else: + host = url + port = 6379 # Default Valkey port + + self.host = host + self.port = port + self.database_id = database_id + self.username = username + self.password = password + self.loop = loop + self.local_cpu_backend = local_cpu_backend + self.executor = AsyncPQExecutor(loop) + + # Create connection properly using async create + self.connection = self._init_connection() + + def _init_connection(self): + """Initialize GlideClient connection with credentials and database""" + + async def create_connection(): + try: + # Setup credentials if provided + credentials = None + if self.username or self.password: + credentials = ServerCredentials(self.username, self.password) + + # Build config with optional database_id + config_kwargs = { + "addresses": [NodeAddress(self.host, self.port)], + "credentials": credentials, + } + + if self.database_id is not None: + config_kwargs["database_id"] = self.database_id + + config = GlideClientConfiguration(**config_kwargs) + return await GlideClient.create(config) + except Exception as e: + raise RuntimeError(f"Fail to init valkey connection {e}") from e + + future = asyncio.run_coroutine_threadsafe(create_connection(), self.loop) + connection = future.result(timeout=1.0) + return connection + + def _get_keys(self, key: CacheEngineKey) -> Tuple[str, str]: + """Generate metadata and kv_bytes keys""" + key_str = key.to_string() + metadata_key = f"{key_str}:metadata" + kv_key = f"{key_str}:kv_bytes" + return metadata_key, kv_key + + async def _exists(self, key: CacheEngineKey) -> bool: + metadata_key, _ = self._get_keys(key) + return bool(await self.connection.exists([metadata_key])) + + async def exists(self, key: CacheEngineKey) -> bool: + return await self.executor.submit_job( + self._exists, key=key, priority=Priorities.PEEK + ) + + def exists_sync(self, key: CacheEngineKey) -> bool: + future = asyncio.run_coroutine_threadsafe( + self.executor.submit_job(self._exists, key=key, priority=Priorities.PEEK), + self.loop, + ) + return future.result() + + async def _get(self, key: CacheEngineKey) -> Optional[MemoryObj]: + metadata_key, kv_key = self._get_keys(key) + + results = await self.connection.mget([metadata_key, kv_key]) + + if len(results) != 2: + return None + + metadata_bytes, kv_bytes = results[0], results[1] + + if metadata_bytes is None: + return None + + assert not inspect.isawaitable(metadata_bytes) + + metadata = RemoteMetadata.deserialize(memoryview(metadata_bytes)) + + memory_obj = self.local_cpu_backend.allocate( + metadata.shape, + metadata.dtype, + metadata.fmt, + ) + if memory_obj is None: + logger.warning("Failed to allocate memory during remote receive") + return None + + if kv_bytes is None: + logger.warning( + "Key exists but KV cache does not exist." + "Might happen when the cache is evicted by valkey." + ) + memory_obj.ref_count_down() + return None + + assert not inspect.isawaitable(kv_bytes) + + try: + if isinstance(memory_obj.byte_array, memoryview): + view = memory_obj.byte_array + if view.format == " Optional[MemoryObj]: + return await self.executor.submit_job( + self._get, key=key, priority=Priorities.GET + ) + + async def _put(self, key: CacheEngineKey, memory_obj: MemoryObj): + try: + kv_bytes = bytes(memory_obj.byte_array) + kv_shape = memory_obj.get_shape() + kv_dtype = memory_obj.get_dtype() + memory_format = memory_obj.get_memory_format() + + metadata_bytes = RemoteMetadata( + len(kv_bytes), kv_shape, kv_dtype, memory_format + ).serialize() + + metadata_key, kv_key = self._get_keys(key) + + # Use batch to set both keys in one operation + # kv bytes needs to be set first to avoid race condition + batch = Batch(False) + batch.set(kv_key, kv_bytes) + batch.set(metadata_key, metadata_bytes) + + await self.connection.exec(batch, raise_on_error=False) + except Exception as exc: + logger.error(f"Fail to put data: {exc}") + + async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): + await self.executor.submit_job( + self._put, key=key, memory_obj=memory_obj, priority=Priorities.PUT + ) + + @no_type_check + async def list(self) -> List[str]: + pass + + async def close(self): + await self.executor.shutdown(wait=True) + await self.connection.close() + logger.info("Closed the Valkey connection") + + +class ValkeyClusterConnector(RemoteConnector): + """ + Uses GlideClusterClient to connect to a Valkey cluster. + Supports both URL-based and hosts_and_ports-based initialization. + """ + + def __init__( + self, + loop: asyncio.AbstractEventLoop, + local_cpu_backend: LocalCPUBackend, + username: str, + password: str, + hosts_and_ports: Optional[List[Tuple[str, int]]], + ): + self.loop = loop + self.local_cpu_backend = local_cpu_backend + self.executor = AsyncPQExecutor(loop) + self.username = username + self.password = password + self.hosts_and_ports = hosts_and_ports + + # Create connection + self.connection = self._init_connection() + + def _init_connection(self): + """Initialize GlideClusterClient connection with credentials""" + + async def create_connection(): + try: + # Setup credentials if provided + credentials = None + if self.username or self.password: + credentials = ServerCredentials(self.username, self.password) + + addresses = [ + NodeAddress(host, port) for host, port in self.hosts_and_ports + ] + config = GlideClusterClientConfiguration( + addresses=addresses, credentials=credentials + ) + return await GlideClusterClient.create(config) + except Exception as e: + raise RuntimeError(f"Fail to init valkey connection {e}") from e + + future = asyncio.run_coroutine_threadsafe(create_connection(), self.loop) + connection = future.result(timeout=1.0) + return connection + + def _get_keys_with_hash_tag(self, key: CacheEngineKey) -> Tuple[str, str]: + """Generate metadata and kv_bytes keys with hash tag for same slot placement""" + key_str = key.to_string() + # Use hash tag to ensure both keys go to same slot + metadata_key = f"{{{key_str}}}:metadata" + kv_key = f"{{{key_str}}}:kv_bytes" + return metadata_key, kv_key + + async def _exists(self, key: CacheEngineKey) -> bool: + metadata_key, _ = self._get_keys_with_hash_tag(key) + return bool(await self.connection.exists([metadata_key])) + + async def exists(self, key: CacheEngineKey) -> bool: + return await self.executor.submit_job( + self._exists, key=key, priority=Priorities.PEEK + ) + + def exists_sync(self, key: CacheEngineKey) -> bool: + future = asyncio.run_coroutine_threadsafe( + self.executor.submit_job(self._exists, key=key, priority=Priorities.PEEK), + self.loop, + ) + return future.result() + + async def _get(self, key: CacheEngineKey) -> Optional[MemoryObj]: + metadata_key, kv_key = self._get_keys_with_hash_tag(key) + + results = await self.connection.mget([metadata_key, kv_key]) + + if len(results) != 2: + return None + + metadata_bytes, kv_bytes = results[0], results[1] + + if metadata_bytes is None: + return None + + assert not inspect.isawaitable(metadata_bytes) + + metadata = RemoteMetadata.deserialize(memoryview(metadata_bytes)) + + memory_obj = self.local_cpu_backend.allocate( + metadata.shape, + metadata.dtype, + metadata.fmt, + ) + if memory_obj is None: + logger.warning("Failed to allocate memory during remote receive") + return None + + if kv_bytes is None: + logger.warning( + "Key exists but KV cache does not exist." + "Might happen when the cache is evicted by valkey." + ) + memory_obj.ref_count_down() + return None + + assert not inspect.isawaitable(kv_bytes) + + try: + if isinstance(memory_obj.byte_array, memoryview): + view = memory_obj.byte_array + if view.format == " Optional[MemoryObj]: + return await self.executor.submit_job( + self._get, key=key, priority=Priorities.GET + ) + + async def _put(self, key: CacheEngineKey, memory_obj: MemoryObj): + try: + kv_bytes = bytes(memory_obj.byte_array) + kv_shape = memory_obj.get_shape() + kv_dtype = memory_obj.get_dtype() + memory_format = memory_obj.get_memory_format() + + metadata_bytes = RemoteMetadata( + len(kv_bytes), kv_shape, kv_dtype, memory_format + ).serialize() + + metadata_key, kv_key = self._get_keys_with_hash_tag(key) + + # Use cluster batch to set both keys in one operation + # kv bytes needs to be set first to avoid race condition + batch = ClusterBatch(False) + batch.set(kv_key, kv_bytes) + batch.set(metadata_key, metadata_bytes) + + await self.connection.exec(batch, raise_on_error=False) + except Exception as exc: + logger.error(f"Fail to put data: {exc}") + + async def put(self, key: CacheEngineKey, memory_obj: MemoryObj): + await self.executor.submit_job( + self._put, key=key, memory_obj=memory_obj, priority=Priorities.PUT + ) + + @no_type_check + async def list(self) -> List[str]: + pass + + async def close(self): + await self.executor.shutdown(wait=True) + await self.connection.close() + logger.info("Closed the Valkey connection") diff --git a/lmcache/v1/storage_backend/gds_backend.py b/lmcache/v1/storage_backend/gds_backend.py index d429743a76a..7038b3e7818 100644 --- a/lmcache/v1/storage_backend/gds_backend.py +++ b/lmcache/v1/storage_backend/gds_backend.py @@ -88,7 +88,7 @@ def get_fstype(path): return best_fstype -def pack_metadata(tensor, **extra_metadata) -> bytes: +def pack_metadata(tensor, fmt: MemoryFormat, **extra_metadata) -> bytes: if tensor.dtype not in torch_dtypes: raise RuntimeError(f"unhandled dtype {tensor.dtype}") @@ -98,6 +98,7 @@ def pack_metadata(tensor, **extra_metadata) -> bytes: "dtype": torch_dtypes[tensor.dtype], "shape": list(tensor.size()), "data_offsets": [0, data_size], + "fmt": fmt.value, "__metadata__": extra_metadata, } meta = {"kvcache": tensor_meta} @@ -124,11 +125,12 @@ def unpack_metadata(buffer: bytes): shape = tensor_meta["shape"] dtype_str = tensor_meta["dtype"] data_offsets = tensor_meta["data_offsets"] + fmt = MemoryFormat(tensor_meta["fmt"]) nbytes = data_offsets[1] - data_offsets[0] dtype = torch_dtypes_inverse[dtype_str] - return torch.Size(shape), dtype, nbytes, tensor_meta["__metadata__"] + return torch.Size(shape), dtype, nbytes, fmt, tensor_meta["__metadata__"] def rand_suffix(rand, n: int): @@ -330,13 +332,24 @@ def _read_metadata(self, key, filename, subdir_key): with open(filename, "rb") as f: buf = f.read(_METADATA_MAX_SIZE) - shape, dtype, size, extra_metadata = unpack_metadata(buf) + shape, dtype, size, fmt, extra_metadata = unpack_metadata(buf) if extra_metadata["lmcache_version"] != str(_METADATA_VERSION): raise RuntimeError("unhandled lmcache metadata") - + logger.debug( + f"Read metadata for {key} from {filename}: " + f"shape={shape}, dtype={dtype}, size={size}, fmt={fmt}, " + f"extra_metadata={extra_metadata}" + ) # TODO(extra_metadata) + # TODO(Jiayi): need to support `cached_positions`. + # Currently we just fill it as None. metadata = DiskCacheMetadata( - filename.removesuffix(_METADATA_FILE_SUFFIX), size, shape, dtype + filename.removesuffix(_METADATA_FILE_SUFFIX), + size, + shape, + dtype, + None, + fmt, ) with self.hot_lock: self.metadata_dirs.add(subdir_key) @@ -433,15 +446,21 @@ async def _async_save_bytes_to_disk( os.makedirs(os.path.join(self.gds_path, l1_dir, l2_dir), exist_ok=True) self.metadata_dirs.add(subdir_key) tmp = ".tmp" + rand_suffix(self.rand, 8) + fmt = memory_obj.metadata.fmt metadata = await asyncio.to_thread( self._save_gds, path, tmp, kv_chunk, + fmt, self.cufile_base_pointer, memory_obj.metadata.address, ) + logger.debug( + f"Saved {kv_chunk.numel()} elements of {kv_chunk.dtype} " + f"to {path} with metadata {metadata}" + ) self.insert_key(key, memory_obj) memory_obj.ref_count_down() @@ -458,8 +477,10 @@ def insert_key(self, key: CacheEngineKey, memory_obj: MemoryObj) -> None: size = memory_obj.get_physical_size() shape = memory_obj.metadata.shape dtype = memory_obj.metadata.dtype + fmt = memory_obj.metadata.fmt with self.hot_lock: - self.hot_cache[key] = DiskCacheMetadata(path, size, shape, dtype) + # TODO(Jiayi): need to support `cached_positions`. + self.hot_cache[key] = DiskCacheMetadata(path, size, shape, dtype, None, fmt) def submit_prefetch_task( self, @@ -473,10 +494,12 @@ def submit_prefetch_task( # path = entry.path # dtype = entry.dtype # shape = entry.shape + # fmt = entry.fmt # assert dtype is not None # assert shape is not None + # assert fmt is not None # return asyncio.run_coroutine_threadsafe( - # self._async_load_bytes_from_disk(key, path, dtype, shape), self.loop + # self._async_load_bytes_from_disk(key, path, dtype, shape,fmt), self.loop # ) # TODO(Jiayi): Need to modify this when prefetch interface is determined. @@ -490,8 +513,9 @@ async def _async_load_bytes_from_disk( path: str, dtype: torch.dtype, shape: torch.Size, + fmt: MemoryFormat, ) -> Optional[MemoryObj]: - return self._load_bytes_from_disk(key, path, dtype, shape) + return self._load_bytes_from_disk(key, path, dtype, shape, fmt=fmt) def get_blocking( self, @@ -505,9 +529,12 @@ def get_blocking( path = entry.path dtype = entry.dtype shape = entry.shape + fmt = entry.fmt + logger.warning(entry) assert dtype is not None assert shape is not None - return self._load_bytes_from_disk(key, path, dtype=dtype, shape=shape) + assert fmt is not None + return self._load_bytes_from_disk(key, path, dtype=dtype, shape=shape, fmt=fmt) def _load_bytes_from_disk( self, @@ -515,11 +542,12 @@ def _load_bytes_from_disk( path: str, dtype: torch.dtype, shape: torch.Size, + fmt: MemoryFormat, ) -> Optional[MemoryObj]: """ Load byte array from disk. """ - memory_obj = self.memory_allocator.allocate(shape, dtype) + memory_obj = self.memory_allocator.allocate(shape, dtype, fmt=fmt) if memory_obj is None: logger.debug("Memory allocation failed during sync disk load.") return None @@ -556,6 +584,7 @@ def _load_bytes_from_disk( def get_non_blocking( self, key: CacheEngineKey, + location: Optional[str] = None, ) -> Optional[Future]: # TODO: Using a dummy wrapper around prefetch for now. if not self.submit_prefetch_task(key): @@ -569,6 +598,7 @@ def _save_gds( path: str, tmp: str, kv_chunk: torch.Tensor, + fmt: MemoryFormat, base_pointer: int, device_offset: int, ): @@ -582,7 +612,9 @@ def _save_gds( offset = _METADATA_MAX_SIZE # TODO: We can add the chunk's metadata here, e.g. Tensor parallelism shard # and pipeline parallelism index. - metadata = pack_metadata(kv_chunk, lmcache_version=str(_METADATA_VERSION)) + metadata = pack_metadata( + kv_chunk, fmt=fmt, lmcache_version=str(_METADATA_VERSION) + ) try: with open(tmp_path, "wb") as f: f.write(metadata) diff --git a/lmcache/v1/storage_backend/local_cpu_backend.py b/lmcache/v1/storage_backend/local_cpu_backend.py index 2a1e88fab77..01d6ea526b8 100644 --- a/lmcache/v1/storage_backend/local_cpu_backend.py +++ b/lmcache/v1/storage_backend/local_cpu_backend.py @@ -15,6 +15,7 @@ from lmcache.utils import CacheEngineKey, _lmcache_nvtx_annotate from lmcache.v1.cache_controller.message import KVAdmitMsg, KVEvictMsg from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.lazy_memory_allocator import LazyMixedMemoryAllocator from lmcache.v1.memory_management import ( MemoryAllocatorInterface, MemoryFormat, @@ -24,7 +25,7 @@ ) from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface from lmcache.v1.storage_backend.cache_policy import get_cache_policy -from lmcache.v1.system_detection import NUMADetector +from lmcache.v1.system_detection import NUMADetector, SystemMemoryDetector if TYPE_CHECKING: # First Party @@ -250,6 +251,62 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: # other backends might still (temporarily) hold the memory object. return True + def _calculate_effective_cpu_size( + self, + configured_cpu_size: float, + config: LMCacheEngineConfig, + metadata: Optional[LMCacheEngineMetadata] = None, + ) -> float: + """ + Calculate the effective CPU memory size based on system available memory + and reserve memory configuration. + + Args: + configured_cpu_size: The configured CPU memory size in GB + config: The LMCache engine configuration + metadata: Optional metadata for first rank handling + + Returns: + The effective CPU memory size in GB + """ + + save_only_first_rank = ( + metadata is not None + and config.get_extra_config_value("save_only_first_rank", metadata.use_mla) + and metadata.use_mla + ) + if not save_only_first_rank: + # Do not adjust cpu_size if save_only_first_rank is False for now + return configured_cpu_size + + # Get the system available memory and calculate effective cpu_size + system_available_memory_gb = SystemMemoryDetector.get_available_memory_gb() + # Get reserve memory size from config + reserve_cpu_size = config.reserve_local_cpu_size + + # TODO(baoloongmao): For disable save_only_first_rank case, + # we need to avoid multi-rank race condition in future. + # But for enable save_only_first_rank case, + # we can handle reserve memory simply since non-first ranks + # do not allocate memory. + # Effective memory: min(configured_size, available_memory - reserve_size) + if system_available_memory_gb > 0: + max_usable_memory = max(0, system_available_memory_gb - reserve_cpu_size) + effective_cpu_size = min(configured_cpu_size, max_usable_memory) + logger.info( + f"Adjusted CPU memory size from {configured_cpu_size:.2f} GB " + f"to {effective_cpu_size:.2f} GB " + f"(system available: {system_available_memory_gb:.2f} GB, " + f"reserve: {reserve_cpu_size:.2f} GB)" + ) + assert effective_cpu_size > 0 + return effective_cpu_size + else: + logger.warning( + "Could not determine system available memory, using configured cpu_size" + ) + return configured_cpu_size + def initialize_allocator( self, config: LMCacheEngineConfig, @@ -266,18 +323,21 @@ def initialize_allocator( if save_only_first_rank and metadata.is_first_rank(): # Only the first rank will save the cache, - # so we need to set it lager than other ranks - cpu_size = ( - config.extra_config.get("first_rank_max_local_cpu_size", cpu_size) - if config.extra_config - else cpu_size + # so we need to set it larger than other ranks + cpu_size = config.get_extra_config_value( + "first_rank_max_local_cpu_size", cpu_size ) # Detect the numa mapping numa_mapping = NUMADetector.get_numa_mapping(config) logger.info(f"NUMA mapping {numa_mapping}") + # Calculate effective CPU memory size + cpu_size = self._calculate_effective_cpu_size(cpu_size, config, metadata) + if config.enable_p2p: + # TODO(baoloongmao): Add lazy memory allocator support for P2P mode + # For now, keep the original P2P implementation assert metadata is not None meta_shape = torch.Size(metadata.kv_shape) # TODO(Jiayi): remove this hardcode @@ -299,11 +359,43 @@ def initialize_allocator( ) return paged_mem_allocator else: - return MixedMemoryAllocator( - int(cpu_size * 1024**3), - numa_mapping=numa_mapping, + # Check if lazy memory allocator should be enabled + use_lazy = ( + config.enable_lazy_memory_allocator + and cpu_size > config.lazy_memory_safe_size ) + if use_lazy: + logger.info( + f"Using LazyMixedMemoryAllocator with " + f"initial_ratio={config.lazy_memory_initial_ratio}, " + f"expand_trigger_ratio=" + f"{config.lazy_memory_expand_trigger_ratio}, " + f"step_ratio={config.lazy_memory_step_ratio}" + ) + return LazyMixedMemoryAllocator( + int(cpu_size * 1024**3), + config=config, + numa_mapping=numa_mapping, + memory_limit_callback=lambda: int( + self._calculate_effective_cpu_size(cpu_size, config, metadata) + * 1024**3 + ), + ) + else: + if config.enable_lazy_memory_allocator: + logger.info( + f"LazyMixedMemoryAllocator is disabled because " + f"cpu_size ({cpu_size:.2f} GB) does not exceed " + f"lazy_memory_safe_size " + f"({config.lazy_memory_safe_size:.2f} GB). " + f"Using MixedMemoryAllocator instead." + ) + return MixedMemoryAllocator( + int(cpu_size * 1024**3), + numa_mapping=numa_mapping, + ) + @_lmcache_nvtx_annotate def allocate( self, @@ -516,8 +608,7 @@ def calculate_chunk_budget(self) -> int: Returns: int: The estimated chunk budget for concurrent allocations """ - logger.info("Attempting to calculate chunk budget for async loading") - assert isinstance(self.memory_allocator, MixedMemoryAllocator) + logger.debug("Attempting to calculate chunk budget for async loading") assert self.metadata is not None, ( "metadata required for chunk budget calculation" ) @@ -541,23 +632,23 @@ def calculate_chunk_budget(self) -> int: else: # full: [kv_size, num_layers, chunk_tokens, hidden_dim] chunk_bytes = kv_size * num_layers * chunk_tokens * hidden_dim * dtype_size - logger.info( + logger.debug( f"Stats received: num_layers={num_layers}, kv_size={kv_size}, " f"chunk_tokens={chunk_tokens}, head_dim={head_size}, " f"dtype_size={dtype_size}, " f"hidden_dim={hidden_dim}" ) - logger.info(f"Calculated bytes per chunk per rank: {chunk_bytes}") + logger.debug(f"Calculated bytes per chunk per rank: {chunk_bytes}") # add alignment overhead # (MixedMemoryAllocator uses TensorMemoryAllocator with 4KB alignment) + assert hasattr(self.memory_allocator, "align_bytes") alignment = self.memory_allocator.align_bytes aligned_chunk_bytes = ((chunk_bytes + alignment - 1) // alignment) * alignment # calculate budget with safety margin max_chunks = total_memory // aligned_chunk_bytes - chunk_budget = int(max_chunks) - return chunk_budget + return max_chunks def get_keys(self) -> List[CacheEngineKey]: """ @@ -577,7 +668,7 @@ def clear(self) -> int: with self.cpu_lock: for key in self.hot_cache: memory_obj = self.hot_cache[key] - if memory_obj.can_evict: + if not memory_obj.can_evict: continue clear_keys.append(key) num_cleared_tokens += memory_obj.get_num_tokens() diff --git a/lmcache/v1/storage_backend/local_disk_backend.py b/lmcache/v1/storage_backend/local_disk_backend.py index bcd48cb23b5..faa2378949f 100644 --- a/lmcache/v1/storage_backend/local_disk_backend.py +++ b/lmcache/v1/storage_backend/local_disk_backend.py @@ -248,17 +248,20 @@ def insert_key( shape: torch.Size, dtype: torch.dtype, fmt: MemoryFormat, + cached_positions: Optional[torch.Tensor] = None, ) -> None: path = self._key_to_path(key) has_stored = False with self.disk_lock: - # Need to do reinsert to update cache recency if key in self.dict: - self.dict.pop(key) + # Update cache recency + self.cache_policy.update_on_hit(key, self.dict) has_stored = True - - self.dict[key] = DiskCacheMetadata(path, size, shape, dtype, fmt, False) + else: + self.dict[key] = DiskCacheMetadata( + path, size, shape, dtype, cached_positions, fmt, 0 + ) # push kv admit msg if self.lmcache_worker is not None and not has_stored: @@ -343,11 +346,6 @@ def get_blocking( self.disk_lock.release() return None - self.cache_policy.update_on_hit(key, self.dict) - - self.disk_lock.release() - - self.disk_lock.acquire() # Update cache recency self.cache_policy.update_on_hit(key, self.dict) @@ -375,14 +373,11 @@ async def batched_get_non_blocking( mem_objs: list[MemoryObj] = [] paths: list[str] = [] - logger.info(f"lookup_id: {lookup_id}; Prefetching {len(keys)} keys from disk.") + logger.debug(f"lookup_id: {lookup_id}; Prefetching {len(keys)} keys from disk.") for key in keys: self.disk_lock.acquire() assert key in self.dict, f"Key {key} not found in disk cache after pinning" - # NOTE(Jiayi): Currently, we consider prefetch as cache hit. - self.cache_policy.update_on_hit(key, self.dict) - path = self.dict[key].path dtype = self.dict[key].dtype shape = self.dict[key].shape @@ -403,6 +398,7 @@ async def batched_get_non_blocking( self.dict[key].pin() + # NOTE(Jiayi): Currently, we consider prefetch as cache hit. # Update cache recency self.cache_policy.update_on_hit(key, self.dict) @@ -463,13 +459,16 @@ def async_save_bytes_to_disk( # `submit_put_task` above. # Ref count down better be before `insert_key` for testing # purposes (e.g., testing mem_leak). + # TODO(Jiayi): This could be problematic if the + # freed memory object is immediately reused. size = memory_obj.get_physical_size() shape = memory_obj.metadata.shape dtype = memory_obj.metadata.dtype fmt = memory_obj.metadata.fmt + cached_positions = memory_obj.metadata.cached_positions memory_obj.ref_count_down() - self.insert_key(key, size, shape, dtype, fmt) + self.insert_key(key, size, shape, dtype, fmt, cached_positions=cached_positions) self.disk_worker.remove_put_task(key) @@ -490,6 +489,11 @@ def batched_async_load_bytes_from_disk( buffer = mem_obj.byte_array self.read_file(key, buffer, path) + # TODO(Jiayi): Please recover the metadata in a more + # elegant way in the future. + cached_positions = self.dict[key].cached_positions + mem_obj.metadata.cached_positions = cached_positions + self.disk_lock.acquire() self.dict[key].unpin() self.disk_lock.release() @@ -513,6 +517,12 @@ def load_bytes_from_disk( buffer = memory_obj.byte_array self.read_file(key, buffer, path) + + # TODO(Jiayi): Please recover the metadata in a more + # elegant way in the future. + cached_positions = self.dict[key].cached_positions + memory_obj.metadata.cached_positions = cached_positions + return memory_obj def write_file(self, buffer, path): @@ -550,6 +560,7 @@ def read_file(self, key, buffer, path): with os.fdopen(fd, "rb", buffering=0) as fdo: fdo.readinto(buffer) except FileNotFoundError: + logger.warning(f"File not found on disk: {path}") if self.dict.get(key, None): self.dict.pop(key) return diff --git a/lmcache/v1/storage_backend/nixl_storage_backend.py b/lmcache/v1/storage_backend/nixl_storage_backend.py index bb38aeb1538..ad42e32312b 100644 --- a/lmcache/v1/storage_backend/nixl_storage_backend.py +++ b/lmcache/v1/storage_backend/nixl_storage_backend.py @@ -14,6 +14,7 @@ # limitations under the License. # Standard +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, List, Optional, Sequence, Set, cast import asyncio @@ -43,6 +44,7 @@ PagedTensorMemoryAllocator, ) from lmcache.v1.storage_backend.abstract_backend import AllocatorBackendInterface +from lmcache.v1.storage_backend.cache_policy import get_cache_policy from lmcache.v1.transfer_channel.transfer_utils import get_correct_device logger = init_logger(__name__) @@ -51,16 +53,17 @@ @dataclass class NixlStorageConfig: buffer_size: int - file_pool_size: int + pool_size: int buffer_device: str path: str backend: str + backend_params: dict[str, str] @staticmethod def validate_nixl_backend(backend: str, device: str): if backend in ("GDS", "GDS_MT"): return device == "cpu" or device == "cuda" - elif backend in ("POSIX", "HF3FS"): + elif backend in ("POSIX", "HF3FS", "OBJ"): return device == "cpu" else: return False @@ -75,39 +78,47 @@ def from_cache_engine_config( extra_config = config.extra_config assert extra_config is not None assert extra_config.get("enable_nixl_storage") - assert extra_config.get("nixl_backend") is not None - assert extra_config.get("nixl_path") is not None - assert extra_config.get("nixl_file_pool_size") is not None + + pool_size = extra_config.get("nixl_pool_size") + backend = extra_config.get("nixl_backend") + path = extra_config.get("nixl_path") + + assert pool_size is not None + assert backend is not None assert NixlStorageConfig.validate_nixl_backend( - extra_config.get("nixl_backend"), config.nixl_buffer_device + backend, config.nixl_buffer_device ), "Invalid NIXL backend & device combination" + backend_params = extra_config.get("nixl_backend_params") + if backend_params is None: + backend_params = {} + corrected_device = get_correct_device( config.nixl_buffer_device, metadata.worker_id ) return NixlStorageConfig( buffer_size=config.nixl_buffer_size, - file_pool_size=extra_config.get("nixl_file_pool_size"), + pool_size=pool_size, buffer_device=corrected_device, - path=extra_config.get("nixl_path"), - backend=extra_config.get("nixl_backend"), + path=path, + backend=backend, + backend_params=backend_params, ) -class NixlFilePool: - def __init__(self, path: str, size: int): +class NixlDescPool(ABC): + def __init__(self, size: int): self.lock = threading.Lock() self.size: int = size - self.fds: List[int] = [] self.indices: List[int] = [] - for i in reversed(range(size)): - tmp_path = path + f"obj_{i}_{uuid.uuid4().hex[0:4]}.bin" - fd = os.open(tmp_path, os.O_CREAT | os.O_RDWR) - self.fds.append(fd) - self.indices.append(i) + self.indices.extend(reversed(range(size))) + + def get_num_available_descs(self) -> int: + with self.lock: + return len(self.indices) def pop(self) -> int: with self.lock: @@ -119,6 +130,24 @@ def push(self, index: int): assert len(self.indices) < self.size self.indices.append(index) + @abstractmethod + def close(self): + pass + + +class NixlFilePool(NixlDescPool): + def __init__(self, size: int, path: str): + super().__init__(size) + self.fds: List[int] = [] + + assert path is not None + + for i in reversed(range(size)): + filename = f"obj_{i}_{uuid.uuid4().hex[0:4]}.bin" + tmp_path = os.path.join(path, filename) + fd = os.open(tmp_path, os.O_CREAT | os.O_RDWR) + self.fds.append(fd) + def close(self): # TODO: do we need to delete the files? with self.lock: @@ -127,38 +156,86 @@ def close(self): os.close(fd) +class NixlObjectPool(NixlDescPool): + def __init__(self, size: int): + super().__init__(size) + self.keys: List[str] = [] + + for i in reversed(range(size)): + key = f"obj_{i}_{uuid.uuid4().hex[0:4]}" + self.keys.append(key) + + def close(self): + pass + + +@dataclass +class NixlKeyMetadata: + index: int + shape: Optional[torch.Size] = None + dtype: Optional[torch.dtype] = None + fmt: Optional[MemoryFormat] = None + pin_count: int = 0 + + def pin(self) -> bool: + self.pin_count += 1 + return True + + def unpin(self) -> bool: + self.pin_count -= 1 + return True + + @property + def is_pinned(self) -> bool: + return self.pin_count > 0 + + @property + def can_evict(self) -> bool: + """ + Check if the related key can be evicted. + """ + return not self.is_pinned + + class NixlStorageAgent: agent_name: str nixl_agent: NixlAgent - file_pool: NixlFilePool - reg_descs: nixlBind.nixlRegDList - file_reg_descs: nixlBind.nixlRegDList - xfer_descs: nixlBind.nixlXferDList - file_xfer_descs: nixlBind.nixlXferDList - xfer_handler: NixlDlistHandle - file_xfer_handler: NixlDlistHandle + pool: NixlDescPool + mem_reg_descs: nixlBind.nixlRegDList + storage_reg_descs: nixlBind.nixlRegDList + mem_xfer_descs: nixlBind.nixlXferDList + storage_xfer_descs: nixlBind.nixlXferDList + mem_xfer_handler: NixlDlistHandle + storage_xfer_handler: NixlDlistHandle def __init__( self, allocator: PagedTensorMemoryAllocator, - file_pool: NixlFilePool, + pool: NixlDescPool, device: str, backend: str, + backend_params: dict[str, str], ): buffer_ptr = allocator.buffer_ptr buffer_size = allocator.buffer_size page_size = allocator.align_bytes self.agent_name = "NixlAgent_" + str(uuid.uuid4()) - nixl_conf = NixlAgentConfig(backends=[backend]) + nixl_conf = NixlAgentConfig(backends=[]) self.nixl_agent = NixlAgent(self.agent_name, nixl_conf) + self.nixl_agent.create_backend(backend, backend_params) device_id = torch.cuda.current_device() - self.init_handlers(device, buffer_ptr, buffer_size, page_size, device_id) + self.init_mem_handlers(device, buffer_ptr, buffer_size, page_size, device_id) - self.init_file_handlers(page_size, file_pool.fds) + if isinstance(pool, NixlFilePool): + self.init_storage_handlers_file(page_size, pool.fds) + elif isinstance(pool, NixlObjectPool): + self.init_storage_handlers_object(page_size, pool.keys) + else: + raise TypeError(f"Unsupported pool type: {type(pool).__name__}") - def init_handlers(self, device, buffer_ptr, buffer_size, page_size, device_id): + def init_mem_handlers(self, device, buffer_ptr, buffer_size, page_size, device_id): reg_list = [(buffer_ptr, buffer_size, device_id, "")] xfer_desc = [ (base_addr, page_size, device_id) @@ -176,35 +253,58 @@ def init_handlers(self, device, buffer_ptr, buffer_size, page_size, device_id): "", xfer_descs, mem_type=mem_type ) - self.reg_descs = reg_descs - self.xfer_descs = xfer_descs - self.xfer_handler = xfer_handler + self.mem_reg_descs = reg_descs + self.mem_xfer_descs = xfer_descs + self.mem_xfer_handler = xfer_handler - def init_file_handlers(self, page_size, fds): - reg_list = [(0, page_size, fd, "") for fd in fds] - xfer_desc = [(0, page_size, fd) for fd in fds] + def init_storage_handlers_file(self, page_size, fds): + reg_list = [] + xfer_desc = [] + for fd in fds: + reg_list.append((0, page_size, fd, "")) + xfer_desc.append((0, page_size, fd)) reg_descs = self.nixl_agent.register_memory(reg_list, mem_type="FILE") xfer_descs = self.nixl_agent.get_xfer_descs(xfer_desc, mem_type="FILE") xfer_handler = self.nixl_agent.prep_xfer_dlist( self.agent_name, xfer_desc, mem_type="FILE" ) - self.file_reg_descs = reg_descs - self.file_xfer_descs = xfer_descs - self.file_xfer_handler = xfer_handler + self.storage_reg_descs = reg_descs + self.storage_xfer_descs = xfer_descs + self.storage_xfer_handler = xfer_handler + + def init_storage_handlers_object(self, page_size, keys): + reg_list = [] + xfer_desc = [] + for i, key in enumerate(keys): + reg_list.append((0, page_size, i, key)) + xfer_desc.append((0, page_size, i)) + reg_descs = self.nixl_agent.register_memory(reg_list, mem_type="OBJ") + xfer_descs = self.nixl_agent.get_xfer_descs(xfer_desc, mem_type="OBJ") + xfer_handler = self.nixl_agent.prep_xfer_dlist( + self.agent_name, xfer_desc, mem_type="OBJ" + ) + + self.storage_reg_descs = reg_descs + self.storage_xfer_descs = xfer_descs + self.storage_xfer_handler = xfer_handler - def get_gpu_to_file_handle(self, mem_indices, file_indices) -> NixlXferHandle: + def get_mem_to_storage_handle(self, mem_indices, storage_indices) -> NixlXferHandle: return self.nixl_agent.make_prepped_xfer( "WRITE", - self.xfer_handler, + self.mem_xfer_handler, mem_indices, - self.file_xfer_handler, - file_indices, + self.storage_xfer_handler, + storage_indices, ) - def get_file_to_gpu_handle(self, mem_indices, file_indices) -> NixlXferHandle: + def get_storage_to_mem_handle(self, mem_indices, storage_indices) -> NixlXferHandle: return self.nixl_agent.make_prepped_xfer( - "READ", self.xfer_handler, mem_indices, self.file_xfer_handler, file_indices + "READ", + self.mem_xfer_handler, + mem_indices, + self.storage_xfer_handler, + storage_indices, ) def post_blocking(self, handle: NixlXferHandle): @@ -219,10 +319,10 @@ def release_handle(self, handle): self.nixl_agent.release_xfer_handle(handle) def close(self): - self.nixl_agent.release_dlist_handle(self.file_xfer_handler) - self.nixl_agent.release_dlist_handle(self.xfer_handler) - self.nixl_agent.deregister_memory(self.file_reg_descs) - self.nixl_agent.deregister_memory(self.reg_descs) + self.nixl_agent.release_dlist_handle(self.storage_xfer_handler) + self.nixl_agent.release_dlist_handle(self.mem_xfer_handler) + self.nixl_agent.deregister_memory(self.storage_reg_descs) + self.nixl_agent.deregister_memory(self.mem_reg_descs) class NixlStorageBackend(AllocatorBackendInterface): @@ -233,6 +333,15 @@ class NixlStorageBackend(AllocatorBackendInterface): implementation. """ + @staticmethod + def createPool(backend: str, size: int, path: str): + if backend in ("GDS", "GDS_MT", "POSIX", "HF3FS"): + return NixlFilePool(size, path) + elif backend in ("OBJ"): + return NixlObjectPool(size) + else: + raise ValueError(f"Unsupported NIXL backend: {backend}") + def __init__( self, nixl_config: NixlStorageConfig, @@ -249,21 +358,26 @@ def __init__( super().__init__(dst_device=nixl_config.buffer_device) self.loop = loop - self.key_lock = threading.Lock() - self.key_dict: dict[int, MemoryObjMetadata] = {} + self.key_lock = threading.RLock() + self.cache_policy = get_cache_policy(config.cache_policy) + self.key_dict = self.cache_policy.init_mutable_mapping() - self.progress_lock = threading.Lock() - self.progress_set: Set[int] = set() + self.progress_lock = threading.RLock() + self.progress_set: Set[CacheEngineKey] = set() self.memory_allocator = self.initialize_allocator(config, metadata) - self.file_pool = NixlFilePool(nixl_config.path, nixl_config.file_pool_size) + self.pool = NixlStorageBackend.createPool( + nixl_config.backend, nixl_config.pool_size, nixl_config.path + ) + assert self.pool is not None self.agent = NixlStorageAgent( self.memory_allocator, - self.file_pool, + self.pool, nixl_config.buffer_device, nixl_config.backend, + nixl_config.backend_params, ) def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: @@ -277,8 +391,9 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: """ with self.key_lock: - chunk_hash = key.chunk_hash - if chunk_hash in self.key_dict and not self.exists_in_put_tasks(key): + if key in self.key_dict: + if pin: + self.key_dict[key].pin() return True else: return False @@ -291,54 +406,55 @@ def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: :return: True if the key exists in put tasks, False otherwise """ with self.progress_lock: - return key.chunk_hash in self.progress_set + return key in self.progress_set def add_key_to_dict( self, key: CacheEngineKey, obj: MemoryObjMetadata, index: int ) -> None: with self.key_lock: - assert key.chunk_hash not in self.key_dict - self.key_dict[key.chunk_hash] = MemoryObjMetadata( + assert key not in self.key_dict + self.key_dict[key] = NixlKeyMetadata( shape=obj.shape, dtype=obj.dtype, fmt=obj.fmt, - phy_size=obj.phy_size, - ref_count=1, - address=index, + index=index, ) + self.cache_policy.update_on_put(key) - async def gpu_to_file( + async def mem_to_storage( self, keys: Sequence[CacheEngineKey], mem_objs: List[MemoryObj] ) -> None: mem_indices = [mem_obj.meta.address for mem_obj in mem_objs] - file_indices = [] + storage_indices = [] for i in range(len(keys)): - index = self.file_pool.pop() - file_indices.append(index) + index = self.pool.pop() + storage_indices.append(index) self.add_key_to_dict(keys[i], mem_objs[i].meta, index) - handle = self.agent.get_gpu_to_file_handle(mem_indices, file_indices) + handle = self.agent.get_mem_to_storage_handle(mem_indices, storage_indices) self.agent.post_blocking(handle) self.agent.release_handle(handle) for key in keys: with self.progress_lock: - self.progress_set.discard(key.chunk_hash) + self.progress_set.discard(key) - async def file_to_gpu( + async def storage_to_mem( self, keys: list[CacheEngineKey] ) -> list[Optional[MemoryObj]]: obj_list: list[Optional[MemoryObj]] = [] mem_indices = [] - file_indices = [] + storage_indices = [] with self.key_lock: for key in keys: - metadata = self.key_dict.get(key.chunk_hash) + metadata = self.key_dict.get(key) if metadata is None: obj_list.append(None) continue + self.cache_policy.update_on_hit(key, self.key_dict) + dtype = metadata.dtype shape = metadata.shape fmt = metadata.fmt @@ -352,12 +468,12 @@ async def file_to_gpu( obj_list.append(obj) mem_indices.append(obj.metadata.address) - file_indices.append(metadata.address) + storage_indices.append(metadata.index) if not mem_indices: return obj_list - handle = self.agent.get_file_to_gpu_handle(mem_indices, file_indices) + handle = self.agent.get_storage_to_mem_handle(mem_indices, storage_indices) self.agent.post_blocking(handle) self.agent.release_handle(handle) @@ -369,11 +485,29 @@ def batched_submit_put_task( memory_objs: List[MemoryObj], transfer_spec: Any = None, ) -> None: + with self.key_lock: + available_descs = self.pool.get_num_available_descs() + num_evict = len(keys) - available_descs + if num_evict > 0: + evict_keys = self.cache_policy.get_evict_candidates( + self.key_dict, num_candidates=num_evict + ) + + if not evict_keys: + logger.warning( + "No eviction candidates found. Backend under pressure." + ) + return None + + self.batched_remove(evict_keys, force=False) + with self.progress_lock: for key in keys: - self.progress_set.add(key.chunk_hash) + self.progress_set.add(key) - asyncio.run_coroutine_threadsafe(self.gpu_to_file(keys, memory_objs), self.loop) + asyncio.run_coroutine_threadsafe( + self.mem_to_storage(keys, memory_objs), self.loop + ) def get_blocking(self, key: CacheEngineKey) -> Optional[MemoryObj]: """ @@ -384,7 +518,7 @@ def get_blocking(self, key: CacheEngineKey) -> Optional[MemoryObj]: :return: MemoryObj. None if the key does not exist. """ - future = asyncio.run_coroutine_threadsafe(self.file_to_gpu([key]), self.loop) + future = asyncio.run_coroutine_threadsafe(self.storage_to_mem([key]), self.loop) if future is None: return None @@ -398,7 +532,7 @@ async def batched_get_non_blocking( keys: list[CacheEngineKey], transfer_spec: Any = None, ) -> list[MemoryObj]: - obj_list = await self.file_to_gpu(keys) + obj_list = await self.storage_to_mem(keys) assert None not in obj_list return cast(list[MemoryObj], obj_list) @@ -410,18 +544,30 @@ def remove(self, key: CacheEngineKey, force: bool = True) -> bool: """ with self.key_lock: - metadata = self.key_dict.pop(key.chunk_hash, None) + metadata = self.key_dict.pop(key, None) if metadata is None: return False + if force: + self.cache_policy.update_on_force_evict(key) - self.file_pool.push(metadata.address) + self.pool.push(metadata.index) return True def pin(self, key: CacheEngineKey) -> bool: - return False + with self.key_lock: + if key in self.key_dict: + self.key_dict[key].pin() + return True + else: + return False def unpin(self, key: CacheEngineKey) -> bool: - return False + with self.key_lock: + if key in self.key_dict: + self.key_dict[key].unpin() + return True + else: + return False def close(self) -> None: """ @@ -429,7 +575,7 @@ def close(self) -> None: """ self.agent.close() - self.file_pool.close() + self.pool.close() self.memory_allocator.close() @@ -480,8 +626,6 @@ def allocate( eviction: bool = True, busy_loop: bool = True, ) -> Optional[MemoryObj]: - if eviction: - logger.warning("NixlStorageBackend does not support eviction for now") if busy_loop: logger.warning("NixlStorageBackend does not support busy loop for now") @@ -496,8 +640,6 @@ def batched_allocate( eviction: bool = True, busy_loop: bool = True, ) -> Optional[list[MemoryObj]]: - if eviction: - logger.warning("NixlStorageBackend does not support eviction for now") if busy_loop: logger.warning("NixlStorageBackend does not support busy loop for now") diff --git a/lmcache/v1/storage_backend/p2p_backend.py b/lmcache/v1/storage_backend/p2p_backend.py index 6ec7e6bcc7d..6de80f7707c 100644 --- a/lmcache/v1/storage_backend/p2p_backend.py +++ b/lmcache/v1/storage_backend/p2p_backend.py @@ -11,6 +11,7 @@ # First Party from lmcache.config import LMCacheEngineMetadata from lmcache.logging import init_logger +from lmcache.observability import LMCStatsMonitor from lmcache.utils import CacheEngineKey from lmcache.v1.cache_controller.message import ( BatchedP2PLookupMsg, @@ -108,7 +109,7 @@ def __init__( self.config = config self.loop = loop self.lmcache_worker = lmcache_worker - + self.stats_monitor = LMCStatsMonitor.GetOrCreate() assert config.p2p_host is not None, "p2p_host must be specified" assert config.p2p_init_ports is not None, "p2p_init_ports must be specified" assert config.p2p_lookup_ports is not None, "p2p_lookup_ports must be specified" @@ -143,6 +144,7 @@ def __init__( self.full_size_shape = list(self.memory_allocator.cpu_allocator.shape) # TODO(Jiayi): remove this hardcode self.fmt: MemoryFormat = MemoryFormat.KV_2LTD + self.chunk_size = config.chunk_size self.transfer_channel = CreateTransferChannel( channel_type=config.transfer_channel, @@ -230,6 +232,9 @@ async def _handle_peer_requests(self): msg_bytes = await self.async_peer_socket.recv() msg = msgspec.msgpack.decode(msg_bytes, type=P2PMsg) + num_tokens = len(msg.mem_indexes) * self.chunk_size + monitor_req_id = self.stats_monitor.on_p2p_transfer_request(num_tokens) + if isinstance(msg, BatchedLookupAndGetMsg): logger.info("Received P2P batched get msg") @@ -310,6 +315,9 @@ async def _handle_peer_requests(self): num_read_chunks=len(local_mem_objs), ) + logger.info(f"P2P transfer finished for request {monitor_req_id}") + self.stats_monitor.on_p2p_transfer_finished(monitor_req_id) + await self.async_peer_socket.send(msgspec.msgpack.encode(ret_msg)) async def _ensure_peer_connection( diff --git a/lmcache/v1/storage_backend/pd_backend.py b/lmcache/v1/storage_backend/pd_backend.py index 203b2978ca2..c68625f475c 100644 --- a/lmcache/v1/storage_backend/pd_backend.py +++ b/lmcache/v1/storage_backend/pd_backend.py @@ -245,6 +245,8 @@ def allocate( eviction: bool = True, busy_loop: bool = True, ) -> Optional[MemoryObj]: + if fmt is None: + fmt = MemoryFormat.KV_2LTD # NOTE: no eviction and busy_loop in PD return self.memory_allocator.allocate( shape=shape, dtype=dtype, fmt=fmt, allocator_type="gpu" @@ -261,6 +263,8 @@ def batched_allocate( eviction: bool = True, busy_loop: bool = True, ): + if fmt is None: + fmt = MemoryFormat.KV_2LTD return self.memory_allocator.batched_allocate( shape, dtype, batch_size, fmt, allocator_type="gpu" ) diff --git a/lmcache/v1/storage_backend/remote_backend.py b/lmcache/v1/storage_backend/remote_backend.py index 5b3741aed9e..9f00430d2a1 100644 --- a/lmcache/v1/storage_backend/remote_backend.py +++ b/lmcache/v1/storage_backend/remote_backend.py @@ -28,7 +28,7 @@ def __init__( config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, loop: asyncio.AbstractEventLoop, - local_cpu_backend: LocalCPUBackend, + local_cpu_backend: Optional[LocalCPUBackend], dst_device: str = "cuda", ): super().__init__(dst_device=dst_device) @@ -136,14 +136,7 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: # For MLA worker id as 0 mode, use worker_id 0 if self._mla_worker_id_as0_mode: - key = CacheEngineKey( - key.fmt, - key.model_name, - key.world_size, - 0, - key.chunk_hash, - key.request_configs, - ) + key = key.with_new_worker_id(0) try: if self.config.extra_config is not None and self.config.extra_config.get( @@ -161,6 +154,32 @@ def contains(self, key: CacheEngineKey, pin: bool = False) -> bool: logger.warning("Returning False") return False + def support_batched_contains(self) -> bool: + return ( + self.connection is not None and self.connection.support_batched_contains() + ) + + def batched_contains( + self, + keys: List[CacheEngineKey], + pin: bool = False, + stop_after_first_not_exits: bool = True, + ) -> List[bool]: + if self.connection is None: + logger.warning( + "Connection is None in batched_contains, returning all False" + ) + return [False] * len(keys) + + if self._mla_worker_id_as0_mode: + keys = [key.with_new_worker_id(0) for key in keys] + + try: + return self.connection.batched_contains(keys, stop_after_first_not_exits) + except Exception as e: + logger.warning(f"Remote connection failed in batched_contains: {e}") + return [False] * len(keys) + def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: with self.lock: return key in self.put_tasks @@ -257,20 +276,20 @@ def get_blocking( """ Blocking get function. """ + # Check if local_cpu_backend is available (required for memory allocation) + if self.local_cpu_backend is None: + logger.warning( + "local_cpu_backend is None in get_blocking " + "(likely scheduler role), returning None" + ) + return None if self.connection is None: logger.warning("Connection is None in get_blocking, returning None") return None # For MLA worker id as 0 mode, use worker_id 0 if self._mla_worker_id_as0_mode: - key = CacheEngineKey( - key.fmt, - key.model_name, - key.world_size, - 0, - key.chunk_hash, - key.request_configs, - ) + key = key.with_new_worker_id(0) t1 = time.perf_counter() future = asyncio.run_coroutine_threadsafe(self.connection.get(key), self.loop) @@ -300,31 +319,27 @@ def batched_get_blocking( self, keys: List[CacheEngineKey], ) -> List[Optional[MemoryObj]]: + # Check if local_cpu_backend is available (required for memory allocation) + if self.local_cpu_backend is None: + logger.warning( + "local_cpu_backend is None in batched_get_blocking " + "(likely scheduler role), returning None list" + ) + return [None] * len(keys) + if self.connection is None: logger.warning("Connection is None in batched_get_blocking, returning None") return [None] * len(keys) # For MLA worker id as 0 mode, use worker_id 0 if self._mla_worker_id_as0_mode: - new_keys = [ - CacheEngineKey( - key.fmt, - key.model_name, - key.world_size, - 0, - key.chunk_hash, - key.request_configs, - ) - for key in keys - ] - else: - new_keys = keys + keys = [key.with_new_worker_id(0) for key in keys] t1 = time.perf_counter() # batched get if self.connection.support_batched_get(): future = asyncio.run_coroutine_threadsafe( - self.connection.batched_get(new_keys), self.loop + self.connection.batched_get(keys), self.loop ) try: memory_objs = future.result(self.blocking_timeout_secs) @@ -334,17 +349,16 @@ def batched_get_blocking( "batched get blocking timeout, trigger cancel the future task" ) future.cancel() - with self.lock: - self.connection = None - self.failure_time = time.time() - logger.warning( - f"Error occurred in batched_get_blocking: {e}, returning None list" - ) + else: + logger.warning( + f"Error occurred in batched_get_blocking: {e}, " + f"returning None list" + ) return [None] * len(keys) else: futures = [ asyncio.run_coroutine_threadsafe(self.connection.get(key), self.loop) - for key in new_keys + for key in keys ] memory_objs = [] failed = False @@ -359,12 +373,10 @@ def batched_get_blocking( "get blocking timeout, trigger cancel the future task" ) fut.cancel() - with self.lock: - self.connection = None - self.failure_time = time.time() - logger.warning( - f"Error occurred in get_blocking: {e}, returning None" - ) + else: + logger.warning( + f"Error occurred in get_blocking: {e}, returning None" + ) memory_obj = None memory_objs.append(memory_obj) else: @@ -404,17 +416,7 @@ async def batched_async_contains( logger.warning("Connection is None in batched_async_contains, returning 0") return 0 if self._mla_worker_id_as0_mode: - keys = [ - CacheEngineKey( - key.fmt, - key.model_name, - key.world_size, - 0, - key.chunk_hash, - key.request_configs, - ) - for key in keys - ] + keys = [key.with_new_worker_id(0) for key in keys] try: assert self.connection.support_batched_async_contains(), ( @@ -437,6 +439,14 @@ async def batched_get_non_blocking( keys: List[CacheEngineKey], transfer_spec: Any = None, ) -> List[MemoryObj]: + # Check if local_cpu_backend is available (required for memory allocation) + if self.local_cpu_backend is None: + logger.warning( + "local_cpu_backend is None in batched_get_non_blocking " + "(likely scheduler role), returning empty list" + ) + return [] + if self.connection is None: logger.warning( "Connection is None in batched_get_non_blocking, returning empty list" @@ -459,9 +469,23 @@ def unpin(self, key: CacheEngineKey) -> bool: return True def remove(self, key, force=True): - raise NotImplementedError("Remote backend does not support remove now.") + if self.connection is None: + logger.warning("Connection is None in remove, returning False") + return False + + try: + return self.connection.remove_sync(key) + except Exception as e: + logger.exception( + f"Failed to remove key {key} from remote backend, error: {e}" + ) + return False def get_allocator_backend(self): + assert self.local_cpu_backend is not None, ( + "local_cpu_backend is required for get_allocator_backend, " + "should not be called in scheduler role" + ) return self.local_cpu_backend def close(self): diff --git a/lmcache/v1/storage_backend/storage_manager.py b/lmcache/v1/storage_backend/storage_manager.py index 51c47889c52..5c3320ea87e 100644 --- a/lmcache/v1/storage_backend/storage_manager.py +++ b/lmcache/v1/storage_backend/storage_manager.py @@ -21,6 +21,7 @@ # First Party from lmcache.config import LMCacheEngineMetadata from lmcache.logging import init_logger +from lmcache.observability import PrometheusLogger from lmcache.utils import ( CacheEngineKey, _lmcache_nvtx_annotate, @@ -32,7 +33,7 @@ MemoryFormat, MemoryObj, ) -from lmcache.v1.storage_backend import CreateStorageBackends +from lmcache.v1.storage_backend import CreateStorageBackends, is_cuda_worker from lmcache.v1.storage_backend.abstract_backend import ( AllocatorBackendInterface, StorageBackendInterface, @@ -119,7 +120,7 @@ async def acquire(self, n: int = 1) -> None: ) async with self._cond: - logger.info(f"WeightedSemaphore: Attempting to acquire {n} chunks") + logger.debug(f"WeightedSemaphore: Attempting to acquire {n} chunks") if n <= self._concurrent_budget_cap: await self._cond.wait_for(lambda: self._current_chunks >= n) self._current_chunks -= n @@ -130,7 +131,7 @@ async def acquire(self, n: int = 1) -> None: ) # Reserve everything self._current_chunks = 0 - logger.info( + logger.debug( f"WeightedSemaphore: Acquired {n} chunks, " f"remaining chunks: {self._current_chunks}" ) @@ -195,7 +196,8 @@ def __init__( ) self.thread.start() - if torch.cuda.is_available(): + # For scheduler role, always use CPU device + if is_cuda_worker(metadata): dst_device = "cuda" else: dst_device = "cpu" @@ -209,9 +211,14 @@ def __init__( ) ) + # the backend used for actual storage + self.non_allocator_backends = self.get_non_allocator_backends() + self.enable_pd = config.enable_pd - self.allocator_backend = self._get_allocator_backend(config) + self.allocator_backend = None + if metadata.role != "scheduler": + self.allocator_backend = self._get_allocator_backend(config) if config.local_cpu: self.local_cpu_backend = self.storage_backends["LocalCPUBackend"] @@ -224,13 +231,39 @@ def __init__( self.event_manager = event_manager self.async_lookup_server: Optional["LMCacheAsyncLookupServer"] = None + self.async_serializer: Optional[AsyncSerializer] = None # The cuda stream for internal copies during put - if torch.cuda.is_available(): + if is_cuda_worker(metadata): self.internal_copy_stream = torch.cuda.Stream() else: self.internal_copy_stream = None + self._setup_metrics() + + def _setup_metrics(self): + prometheus_logger = PrometheusLogger.GetInstanceOrNone() + if prometheus_logger is None: + logger.warning( + "PrometheusLogger is not initialized, " + "event metrics will not be collected" + ) + return + + metric_map = { + "storage_events_ongoing_count": EventStatus.ONGOING, + "storage_events_done_count": EventStatus.DONE, + "storage_events_not_found_count": EventStatus.NOT_FOUND, + } + + for metric_name, status in metric_map.items(): + metric = getattr(prometheus_logger, metric_name) + metric.set_function( + lambda s=status: self.event_manager.get_events_count_by_status( + EventType.LOADING, s + ) + ) + def post_init(self, **kwargs) -> None: if "async_lookup_server" in kwargs: assert not self.config.save_unfull_chunk, ( @@ -238,6 +271,11 @@ def post_init(self, **kwargs) -> None: "async loading." ) self.async_lookup_server = kwargs.pop("async_lookup_server") + # PDBackend has't supported calculate_chunk_budget + if not self.enable_pd and ( + self.config.enable_async_loading or self.config.use_layerwise + ): + assert self.allocator_backend is not None self.async_serializer = AsyncSerializer(self.allocator_backend, self.loop) def _get_allocator_backend( @@ -265,6 +303,7 @@ def allocate( """ # TODO (Jiayi): We might need to pre-allocate and management # disk in a similar way as CPU. + assert self.allocator_backend is not None return self.allocator_backend.allocate( shape, dtype, fmt, eviction=eviction, busy_loop=busy_loop ) @@ -285,6 +324,8 @@ def batched_allocate( """ # TODO (Jiayi): We might need to pre-allocate and management # disk in a similar way as CPU. + if self.allocator_backend is None: + raise RuntimeError("Allocator backend not available for scheduler role") return self.allocator_backend.batched_allocate( shape, dtype, batch_size, fmt, eviction=eviction, busy_loop=busy_loop ) @@ -328,6 +369,9 @@ def batched_put( str, tuple[Sequence[CacheEngineKey], list[MemoryObj]], ] = {} + if self.allocator_backend is None: + # For scheduler role, no allocator backend available + raise RuntimeError("Batched put not available for scheduler role") obj_dict[get_backend_cname(self.allocator_backend)] = ( keys, memory_objs, @@ -371,7 +415,10 @@ def get( # are allocated by the allocator backend. memory_obj = backend.get_blocking(key) if memory_obj: - if backend_name not in ["LocalCPUBackend", "PDBackend"]: + if ( + backend_name not in ["LocalCPUBackend", "PDBackend"] + and "LocalCPUBackend" in self.storage_backends + ): local_cpu_backend = self.storage_backends["LocalCPUBackend"] assert isinstance(local_cpu_backend, LocalCPUBackend) local_cpu_backend.submit_put_task(key, memory_obj) @@ -379,6 +426,27 @@ def get( return None + def get_non_blocking( + self, + key: CacheEngineKey, + location: Optional[str] = None, + ) -> Optional[Future]: + """ + Non-blocking function to get the memory object from the storages. + """ + # TODO (Jiayi): incorporate prefetching here + + # Search all backends for non-blocking get + for backend_name, backend in self.storage_backends.items(): + if location and backend_name != location: + continue + # NOTE(Jiayi): bypass the allocator for now + task = backend.get_non_blocking(key) + if task: + # TODO (Jiayi): add write-back logic here + return task + return None + def batched_get( self, keys: List[CacheEngineKey], @@ -415,20 +483,19 @@ def layerwise_batched_get( """ if location is None: location = "LocalCPUBackend" - for keys_multi_chunk in keys: # Retrieve all chunks for one layer backend = self.storage_backends[location] # TODO(Jiayi): need to make async loading and layerwise compatible - task = asyncio.run_coroutine_threadsafe( - self.async_serializer.run( - backend.batched_get_non_blocking( - "fake_lookup_id", keys_multi_chunk - ), - len(keys_multi_chunk), - ), - self.loop, + assert self.async_serializer is not None, ( + "Async serializer must be initialized via post_init before using " + "layerwise_batched_get." + ) + coro = self.async_serializer.run( + backend.batched_get_non_blocking("fake_lookup_id", keys_multi_chunk), + len(keys_multi_chunk), ) + task = asyncio.run_coroutine_threadsafe(coro, self.loop) yield task def prefetch_single_done_callback( @@ -516,16 +583,19 @@ async def async_lookup_and_prefetch( num_total_hit_chunks += num_hit_chunks - loading_task = asyncio.create_task( - self.async_serializer.run( - backend.batched_get_non_blocking( - lookup_id, - keys[:num_hit_chunks], - {"cum_chunk_lengths": cum_chunk_lengths[: num_hit_chunks + 1]}, - ), - num_hit_chunks, - ) + assert self.async_serializer is not None, ( + "Async serializer must be initialized via post_init before using " + "async_lookup_and_prefetch." ) + get_coro = self.async_serializer.run( + backend.batched_get_non_blocking( + lookup_id, + keys[:num_hit_chunks], + {"cum_chunk_lengths": cum_chunk_lengths[: num_hit_chunks + 1]}, + ), + num_hit_chunks, + ) + loading_task = asyncio.create_task(get_coro) loading_task.add_done_callback( functools.partial( self.prefetch_single_done_callback, @@ -600,6 +670,67 @@ def contains( return None + def batched_contains( + self, + keys: List[CacheEngineKey], + search_range: Optional[List[str]] = None, + pin: bool = False, + stop_after_first_not_exits: bool = True, + ) -> List[bool]: + """ + Check whether the key exists in the storage backend. + + :param List[CacheEngineKey] keys: The keys to check. + + :param Optional[List[str]] search_range: The range of storage backends + to search in. Should be a subset of ["LocalCPUBackend", + "LocalDiskBackend"] for now. + If None, search in all backends. + + :param bool pin: Whether to pin the key. + + :param bool stop_after_first_not_exits: Stop when find the first not exists key, + all subsequent results will return False directly. + + return: True if the key exists in the specified storage backends else False. + """ + + # TODO: Only single-layer batched_contains is supported currently. + # Only allocate backend is LocalCPUBackend and do not enable hot cache, + # check another backend is supported batched_contains + if ( + len(self.storage_backends) == 2 + and not self.config.enable_pd + and not self.config.local_cpu + and (search_range is None or len(search_range) == 1) + ): + for backend_name, backend in self.storage_backends.items(): + if backend_name == "LocalCPUBackend": + continue + if ( + search_range is None or search_range[0] == backend_name + ) and backend.support_batched_contains(): + return backend.batched_contains( + keys, pin, stop_after_first_not_exits + ) + + # default implementation + contains_res = [] + for key in keys: + res = self.contains(key, search_range, pin) + if res is not None: + contains_res.append(True) + else: + if stop_after_first_not_exits: + # fill the contains_res with None + current_len = len(contains_res) + contains_res.extend([False] * (len(keys) - current_len)) + break + else: + contains_res.append(False) + + return contains_res + def touch_cache(self): for backend_name, backend in self.storage_backends.items(): if backend_name == "LocalCPUBackend" or backend_name == "LocalDiskBackend": @@ -724,6 +855,23 @@ def memcheck(self) -> bool: return False return True + def get_non_allocator_backends(self) -> List[str]: + """ + Get the names of the actual storage backends. Some backends, + such as LocalCPUBackend and PDBackend, in some cases, only + serve as a backend for allocation. + """ + storage_names = [] + for backend_name, backend in self.storage_backends.items(): + if "LocalCPUBackend" == backend_name and not self.config.local_cpu: + # if local_cpu is False, means LocalCPUBackend is only a allocator + continue + if "PDBackend" == backend_name and backend.pd_config.role == "sender": # type: ignore + # if pd_config.role is sender, means PDBackend is only a allocator + continue + storage_names.append(backend_name) + return storage_names + def close(self): for backend in self.storage_backends.values(): backend.close() diff --git a/lmcache/v1/system_detection.py b/lmcache/v1/system_detection.py index b47182efc59..8b542bd6190 100644 --- a/lmcache/v1/system_detection.py +++ b/lmcache/v1/system_detection.py @@ -2,12 +2,19 @@ # Standard from dataclasses import dataclass from typing import Optional +import platform # Third Party +import psutil import torch if torch.cuda.is_available(): - from lmcache.c_ops import get_gpu_pci_bus_id + try: + # First Party + from lmcache.c_ops import get_gpu_pci_bus_id + except ImportError: + # Fallback if c_ops is not available + get_gpu_pci_bus_id = None # First Party from lmcache.logging import init_logger @@ -21,6 +28,30 @@ class NUMAMapping: gpu_to_numa_mapping: dict[int, int] +class SystemMemoryDetector: + @staticmethod + def get_available_memory_gb() -> float: + """ + Get system available memory in GB using psutil. + This method is cross-platform and doesn't require subprocess calls. + + Returns: + Available memory in GB, or 0.0 if detection fails. + """ + try: + # Use psutil to get virtual memory information + memory = psutil.virtual_memory() + available_gb = memory.available / (1024**3) + + system = platform.system() + logger.info(f"{system} system available memory: {available_gb:.2f} GB") + return available_gb + + except Exception as e: + logger.warning(f"Failed to get system available memory using psutil: {e}") + return 0.0 + + class NUMADetector: @staticmethod def get_numa_mapping(config: LMCacheEngineConfig) -> Optional[NUMAMapping]: diff --git a/lmcache/v1/token_database.py b/lmcache/v1/token_database.py index 427e0f31bdf..d8c76017663 100644 --- a/lmcache/v1/token_database.py +++ b/lmcache/v1/token_database.py @@ -1,7 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 +""" +vLLM compatibility notes: +- PR#20511: Introduced kv_cache_utils.init_none_hash() + https://github.com/vllm-project/vllm/pull/20511 +- PR#23673: Renamed sha256_cbor_64bit to sha256_cbor + https://github.com/vllm-project/vllm/pull/23673 +- PR#27151: Moved hash functions to vllm.utils.hashing module + https://github.com/vllm-project/vllm/pull/27151 + +TODO(baoloongmao): Move this to vllm_v1_adapter to decouple from vLLM +""" + # Standard from typing import Any, Iterable, List, Optional, Tuple, Union import abc +import os # Third Party from transformers import AutoTokenizer @@ -17,6 +30,10 @@ NONE_HASH: int +# Type alias for process_tokens return value +# (start_index, end_index, cache_engine_key|hash) +ProcessTokensResult = Tuple[int, int, Union[CacheEngineKey, int]] + class TokenDatabase(metaclass=abc.ABCMeta): """TokenDatabase is used to convert input tokens into list of @@ -37,42 +54,109 @@ def __init__( ): global NONE_HASH - vllm_is_available = True - try: - # Third Party - from vllm.utils import sha256, sha256_cbor_64bit - except ImportError: - # sha256, sha256_cbor_64bit are available through vLLM only - vllm_is_available = False - - hash_algorithm: str - if config is not None: - hash_algorithm = config.pre_caching_hash_algorithm - else: # Default value - hash_algorithm = "builtin" # fallback to builtin hash - - # Need to support vLLM hashing functions at a minimum - self.hash_func = ( - sha256_cbor_64bit - if hash_algorithm == "sha256_cbor_64bit" and vllm_is_available - else sha256 - if hash_algorithm == "sha256" and vllm_is_available - else hash + hash_algorithm: str = ( + config.pre_caching_hash_algorithm if config is not None else "builtin" ) + # Get hash function with vLLM version compatibility + self.hash_func = self._get_vllm_hash_func(hash_algorithm) + + # Initialize NONE_HASH (vLLM >= PR#20511) # NOTE: For centralized cache sharing, ensure PYTHONHASHSEED is # set consistently across all processes (e.g., export PYTHONHASHSEED=0). try: # Third Party from vllm.v1.core import kv_cache_utils - kv_cache_utils.init_none_hash(self.hash_func) - NONE_HASH = kv_cache_utils.NONE_HASH + if hasattr(kv_cache_utils, "init_none_hash"): + kv_cache_utils.init_none_hash(self.hash_func) + NONE_HASH = kv_cache_utils.NONE_HASH + logger.info( + f"Initialized NONE_HASH={NONE_HASH} from vLLM (>= PR#20511)" + ) + else: + NONE_HASH = 0 + logger.info("Using default NONE_HASH=0 (vLLM < PR#20511)") except (ImportError, AttributeError): NONE_HASH = 0 + logger.info("Using default NONE_HASH=0 (vLLM not available)") + logger.info(f"Using hash algorithm: {hash_algorithm}") self.metadata = metadata + def _get_vllm_hash_func(self, hash_algorithm: str): + """Get hash function from vLLM with version compatibility. + + Tries multiple import paths to support different vLLM versions: + - vllm.utils.hashing.get_hash_fn_by_name (>= PR#27151) + - vllm.utils.get_hash_fn_by_name (< PR#27151) + - Direct imports as fallback + - sha256_cbor_64bit -> sha256_cbor rename (PR#23673) + """ + # Try get_hash_fn_by_name from both locations (PR#27151) + for module_path in ["vllm.utils.hashing", "vllm.utils"]: + try: + module = __import__(module_path, fromlist=["get_hash_fn_by_name"]) + get_hash_fn_by_name = module.get_hash_fn_by_name + return self._try_get_hash( + get_hash_fn_by_name, hash_algorithm, module_path + ) + except (ImportError, AttributeError, ValueError): + continue + + # Try direct imports as fallback (for older vLLM versions) + func_names = ( + ["sha256_cbor", "sha256_cbor_64bit"] + if hash_algorithm in ("sha256_cbor", "sha256_cbor_64bit") + else [hash_algorithm] + ) + for module_path in ["vllm.utils.hashing", "vllm.utils"]: + for func_name in func_names: + try: + module = __import__(module_path, fromlist=[func_name]) + hash_func = getattr(module, func_name) + logger.info( + f"Loaded '{func_name}' from {module_path} (direct import)" + ) + return hash_func + except (ImportError, AttributeError): + continue + + # Fallback to builtin hash + logger.warning( + f"Could not load '{hash_algorithm}' from vLLM. Using builtin hash. " + "This may cause inconsistencies in distributed caching." + ) + + # Check PYTHONHASHSEED when using builtin hash + if os.getenv("PYTHONHASHSEED") is None: + logger.warning( + "Using builtin hash without PYTHONHASHSEED set. " + "For production environments (non-testing scenarios), you MUST set " + "PYTHONHASHSEED to ensure consistent hashing across processes. " + "Example: export PYTHONHASHSEED=0" + ) + + return hash + + def _try_get_hash(self, get_hash_fn_by_name, hash_algorithm: str, module_name: str): + """Try to get hash function, handling sha256_cbor_64bit rename.""" + # Handle sha256_cbor_64bit -> sha256_cbor rename (PR#23673) + names_to_try = ( + ["sha256_cbor", "sha256_cbor_64bit"] + if hash_algorithm in ("sha256_cbor", "sha256_cbor_64bit") + else [hash_algorithm] + ) + + for name in names_to_try: + try: + hash_func = get_hash_fn_by_name(name) + logger.info(f"Loaded '{name}' from {module_name}") + return hash_func + except ValueError: + continue + raise ValueError(f"Hash function '{hash_algorithm}' not found in {module_name}") + @abc.abstractmethod def process_tokens( self, @@ -82,7 +166,7 @@ def process_tokens( mask: Optional[torch.Tensor] = None, make_key: bool = True, request_configs: Optional[dict] = None, - ) -> Iterable[Tuple[int, int, Union[CacheEngineKey, int]]]: + ) -> Iterable[ProcessTokensResult]: """Process the tokens and return the corresponding cache engine keys. :param Optional[Union[torch.Tensor, List[int]]] tokens: The tokens to process. @@ -120,6 +204,7 @@ def _make_key_by_hash( self.metadata.world_size, self.metadata.worker_id, chunk_hash, + self.metadata.kv_dtype, request_configs, ) @@ -155,9 +240,6 @@ def __init__( self.save_unfull_chunk = config.save_unfull_chunk # Check for cross-process cache sharing setup - # Standard - import os - if os.getenv("PYTHONHASHSEED") is None: if config.remote_url is not None: logger.warning( @@ -220,7 +302,7 @@ def process_tokens( mask: Optional[torch.Tensor] = None, make_key: bool = True, request_configs: Optional[dict] = None, - ) -> Iterable[Tuple[int, int, Union[CacheEngineKey, int]]]: + ) -> Iterable[ProcessTokensResult]: """Process the tokens/hashes and return the corresponding cache engine keys. :param Optional[Union[torch.Tensor, List[int]]] tokens: The tokens to process. @@ -345,7 +427,7 @@ def process_tokens( mask: Optional[torch.Tensor] = None, make_key: bool = True, request_configs: Optional[dict] = None, - ) -> Iterable[Tuple[int, int, Union[CacheEngineKey, int]]]: + ) -> Iterable[ProcessTokensResult]: """Process the tokens and return the corresponding cache engine keys. :param Union[torch.Tensor, List[int]] tokens: The tokens to process. diff --git a/lmcache/v1/transfer_channel/nixl_channel.py b/lmcache/v1/transfer_channel/nixl_channel.py index 50aa4f630ae..f628c487023 100644 --- a/lmcache/v1/transfer_channel/nixl_channel.py +++ b/lmcache/v1/transfer_channel/nixl_channel.py @@ -555,10 +555,11 @@ def close(self): for thread in self.running_threads: thread.join() self.zmq_context.term() - self.agent.deregister_memory(self.reg_descs) - self.agent.release_dlist_handle(self.xfer_handler) + self.nixl_agent.deregister_memory(self.nixl_wrapper.reg_descs) + self.nixl_agent.release_dlist_handle(self.nixl_wrapper.xfer_handler) + for remote_xfer_handler in self.remote_xfer_handlers_dict.values(): - self.agent.release_dlist_handle(remote_xfer_handler) + self.nixl_agent.release_dlist_handle(remote_xfer_handler) @dataclass diff --git a/lmcache/v1/xpu_connector.py b/lmcache/v1/xpu_connector.py new file mode 100644 index 00000000000..fac285fb13e --- /dev/null +++ b/lmcache/v1/xpu_connector.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024-2025 LMCache Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Standard +from typing import List, Optional + +# Third Party +import torch + +# First Party +from lmcache.logging import init_logger +from lmcache.v1.gpu_connector import VLLMPagedMemGPUConnectorV2 +from lmcache.v1.memory_management import MemoryFormat, MemoryObj + +logger = init_logger(__name__) + + +class VLLMPagedMemXPUConnectorV2(VLLMPagedMemGPUConnectorV2): + """ + The GPU KV cache should be a nested tuple of K and V tensors. + More specifically, we have: + - GPUTensor = Tuple[KVLayer, ...] + - KVLayer = Tuple[Tensor, Tensor] + - Tensor: [num_blocks, block_size, num_heads, head_size] + + It will produce / consume memory object with KV_2LTD format + """ + + def __init__( + self, + hidden_dim_size: int, + num_layers: int, + use_gpu: bool = False, + **kwargs, + ): + """ + If use_gpu is true, it will create a gpu intermediate buffer. In this + case, it requires the following kwargs: + - chunk_size: The MAX size of the chunk to be copied to GPU. + - dtype: The data type of the intermediate buffer. + """ + self.hidden_dim_size = hidden_dim_size + self.num_layers = num_layers + self.kv_cache_pointers = torch.empty( + num_layers, dtype=torch.int64, device="cpu" + ) + # Not sure we need a dict here. Maybe a single GPU connector always + # works with a single device? + self.kv_cache_pointers_on_gpu: dict[int, torch.Tensor] = {} + self.page_buffer_size = 0 + + self.kvcaches: Optional[List[torch.Tensor]] = None + self.gpu_buffer: Optional[torch.Tensor] = None + self.use_mla = "use_mla" in kwargs and kwargs["use_mla"] + if use_gpu: + assert "chunk_size" in kwargs, ( + "chunk_size should be provided to create a GPU buffer." + ) + assert "dtype" in kwargs, "dtype should be provided to create a GPU buffer." + assert "device" in kwargs, ( + "device should be provided to create a GPU buffer." + ) + shape = self.get_shape(kwargs["chunk_size"]) + self.gpu_buffer = torch.empty( + shape, dtype=kwargs["dtype"], device=kwargs["device"] + ) + + def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs): + """Expect a kwarg 'kvcaches' which is a nested tuple of K and V tensors. + The kvcaches should correspond to the "WHOLE token sequence". + + Note: + 1. This function expects the 'slot_mapping' is a "full slot mapping" + where it's length is the same as the whole token sequence. + 2. In the case that there is prefix caching, slot_mapping will starts + with -1s until the end of the matched prefix. The start and end + should NEVER overlap with the prefix caching (which means the + underlying CUDA kernel will never see -1 in slot_mapping) + + + :raises ValueError: If 'kvcaches' is not provided in kwargs. + :raises AssertionError: If the memory object does not have a tensor. + :raises ValueError: If 'slot_mapping' is not provided in kwargs. + """ + assert memory_obj.tensor is not None + + if self.use_mla: + if memory_obj.metadata.fmt != MemoryFormat.KV_MLA_FMT: + raise ValueError( + "The memory object should be in KV_MLA_FMT format in" + " order to be processed by VLLMPagedMemXPUConnector" + ) + else: + if memory_obj.metadata.fmt != MemoryFormat.KV_2LTD: + raise ValueError( + "The memory object should be in KV_2LTD format in" + " order to be processed by VLLMPagedMemXPUConnector" + ) + + if "kvcaches" not in kwargs: + raise ValueError("'kvcaches' should be provided in kwargs.") + + if "slot_mapping" not in kwargs: + raise ValueError("'slot_mapping' should be provided in kwargs.") + + kvcaches: List[torch.Tensor] = kwargs["kvcaches"] + slot_mapping: torch.Tensor = kwargs["slot_mapping"] + slices = slot_mapping[start:end] + + if self.use_mla: + tmp = memory_obj.tensor[0].to(slot_mapping.device) + num_blocks, block_size, head_size = kvcaches[0].shape + total_blocks = num_blocks * block_size + for i, kvcache in enumerate(kvcaches): + kvcache.view(total_blocks, head_size).index_copy_(0, slices, tmp[i]) + else: + tmp_k = memory_obj.tensor[0].to(slot_mapping.device) + tmp_v = memory_obj.tensor[1].to(slot_mapping.device) + num_blocks, block_size, num_heads, head_size = kvcaches[0][0].shape + total_blocks = num_blocks * block_size + d = num_heads * head_size + for i, (kcache, vcache) in enumerate(kvcaches): + kcache.view(total_blocks, d).index_copy_(0, slices, tmp_k[i]) + vcache.view(total_blocks, d).index_copy_(0, slices, tmp_v[i]) + + def from_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs): + """Expect a kwarg 'kvcaches' which is a nested tuple of K and V tensors. + The kvcaches should correspond to the "WHOLE token sequence". + + Will set the memory_obj.metadata.fmt to MemoryFormat.KV_2LTD. + + Note: + 1. This function expects the 'slot_mapping' is a "full slot mapping" + where it's length is the same as the whole token sequence. + 2. In the case that there is prefix caching, slot_mapping will starts + with -1s until the end of the matched prefix. The start and end + should NEVER overlap with the prefix caching (which means the + underlying CUDA kernel will never see -1 in slot_mapping) + + :raises ValueError: If 'kvcaches' is not provided in kwargs, + :raises AssertionError: If the memory object does not have a tensor. + :raises ValueError: If 'slot_mapping' is not provided in kwargs. + """ + assert memory_obj.tensor is not None + + if "kvcaches" not in kwargs: + raise ValueError("'kvcaches' should be provided in kwargs.") + + if "slot_mapping" not in kwargs: + raise ValueError("'slot_mapping' should be provided in kwargs.") + + kvcaches: List[torch.Tensor] = kwargs["kvcaches"] + slot_mapping: torch.Tensor = kwargs["slot_mapping"] + slices = slot_mapping[start:end] + + if self.use_mla: + num_blocks, block_size, head_size = kvcaches[0].shape + total_blocks = num_blocks * block_size + tmp = torch.stack( + [ + kvcache.view(total_blocks, head_size).index_select(0, slices) + for kvcache in kvcaches + ] + ) + else: + num_blocks, block_size, num_heads, head_size = kvcaches[0][0].shape + total_blocks = num_blocks * block_size + d = num_heads * head_size + tmp_k = torch.stack( + [ + kvcache[0].view(total_blocks, d).index_select(0, slices) + for kvcache in kvcaches + ] + ) + tmp_v = torch.stack( + [ + kvcache[1].view(total_blocks, d).index_select(0, slices) + for kvcache in kvcaches + ] + ) + tmp = torch.stack([tmp_k, tmp_v]) + memory_obj.tensor.copy_(tmp, non_blocking=True) + + if not memory_obj.tensor.is_xpu: + # Force a synchronize if the target buffer is NOT XPU device + # NOTE: for better performance, we may not want to sync for every + # memory object + torch.xpu.synchronize() + + if self.use_mla: + memory_obj.metadata.fmt = MemoryFormat.KV_MLA_FMT + + # TODO(Jiayi): need to optimize to enable real batching + def batched_to_gpu(self, memory_objs, starts, ends, **kwargs): + for memory_obj, start, end in zip(memory_objs, starts, ends, strict=False): + self.to_gpu(memory_obj, start, end, **kwargs) diff --git a/requirements/bench.txt b/requirements/bench.txt index 051b770f571..9c8d191ac76 100644 --- a/requirements/bench.txt +++ b/requirements/bench.txt @@ -1,7 +1,7 @@ requests tqdm transformers -numpy +numpy<=2.2 pandas matplotlib seaborn \ No newline at end of file diff --git a/requirements/common.txt b/requirements/common.txt index f72aa0b97df..68067619f99 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -3,8 +3,13 @@ aiofiles aiohttp awscrt cufile-python +fastapi +httpx msgspec -numpy +# if nixl decides to support >=3.13 in the future, we can remove this constraint +nixl; python_version < "3.13" +# nixl uses numba which requires numpy<=2.2 +numpy<=2.2 nvtx prometheus_client >= 0.18.0 psutil @@ -29,3 +34,4 @@ sortedcontainers # 4. this torch version may also be overridden inside of LMCache/docker/Dockerfile torch transformers >= 4.51.1 +uvicorn \ No newline at end of file diff --git a/requirements/docs.txt b/requirements/docs.txt index cba0726dea6..3c52439fb82 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -11,4 +11,5 @@ sphinxawesome_theme==5.3.2 sphinx-copybutton==0.5.2 sphinxcontrib-mermaid==1.0.0 sphinx-multiversion==0.2.4 +sphinxcontrib-images msgspec diff --git a/tests/conftest.py b/tests/conftest.py index 27d60d2d571..d07196d2d14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -221,6 +221,32 @@ def slave_for( return self.slave_redis +class MockRedisCluster: + def __init__( + self, startup_nodes=None, max_connections=None, decode_responses=False, **kwargs + ): + self.startup_nodes = startup_nodes or [] + self.max_connections = max_connections + self.decode_responses = decode_responses + self.store = {} + + async def set(self, key, value): + self.store[key] = value + return True + + async def get(self, key): + return self.store.get(key, None) + + async def exists(self, key): + return key in self.store + + async def delete(self, key): + return self.store.pop(key, None) is not None + + async def close(self): + pass + + @dataclass class LMCacheServerProcess: server_url: str @@ -249,6 +275,12 @@ def mock_redis_sentinel(): yield mock +@pytest.fixture(scope="function", autouse=True) +def mock_redis_cluster(): + with patch("redis.asyncio.cluster.RedisCluster", MockRedisCluster) as mock: + yield mock + + @pytest.fixture(scope="module") def lmserver_v1_process(request): def ensure_connection(host, port): diff --git a/tests/disagg/test_nixl_channel.py b/tests/disagg/test_nixl_channel.py index be5d8655aa1..b6272b1e0bc 100644 --- a/tests/disagg/test_nixl_channel.py +++ b/tests/disagg/test_nixl_channel.py @@ -2,7 +2,6 @@ # Standard from typing import List, Tuple import argparse -import threading import time # Third Party @@ -13,19 +12,11 @@ from lmcache.logging import init_logger from lmcache.utils import CacheEngineKey from lmcache.v1.memory_management import AdHocMemoryAllocator, MemoryFormat, MemoryObj +from lmcache.v1.transfer_channel.nixl_channel import NixlChannel +from lmcache.v1.transfer_channel.transfer_utils import get_correct_device pytest.importorskip("nixl", reason="nixl package is required for nixl tests") -# First Party -# from lmcache.v1.storage_backend.connector.nixl_connector import ( -# NixlChannel, NixlConfig, NixlObserverInterface, NixlRole) -from lmcache.v1.storage_backend.connector.nixl_connector import ( - NixlChannel, - NixlConfig, - NixlObserverInterface, - NixlRole, -) - logger = init_logger(__name__) @@ -45,6 +36,7 @@ def generate_test_data( world_size=1, worker_id=0, chunk_hash=i, + dtype=dtype, ) ) obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_2LTD) @@ -61,173 +53,6 @@ def calculate_throughput(total_bytes: int, elapsed_time: float) -> float: return gb / elapsed_time -class TestObserver(NixlObserverInterface): - def __init__(self): - self.received_keys = [] - self.received_tensors = [] - self.received_objs = [] - self.received_event = threading.Event() - self.expected_count = None - self.reset() - - def set_expected_count(self, count: int): - self.expected_count = count - - def __call__(self, keys, objs, is_view=True): - logger.info(f"Observer received {len(keys)} keys and {len(objs)} objects") - - # Clear previous data if we're starting a new batch - if len(self.received_keys) == 0: - self.reset() - - self.received_keys.extend(keys) - - # If these are views, we need to make copies - if is_view: - for obj in objs: - copied_tensor = ( - obj.tensor.clone().detach() - ) # Detach to ensure no gradient history - self.received_tensors.append(copied_tensor) - # copied_obj = TensorMemoryObj(copied_tensor, obj.metadata) - else: - self.received_objs.extend(objs) - - if self.expected_count and len(self.received_objs) >= self.expected_count: - self.received_event.set() - - def summarize(self): - logger.info( - f"Received {len(self.received_keys)} keys and " - f"{len(self.received_tensors)} tensors" - ) - - def reset(self): - # Explicitly free any existing tensors - if hasattr(self, "received_objs"): - for obj in self.received_objs: - del obj.raw_data - del self.received_objs - - if hasattr(self, "received_keys"): - del self.received_keys - - if hasattr(self, "received_tensors"): - del self.received_tensors - - self.received_keys = [] - self.received_tensors = [] - self.received_event = threading.Event() - self.expected_count = None - torch.cuda.empty_cache() # Force CUDA memory cleanup - - -def send_and_measure_throughput( - channel: NixlChannel, - keys: List[CacheEngineKey], - objs: List[MemoryObj], - total_size: int, -) -> float: - """Send data through the channel and measure throughput. - - Args: - channel: The NixlChannel to send data through - keys: List of cache engine keys - objs: List of memory objects to send - total_size: Total size of objects in bytes - - Returns: - float: Throughput in GB/s - """ - # Wait a bit for the receiver to set up - time.sleep(2) - - # Send the data - logger.info(f"Sending {len(objs)} objects...") - start_time = time.time() - channel.send(keys, objs) - end_time = time.time() - - elapsed_time = end_time - start_time - logger.info(f"Sent {len(objs)} objects in {elapsed_time:.6f} seconds") - throughput = calculate_throughput(total_size, elapsed_time) - logger.info(f"Throughput: {throughput:.2f} GB/s") - return throughput - - -def receive_and_verify_data( - observer: TestObserver, - channel: NixlChannel, - expected_keys: List[CacheEngineKey], - expected_objs: List[MemoryObj], - timeout: int = 60, -) -> bool: - """Receive data through the channel and verify it matches expected data. - - Args: - channel: The NixlChannel to receive data through - expected_keys: List of expected cache engine keys - expected_objs: List of expected memory objects - timeout: Maximum time to wait for data in seconds - - Returns: - bool: True if all data was received and verified successfully - """ - # Create and register an observer - - try: - # Wait for all data to be received - logger.info("Waiting to receive data...") - start_time = time.time() - - while len(observer.received_tensors) < len(expected_keys): - if time.time() - start_time > timeout: - logger.error("Timed out waiting for data") - return False - logger.info( - f"Received {len(observer.received_tensors)}/" - f"{len(expected_keys)} tensors so far..." - ) - time.sleep(1) - - if len(observer.received_tensors) == len(expected_keys): - logger.info( - f"Received all {len(observer.received_keys)} keys and " - f"{len(observer.received_tensors)} tensors" - ) - - # Verify the received data - success = True - for i, (received_tensor, original_tensor) in enumerate( - zip(observer.received_tensors, expected_objs, strict=False) - ): - if not torch.allclose(received_tensor, original_tensor.tensor): - logger.error(f"Data mismatch at index {i}") - success = False - break - - for i, (received_key, original_key) in enumerate( - zip(observer.received_keys, expected_keys, strict=False) - ): - if received_key != original_key: - logger.error(f"Key mismatch at index {i}") - success = False - break - - return success - else: - logger.error( - f"Only received {len(observer.received_objs)}/" - f"{len(expected_keys)} objects before timeout" - ) - return False - finally: - # Always cleanup, even if verification fails - observer.summarize() - observer.reset() - torch.cuda.empty_cache() - - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Test NixlChannel with sender/receiver roles" @@ -268,33 +93,67 @@ def receive_and_verify_data( ) # Common configuration - config = NixlConfig( - role=NixlRole(args.role), - receiver_host=args.host, - receiver_port=args.port, - buffer_size=2**32, # 4GB - buffer_device="cuda", - enable_gc=False, - ) + buffer_size = 2**32 # 4GB + buffer_device = get_correct_device("cuda", 0) # Use first GPU + + # Get buffer pointer from first object + buffer_ptr = objs[0].metadata.address + align_bytes = 4096 # Standard page size # Create the NixlChannel - channel = NixlChannel(config) + channel = NixlChannel( + async_mode=False, + role=args.role, + buffer_ptr=buffer_ptr, + buffer_size=buffer_size, + align_bytes=align_bytes, + tp_rank=0, + peer_init_url=f"tcp://{args.host}:{args.port}", + backends=["UCX"], + ) if args.role == "sender": throughputs = [] for i in range(args.num_rounds): logger.info(f"Round {i + 1}/{args.num_rounds}") - throughput = send_and_measure_throughput(channel, keys, objs, total_size) + start_time = time.time() + num_sent = channel.batched_send(objs) + end_time = time.time() + elapsed_time = end_time - start_time + throughput = calculate_throughput(total_size, elapsed_time) + logger.info(f"Sent {num_sent} objects in {elapsed_time:.6f} seconds") + logger.info(f"Throughput: {throughput:.2f} GB/s") throughputs.append(throughput) avg_throughput = sum(throughputs) / len(throughputs) logger.info(f"Average throughput: {avg_throughput:.2f} GB/s") else: # receiver - observer = TestObserver() - observer.set_expected_count(len(keys)) - channel.register_receive_observer(observer) for i in range(args.num_rounds): logger.info(f"Round {i + 1}/{args.num_rounds}") - success = receive_and_verify_data(observer, channel, keys, objs) + start_time = time.time() + num_received = channel.batched_recv(objs) + end_time = time.time() + elapsed_time = end_time - start_time + throughput = calculate_throughput(total_size, elapsed_time) + logger.info( + f"Received {num_received} objects in {elapsed_time:.6f} seconds" + ) + logger.info(f"Throughput: {throughput:.2f} GB/s") + + # Verify data + for i, (received_obj, original_obj) in enumerate( + zip(objs, objs, strict=False) + ): + assert received_obj.tensor is not None, ( + f"Received tensor is None at index {i}" + ) + assert original_obj.tensor is not None, ( + f"Original tensor is None at index {i}" + ) + assert torch.allclose(received_obj.tensor, original_obj.tensor), ( + f"Data mismatch at index {i}: " + f"received {received_obj.tensor.mean().item()}" + f" but expected {original_obj.tensor.mean().item()}" + ) # Wait a bit before closing time.sleep(2) diff --git a/tests/disagg/test_nixl_channel_v2.py b/tests/disagg/test_nixl_channel_v2.py deleted file mode 100644 index 55d6d6d8146..00000000000 --- a/tests/disagg/test_nixl_channel_v2.py +++ /dev/null @@ -1,427 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Standard -from typing import List, Optional, Tuple -import argparse -import threading -import time - -# Third Party -import pytest -import torch - -pytest.importorskip("nixl", reason="nixl package is required for nixl tests") - -# First Party -from lmcache.logging import init_logger -from lmcache.utils import CacheEngineKey -from lmcache.v1.memory_management import AdHocMemoryAllocator, MemoryFormat, MemoryObj -from lmcache.v1.storage_backend.connector.nixl_connector_v2 import ( - NixlChannel, - NixlConfig, - NixlObserverInterface, - NixlRole, -) - -logger = init_logger(__name__) - - -def generate_test_data( - num_objs: int, shape: torch.Size, dtype: torch.dtype = torch.bfloat16 -) -> Tuple[List[CacheEngineKey], List[MemoryObj]]: - keys = [] - objs = [] - allocator = AdHocMemoryAllocator( - device="cuda", # Assuming we are using CUDA for the test - ) - for i in range(num_objs): - keys.append( - CacheEngineKey( - fmt="test", - model_name="test_model", - world_size=1, - worker_id=0, - chunk_hash=i, - ) - ) - obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_2LTD) - obj.tensor.fill_( - (i + 1) / num_objs - ) # Fill with some test data, e.g., the index - objs.append(obj) - return keys, objs - - -def calculate_throughput(total_bytes: int, elapsed_time: float) -> float: - """Calculate throughput in GB/s""" - if elapsed_time == 0: - return float("inf") - gb = total_bytes / (1024 * 1024 * 1024) - return gb / elapsed_time - - -class TestObserver(NixlObserverInterface): - def __init__(self): - self.key_to_tensors = {} # Map keys to received tensors - self.received_event = threading.Event() - self.expected_count = None - self.num_expected_senders = 1 # Default to 1 sender - self.reset() - - def set_expected_count(self, count: int): - self.expected_count = count - - def set_num_expected_senders(self, num_senders: int): - self.num_expected_senders = num_senders - - def __call__(self, keys, objs, is_view=True): - logger.info(f"Observer received {len(keys)} keys and {len(objs)} objects") - - # If these are views, we need to make copies - if is_view: - for i, obj in enumerate(objs): - copied_tensor = obj.tensor.clone().detach() - - # Store tensor by key for verification - key = keys[i] - if key not in self.key_to_tensors: - self.key_to_tensors[key] = [] - self.key_to_tensors[key].append(copied_tensor) - else: - # For non-view objects, still store them by key - for i, obj in enumerate(objs): - key = keys[i] - if key not in self.key_to_tensors: - self.key_to_tensors[key] = [] - self.key_to_tensors[key].append(obj.tensor) - - # Calculate total received tensors - total_received = sum(len(tensors) for tensors in self.key_to_tensors.values()) - - if ( - self.expected_count - and total_received >= self.expected_count * self.num_expected_senders - ): - self.received_event.set() - - def summarize(self): - total_tensors = sum(len(tensors) for tensors in self.key_to_tensors.values()) - logger.info( - f"Received {len(self.key_to_tensors)} unique keys and " - f"{total_tensors} total tensors" - ) - - def reset(self): - # Explicitly free any existing tensors - if hasattr(self, "key_to_tensors"): - for tensors in self.key_to_tensors.values(): - for tensor in tensors: - del tensor - del self.key_to_tensors - - self.key_to_tensors = {} - self.received_event = threading.Event() - self.expected_count = None - torch.cuda.empty_cache() # Force CUDA memory cleanup - - -def send_and_measure_throughput_v2( - channel: NixlChannel, - keys: List[CacheEngineKey], - objs: List[MemoryObj], - total_size: int, - batch_size: Optional[int] = None, - simulate_workload: bool = False, -) -> float: - """Send data through the channel and measure throughput using V2 API. - - Args: - channel: The NixlChannel to send data through - keys: List of cache engine keys - objs: List of memory objects to send - total_size: Total size of objects in bytes - batch_size: Size of batches to send (if None, send all at once) - simulate_workload: If True, sleep 50ms between batches - - Returns: - float: Throughput in GB/s - """ - logger.info(f"Sending {len(objs)} objects using zero_copy_send_with_callback...") - - elapsed_time = 0.0 - - if batch_size is None: - # Original behavior - send all at once - start_time = time.time() - metadatas = [obj.metadata for obj in objs] - channel.zero_copy_send_with_callback( - keys=keys, - metadatas=metadatas, - callback=lambda dest_obj, idx=0: dest_obj.tensor.copy_(objs[idx].tensor), # type: ignore - ) - elapsed_time = time.time() - start_time - else: - # Send in batches - elapsed_times: list[float] = [] - for i in range(0, len(objs), batch_size): - start_time = time.time() - batch_keys = keys[i : i + batch_size] - batch_objs = objs[i : i + batch_size] - batch_metadatas = [obj.metadata for obj in batch_objs] - - def callback(dest_obj, idx, batch_objs=batch_objs): - dest_obj.tensor.copy_(batch_objs[idx].tensor) - - channel.zero_copy_send_with_callback( - keys=batch_keys, metadatas=batch_metadatas, callback=callback - ) - this_round = time.time() - start_time - elapsed_times.append(this_round) - logger.info( - f"Sent batch {i // batch_size + 1}" - f"/{len(objs) // batch_size}" - f" in {this_round:.6f} seconds" - ) - if simulate_workload: - time.sleep(0.05) # Sleep 50ms between batches - elapsed_time = sum(elapsed_times) # type: ignore - logger.info(f"Elapsed times: {elapsed_times}") - - logger.info(f"Sent {len(objs)} objects in {elapsed_time:.6f} seconds") - throughput = calculate_throughput(total_size, elapsed_time) - logger.info(f"Throughput: {throughput:.2f} GB/s") - time.sleep(2) - return throughput - - -def receive_and_verify_data( - observer: TestObserver, - expected_keys: List[CacheEngineKey], - expected_objs: List[MemoryObj], - num_expected_senders: int = 1, - timeout: int = 60, -) -> bool: - """Receive data through the channel and verify it matches expected data. - - Args: - observer: The TestObserver that receives data - expected_keys: List of expected cache engine keys - expected_objs: List of expected memory objects - num_expected_senders: Number of senders expected to send the same data - timeout: Maximum time to wait for data in seconds - - Returns: - bool: True if all data was received and verified successfully - """ - try: - # Wait for all data to be received - logger.info("Waiting to receive data...") - start_time = time.time() - expected_total = len(expected_keys) * num_expected_senders - - # Calculate total received tensors - total_received = sum( - len(tensors) for tensors in observer.key_to_tensors.values() - ) - - while total_received < expected_total: - if time.time() - start_time > timeout: - logger.error("Timed out waiting for data") - return False - logger.info(f"Received {total_received}/{expected_total} tensors so far...") - time.sleep(1) - # Update total received count - total_received = sum( - len(tensors) for tensors in observer.key_to_tensors.values() - ) - - if total_received >= expected_total: - logger.info( - f"Received all {len(observer.key_to_tensors)} unique keys and " - f"{total_received} total tensors" - ) - - # Verify the received data - success = True - - # Check that we received the expected number of tensors for each key - for key in expected_keys: - if key not in observer.key_to_tensors: - logger.error(f"Missing key: {key}") - success = False - continue - - if len(observer.key_to_tensors[key]) != num_expected_senders: - logger.error( - f"Expected {num_expected_senders} objs for key {key}, " - f"but got {len(observer.key_to_tensors[key])}" - ) - success = False - continue - - chunk_hash = key.chunk_hash - try: - idx = chunk_hash - expected_value = (idx + 1) / len( - expected_keys - ) # Match the value in generate_test_data - - # Verify the data for this key - for tensor in observer.key_to_tensors[key]: - # Check if tensor values match expected value - if not torch.allclose( - tensor, torch.full_like(tensor, expected_value) - ): - logger.error( - f"Data mismatch for key {key}. " - f"Received value: {tensor.flatten()[0]}. " - f"Expected value: {expected_value}" - ) - success = False - except (IndexError, ValueError) as e: - logger.error(f"Error parsing chunk_hash {chunk_hash}: {e}") - success = False - - return success - else: - logger.error( - f"Only received {total_received}/{expected_total} " - "tensors before timeout" - ) - return False - finally: - # Always cleanup, even if verification fails - observer.summarize() - observer.reset() - torch.cuda.empty_cache() - - -@pytest.mark.skip(reason="test needs to be parameterized") -def test_allocate_for_send( - channel: NixlChannel, shape: torch.Size, dtype: torch.dtype -) -> None: - """Test the allocate_for_send API""" - logger.info("Testing allocate_for_send API...") - - # Create test keys - keys = [ - CacheEngineKey( - fmt="test", - model_name="test_model", - world_size=1, - worker_id=0, - chunk_hash=i, - ) - for i in range(3) - ] - - # Create test metadatas - allocator = AdHocMemoryAllocator(device="cuda") - temp_objs = [allocator.allocate(shape, dtype) for _ in range(3)] - metadatas = [obj.metadata for obj in temp_objs] - - # Prepare send - channel.prepare_send(keys, metadatas) - - # Allocate and fill objects - for i in range(3): - obj = channel.allocate_for_send(shape, dtype) - assert obj is not None, "Failed to allocate memory for send" - assert obj.tensor is not None - obj.tensor.fill_(i + 10) # Fill with test data - - # Finish send - channel.finish_send() - logger.info("allocate_for_send test completed") - - -def main(): - parser = argparse.ArgumentParser( - description="Test NixlChannel V2 with sender/receiver roles" - ) - parser.add_argument( - "--role", - type=str, - required=True, - choices=["sender", "receiver"], - help="Role of this instance (sender or receiver)", - ) - parser.add_argument( - "--host", - type=str, - default="localhost", - help="Host name/IP for connection", - ) - parser.add_argument( - "--port", type=int, default=5555, help="Port number for connection" - ) - parser.add_argument( - "--num-objs", type=int, default=100, help="Number of objects to send" - ) - parser.add_argument( - "--batch-size", - type=int, - help="Size of batches to send (default: send all at once)", - ) - parser.add_argument( - "--simulate-workload", - action="store_true", - help="Simulate workload by sleeping 50ms between batches", - ) - parser.add_argument( - "--num-expected-senders", - type=int, - default=1, - help="Number of senders expected to connect (receiver only)", - ) - args = parser.parse_args() - - # Generate test data - keys, objs = generate_test_data(args.num_objs, torch.Size([32, 2, 256, 1024])) - total_size = sum(obj.get_size() for obj in objs) - logger.info( - f"Generated {len(objs)} objects with total size " - f"{total_size / (1024 * 1024):.2f} MB" - ) - - # Common configuration - config = NixlConfig( - role=NixlRole(args.role), - receiver_host=args.host, - receiver_port=args.port, - buffer_size=2**32, # 4GB - buffer_device="cuda:0", - enable_gc=False, - ) - - # Create the NixlChannel - channel = NixlChannel(config) - - if args.role == "sender": - throughput = send_and_measure_throughput_v2( - channel, - keys, - objs, - total_size, - batch_size=args.batch_size, - simulate_workload=args.simulate_workload, - ) - logger.info(f"Throughput: {throughput:.2f} GB/s") - else: # receiver - observer = TestObserver() - observer.set_expected_count(len(keys)) - observer.set_num_expected_senders(args.num_expected_senders) - channel.register_receive_observer(observer) - success = receive_and_verify_data( - observer, keys, objs, args.num_expected_senders - ) - if not success: - logger.error("Data verification failed") - - # Wait a bit before closing - time.sleep(2) - channel.close() - logger.info("Test completed") - - -if __name__ == "__main__": - main() diff --git a/tests/disagg/test_nixl_pipe.py b/tests/disagg/test_nixl_pipe.py deleted file mode 100644 index d2f37841394..00000000000 --- a/tests/disagg/test_nixl_pipe.py +++ /dev/null @@ -1,197 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Standard -from typing import List, Tuple -import argparse -import time - -# Third Party -import pytest -import torch -import zmq - -pytest.importorskip("nixl", reason="nixl package is required for nixl tests") - -# First Party -from lmcache.logging import init_logger -from lmcache.utils import CacheEngineKey -from lmcache.v1.memory_management import AdHocMemoryAllocator, MemoryFormat, MemoryObj -from lmcache.v1.storage_backend.connector.nixl_connector import ( - NixlConfig, - NixlPipe, - NixlRole, -) - -logger = init_logger(__name__) - - -def generate_test_data( - num_objs: int, shape: torch.Size, dtype: torch.dtype = torch.bfloat16 -) -> Tuple[List[CacheEngineKey], List[MemoryObj]]: - keys = [] - objs = [] - allocator = AdHocMemoryAllocator( - device="cuda", # Assuming we are using CUDA for the test - ) - for i in range(num_objs): - keys.append( - CacheEngineKey( - fmt="test", - model_name="test_model", - world_size=1, - worker_id=0, - chunk_hash=i, - ) - ) - obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_2LTD) - obj.tensor.fill_(i + 1) # Fill with some test data, e.g., the index - objs.append(obj) - return keys, objs - - -def calculate_throughput(total_bytes: int, elapsed_time: float) -> float: - """Calculate throughput in GB/s""" - if elapsed_time == 0: - return float("inf") - gb = total_bytes / (1024 * 1024 * 1024) - return gb / elapsed_time - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Test NixlChannel with sender/receiver roles" - ) - parser.add_argument( - "--role", - type=str, - required=True, - choices=["sender", "receiver"], - help="Role of this instance (sender or receiver)", - ) - parser.add_argument( - "--host", - type=str, - default="localhost", - help="Host name/IP for connection", - ) - parser.add_argument( - "--port", type=int, default=5555, help="Port number for connection" - ) - parser.add_argument( - "--num-rounds", - type=int, - default=1, - help="Number of rounds to run the experiment", - ) - - args = parser.parse_args() - - keys, objs = generate_test_data(100, torch.Size([32, 2, 256, 1024])) - - # Common configuration - config = NixlConfig( - role=NixlRole(args.role), - receiver_host=args.host, - receiver_port=args.port, - buffer_size=2**32, # 4GB - buffer_device="cuda", - enable_gc=False, - ) - - context = zmq.Context() # type: ignore - side_channel = context.socket(zmq.PAIR) # type: ignore - if args.role == "sender": - side_channel.bind(f"tcp://{args.host}:{args.port}") - else: - side_channel.connect(f"tcp://{args.host}:{args.port}") - - # Test the NIXLPipe - pipe = NixlPipe(config, side_channel) - - total_commit_time = 0.0 - total_wait_time = 0.0 - total_bytes_transferred = 0 - - for round_num in range(args.num_rounds): - logger.info(f"Starting round {round_num + 1}/{args.num_rounds}") - - initial_uuid = f"test_{round_num}" - next_uuid = f"new_test_{round_num}" - - if args.role == "sender": - # Write data to buffer (not timed) - num_objs, total_size = pipe.write_buffer(objs) - logger.info(f"Wrote {num_objs} objects to the buffer") - - # Measure commit time (actual transfer) - commit_start = time.time() - new_uuid = pipe.commit_write(total_size, initial_uuid) - commit_end = time.time() - commit_time = commit_end - commit_start - - total_commit_time += commit_time - total_bytes_transferred += total_size - - logger.info(f"New UUID: {new_uuid}") - logger.info(f"Transfer time: {commit_time:.6f} seconds") - transfer_throughput = calculate_throughput(total_size, commit_time) - logger.info(f"Transfer throughput: {transfer_throughput:.2f} GB/s") - - assert new_uuid == next_uuid, ( - f"Expected new UUID '{next_uuid}', but got '{new_uuid}'" - ) - else: - # Measure wait time (actual transfer) - wait_start = time.time() - pipe.wait_read(initial_uuid) - wait_end = time.time() - wait_time = wait_end - wait_start - - total_wait_time += wait_time - - logger.info(f"Transfer wait time: {wait_time:.6f} seconds") - - # Read data from buffer (not timed) - metadatas = [obj.metadata for obj in objs] - received_objs = pipe.read_buffer(metadatas) - total_size = sum(obj.get_size() for obj in received_objs) - - total_bytes_transferred += total_size - - logger.info(f"Received {len(received_objs)} objects") - transfer_throughput = calculate_throughput(total_size, wait_time) - logger.info(f"Transfer throughput: {transfer_throughput:.2f} GB/s") - - # Check if the received objects are the same as the original objects - for received_obj, original_obj in zip(received_objs, objs, strict=False): - assert received_obj.tensor is not None - assert original_obj.tensor is not None - assert torch.allclose(received_obj.tensor, original_obj.tensor), ( - f"Data mismatch: received {received_obj.tensor.mean()}" - f" but expected {original_obj.tensor.mean()}" - ) - - # Send acknowledgment - pipe.ack_receive(next_uuid) - - # Print aggregate statistics - if args.num_rounds > 1: - if args.role == "sender": - avg_time = total_commit_time / args.num_rounds - logger.info(f"Average transfer time: {avg_time:.6f} seconds") - else: - avg_time = total_wait_time / args.num_rounds - logger.info(f"Average wait time: {avg_time:.6f} seconds") - - avg_throughput = calculate_throughput( - total_bytes_transferred, - total_commit_time if args.role == "sender" else total_wait_time, - ) - logger.info( - f"Average throughput over {args.num_rounds} rounds: " - f"{avg_throughput:.2f} GB/s" - ) - - # Wait a bit before closing - time.sleep(2) - pipe.close() - logger.info("Test completed successfully") diff --git a/tests/disagg/test_nixl_pipe_v2.py b/tests/disagg/test_nixl_pipe_v2.py deleted file mode 100644 index c8d3957ce95..00000000000 --- a/tests/disagg/test_nixl_pipe_v2.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Standard -from typing import List, Tuple -import argparse -import time - -# Third Party -import pytest -import torch -import zmq - -pytest.importorskip("nixl", reason="nixl package is required for nixl tests") - -# First Party -from lmcache.logging import init_logger -from lmcache.utils import CacheEngineKey -from lmcache.v1.memory_management import ( - AdHocMemoryAllocator, - MemoryFormat, - MemoryObj, - TensorMemoryObj, -) -from lmcache.v1.storage_backend.connector.nixl_connector_v2 import ( - NixlConfig, - NixlPipe, - NixlRole, -) - -logger = init_logger(__name__) - - -def generate_test_data( - num_objs: int, shape: torch.Size, dtype: torch.dtype = torch.bfloat16 -) -> Tuple[List[CacheEngineKey], List[MemoryObj]]: - keys = [] - objs = [] - allocator = AdHocMemoryAllocator( - device="cuda", # Assuming we are using CUDA for the test - ) - for i in range(num_objs): - keys.append( - CacheEngineKey( - fmt="test", - model_name="test_model", - world_size=1, - worker_id=0, - chunk_hash=i, - ) - ) - obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_2LTD) - obj.tensor.fill_(i + 1) # Fill with some test data, e.g., the index - objs.append(obj) - return keys, objs - - -def calculate_throughput(total_bytes: int, elapsed_time: float) -> float: - """Calculate throughput in GB/s""" - if elapsed_time == 0: - return float("inf") - gb = total_bytes / (1024 * 1024 * 1024) - return gb / elapsed_time - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Test NixlPipe V2 with sender/receiver roles" - ) - parser.add_argument( - "--role", - type=str, - required=True, - choices=["sender", "receiver"], - help="Role of this instance (sender or receiver)", - ) - parser.add_argument( - "--host", - type=str, - default="localhost", - help="Host name/IP for connection", - ) - parser.add_argument( - "--port", type=int, default=5555, help="Port number for connection" - ) - parser.add_argument( - "--num-rounds", - type=int, - default=1, - help="Number of rounds to run the experiment", - ) - parser.add_argument( - "--num-objs", - type=int, - default=100, - help="Number of objects to transfer", - ) - parser.add_argument( - "--simulate-work", - action="store_true", - help="Simulate some work on both sides", - ) - - args = parser.parse_args() - - keys, objs = generate_test_data(args.num_objs, torch.Size([32, 2, 256, 1024])) - - # Common configuration - config = NixlConfig( - role=NixlRole(args.role), - receiver_host=args.host, - receiver_port=args.port, - buffer_size=2**32, # 4GB - buffer_device="cuda:0", - enable_gc=False, - ) - - context = zmq.Context() # type: ignore - side_channel = context.socket(zmq.PAIR) # type: ignore - if args.role == "sender": - side_channel.bind(f"tcp://{args.host}:{args.port}") - else: - side_channel.connect(f"tcp://{args.host}:{args.port}") - - # Test the NIXLPipe - pipe = NixlPipe(config, side_channel) - - total_transfer_time = 0.0 - total_bytes_transferred = 0 - - for round_num in range(args.num_rounds): - logger.info(f"Starting round {round_num + 1}/{args.num_rounds}") - - if args.role == "sender": - # Sender side - total_size = 0 - - # Allocate and write data to buffer - transfer_time = 0.0 - for idx, obj in enumerate(objs): - if args.simulate_work and idx % 10 == 0: - time.sleep(0.05) # Simulate some work - - # Use the new allocate_for_write method - transfer_start = time.time() - new_obj = pipe.allocate_for_write( - obj.get_shape(), obj.get_dtype(), obj.metadata.fmt - ) - if new_obj is not None and new_obj.tensor is not None: - # Copy data from original object to the new one - new_obj.tensor.copy_(obj.tensor) - total_size += new_obj.get_size() - transfer_time += time.time() - transfer_start - # Measure transfer time - flush_start = time.time() - pipe.flush() # This will wait for receiver's ack - flush_end = time.time() - transfer_time += flush_end - flush_start - - total_transfer_time += transfer_time - total_bytes_transferred += total_size - - logger.info(f"Transfer time: {transfer_time:.6f} seconds") - transfer_throughput = calculate_throughput(total_size, transfer_time) - logger.info(f"Transfer throughput: {transfer_throughput:.2f} GB/s") - - else: - # Receiver side - # Read data from buffer - transfer_start = time.time() - metadatas = [obj.metadata for obj in objs] - received_objs: list[MemoryObj] = [] - while len(received_objs) < len(metadatas): - pipe.wait_read() - new_objs = pipe.read_buffer(metadatas[len(received_objs) :]) - nobj_before = len(received_objs) - for idx, obj in enumerate(new_objs): - cloned_tensor = obj.tensor.detach().clone() - received_objs.append( - TensorMemoryObj(cloned_tensor, obj.metadata, None) - ) - - # Simulate some work: 20ms per 10 objects - if args.simulate_work and len(received_objs) % 10 == 0: - time.sleep(0.02) - - pipe.ack_receive() - transfer_end = time.time() - transfer_time = transfer_end - transfer_start - total_size = sum(obj.get_size() for obj in received_objs) - - total_bytes_transferred += total_size - total_transfer_time += transfer_time - - logger.info(f"Received {len(received_objs)} objects") - transfer_throughput = calculate_throughput(total_size, transfer_time) - logger.info(f"Transfer throughput: {transfer_throughput:.2f} GB/s") - - # Check if the received objects are the same as the original objects - assert len(received_objs) == len(objs), ( - "Number of received objects does not match the number of " - "original objects" - ) - for i, (received_obj, original_obj) in enumerate( - zip(received_objs, objs, strict=False) - ): - assert received_obj.tensor is not None - assert original_obj.tensor is not None - assert torch.allclose(received_obj.tensor, original_obj.tensor), ( - f"Data mismatch at index {i}: received " - f"{received_obj.tensor.mean()} " - f"but expected {original_obj.tensor.mean()}" - ) - logger.info("Round passed") - - # Print aggregate statistics - if args.num_rounds > 1: - avg_time = total_transfer_time / args.num_rounds - logger.info(f"Average transfer time: {avg_time:.6f} seconds") - - avg_throughput = calculate_throughput( - total_bytes_transferred, total_transfer_time - ) - logger.info(f"Total bytes transferred: {total_bytes_transferred}") - logger.info( - f"Average throughput over {args.num_rounds} rounds: " - f"{avg_throughput:.2f} GB/s" - ) - - # Wait a bit before closing - time.sleep(5) - pipe.close() - logger.info("Test completed successfully") diff --git a/tests/disagg/test_nixl_storage_backend.py b/tests/disagg/test_nixl_storage_backend.py index 17ebc69720a..82028710be1 100644 --- a/tests/disagg/test_nixl_storage_backend.py +++ b/tests/disagg/test_nixl_storage_backend.py @@ -2,6 +2,10 @@ # Standard from typing import List, Tuple import argparse +import asyncio +import os +import tempfile +import threading import time # Third Party @@ -11,11 +15,15 @@ pytest.importorskip("nixl", reason="nixl package is required for nixl tests") # First Party +from lmcache.config import LMCacheEngineMetadata from lmcache.logging import init_logger from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig from lmcache.v1.memory_management import AdHocMemoryAllocator, MemoryFormat, MemoryObj -from lmcache.v1.storage_backend.connector.nixl_connector import NixlConfig, NixlRole -from lmcache.v1.storage_backend.nixl_backend import NixlBackend +from lmcache.v1.storage_backend.nixl_storage_backend import ( + NixlStorageBackend, + NixlStorageConfig, +) logger = init_logger(__name__) @@ -26,7 +34,7 @@ def generate_test_data( keys = [] objs = [] allocator = AdHocMemoryAllocator( - device="cuda", # Assuming we are using CUDA for the test + device="cuda" if torch.cuda.is_available() else "cpu", ) for i in range(num_objs): keys.append( @@ -36,12 +44,11 @@ def generate_test_data( world_size=1, worker_id=0, chunk_hash=i, + dtype=dtype, ) ) obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_2LTD) - obj.tensor.fill_( - (i + 1) / num_objs - ) # Fill with some test data, e.g., the index + obj.tensor.fill_((i + 1) / num_objs) # Fill with some test data objs.append(obj) return keys, objs @@ -54,189 +61,331 @@ def calculate_throughput(total_bytes: int, elapsed_time: float) -> float: return gb / elapsed_time -def send_and_measure_throughput( - backend: NixlBackend, - keys: List[CacheEngineKey], - objs: List[MemoryObj], - wait_time: float = 2.0, -) -> float: - """Send objects through the backend and measure throughput. - - Args: - backend: The NixlBackend instance - keys: List of cache engine keys - objs: List of memory objects to send - wait_time: Time to wait for receiver setup in seconds - - Returns: - float: Throughput in GB/s - """ - # Wait for the receiver to set up - time.sleep(wait_time) - - total_size = sum(obj.get_size() for obj in objs) - logger.info("Sending %d objects...", len(objs)) - - backend.register_put_tasks(keys, [obj.metadata for obj in objs]) - start_time = time.time() - backend.batched_submit_put_task(keys, objs) - backend.flush_put_tasks() - end_time = time.time() - - elapsed_time = end_time - start_time - logger.info("Sent %d objects in %.6f seconds", len(objs), elapsed_time) - throughput = calculate_throughput(total_size, elapsed_time) - logger.info("Throughput: %.2f GB/s", throughput) - - return throughput - - -def receive_and_verify_data( - backend: NixlBackend, - keys: List[CacheEngineKey], - num_objs: int, - timeout: float = 60.0, -) -> bool: - """Receive and verify data through the backend. - - Args: - backend: The NixlBackend instance - keys: List of cache engine keys to check - num_objs: Number of objects expected - timeout: Maximum time to wait for data in seconds - - Returns: - bool: True if all data was received and verified correctly - """ - logger.info("Waiting to receive data...") - - # Poll until we receive all objects or timeout - received_count = 0 - start_time = time.time() - - while received_count < num_objs: - received_count = sum(1 for key in keys if backend.contains(key)) - - if received_count == num_objs: - break - - if time.time() - start_time > timeout: - logger.error( - "Timed out waiting for data. Received only %d/%d objects.", - received_count, - num_objs, - ) - return False - - time.sleep(0.1) # Small sleep to avoid busy waiting - - passed_check = True - if received_count == num_objs: - logger.info("Received all %d objects", num_objs) - - # Verify the received data - for i, key in enumerate(keys): - received_obj = backend.get_blocking(key) - if received_obj is None or received_obj.tensor is None: - logger.error(f"Failed to retrieve object for key {key}") - passed_check = False - break - - # Check if the received object matches the original object - expected_value = (i + 1) / num_objs - actual_mean = received_obj.tensor.mean().item() - - # For bfloat16, we need some tolerance in the comparison - if abs(actual_mean - expected_value) > 0.01: - logger.error( - "Mismatch for key %s: received mean %f but expected %f", - key, - actual_mean, - expected_value, - ) - passed_check = False - break - - if passed_check: - logger.info("All data verified successfully!") - else: - logger.error("Data verification failed!") +def create_test_config( + buffer_device: str = "cuda" if torch.cuda.is_available() else "cpu", + backend: str = "GDS_MT" if torch.cuda.is_available() else "POSIX", +) -> LMCacheEngineConfig: + """Create a test configuration for NixlStorageBackend""" + config = LMCacheEngineConfig() + config.nixl_buffer_size = 2**32 # 4GB + config.nixl_buffer_device = buffer_device + config.extra_config = { + "enable_nixl_storage": True, + "nixl_backend": backend, + "nixl_file_pool_size": 10, + "nixl_path": tempfile.mkdtemp(), # Create a temporary directory for testing + } + return config + + +def create_test_metadata() -> LMCacheEngineMetadata: + """Create test metadata for NixlStorageBackend""" + return LMCacheEngineMetadata( + model_name="test_model", + worker_id=0, + world_size=1, + fmt="test", + kv_dtype=torch.bfloat16, + kv_shape=( + 32, + 2, + 256, + 1024, + 128, + ), # (num_layer, 2, chunk_size, num_kv_head, head_size) + ) + + +@pytest.mark.no_shared_allocator +def test_nixl_storage_config(): + """Test NixlStorageConfig creation and validation""" + config = create_test_config() + metadata = create_test_metadata() + + nixl_config = NixlStorageConfig.from_cache_engine_config(config, metadata) + assert nixl_config.buffer_size == config.nixl_buffer_size + assert nixl_config.buffer_device == config.nixl_buffer_device + assert nixl_config.backend == config.extra_config["nixl_backend"] + assert nixl_config.file_pool_size == config.extra_config["nixl_file_pool_size"] + assert nixl_config.path == config.extra_config["nixl_path"] + + # Test validation + assert NixlStorageConfig.validate_nixl_backend("GDS", "cuda") + assert NixlStorageConfig.validate_nixl_backend("GDS", "cpu") + assert NixlStorageConfig.validate_nixl_backend("GDS_MT", "cuda") + assert NixlStorageConfig.validate_nixl_backend("GDS_MT", "cpu") + assert NixlStorageConfig.validate_nixl_backend("POSIX", "cpu") + assert not NixlStorageConfig.validate_nixl_backend("POSIX", "cuda") + assert not NixlStorageConfig.validate_nixl_backend("INVALID", "cpu") + + +@pytest.mark.no_shared_allocator +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") +def test_nixl_storage_backend_basic(): + """Test basic NixlStorageBackend operations""" + config = create_test_config() + metadata = create_test_metadata() + + thread_loop = None + thread = None + backend = None + try: + thread_loop = asyncio.new_event_loop() + thread = threading.Thread(target=thread_loop.run_forever) + thread.start() + + backend = NixlStorageBackend.CreateNixlStorageBackend( + config=config, + loop=thread_loop, + metadata=metadata, + ) + + # Test allocation + shape = torch.Size([32, 2, 256, 1024]) + dtype = torch.bfloat16 + obj = backend.allocate(shape, dtype) + assert obj is not None + assert obj.tensor is not None + assert obj.tensor.shape == shape + assert obj.tensor.dtype == dtype + + # Test batched allocation + batch_size = 5 + objs = backend.batched_allocate(shape, dtype, batch_size) + assert objs is not None + assert len(objs) == batch_size + for obj in objs: + assert obj.tensor is not None + assert obj.tensor.shape == shape + assert obj.tensor.dtype == dtype + + except Exception: + raise + finally: + if backend: + backend.close() + if thread_loop and thread_loop.is_running(): + thread_loop.call_soon_threadsafe(thread_loop.stop) + if thread and thread.is_alive(): + thread.join() + # Cleanup temporary directory + if os.path.exists(config.extra_config["nixl_path"]): + os.rmdir(config.extra_config["nixl_path"]) + + +@pytest.mark.no_shared_allocator +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") +def test_nixl_storage_backend_put_get(): + """Test put and get operations in NixlStorageBackend""" + config = create_test_config() + metadata = create_test_metadata() + + thread_loop = None + thread = None + backend = None + try: + thread_loop = asyncio.new_event_loop() + thread = threading.Thread(target=thread_loop.run_forever) + thread.start() + + backend = NixlStorageBackend.CreateNixlStorageBackend( + config=config, + loop=thread_loop, + metadata=metadata, + ) + + # Generate test data + keys, objs = generate_test_data(10, torch.Size([32, 2, 256, 1024])) + + # Test contains before put + for key in keys: + assert not backend.contains(key) + assert not backend.exists_in_put_tasks(key) + + # Test put + backend.batched_submit_put_task(keys, objs) + + # Test get + for key, original_obj in zip(keys, objs, strict=False): + assert backend.contains(key) + retrieved_obj = backend.get_blocking(key) + assert retrieved_obj is not None + assert retrieved_obj.tensor is not None + assert torch.equal(retrieved_obj.tensor, original_obj.tensor) + + # Test batched get + retrieved_objs = asyncio.run( + backend.batched_get_non_blocking(lookup_id="test", keys=keys) + ) + assert len(retrieved_objs) == len(objs) + for retrieved_obj, original_obj in zip(retrieved_objs, objs, strict=False): + assert retrieved_obj is not None + assert retrieved_obj.tensor is not None + assert torch.equal(retrieved_obj.tensor, original_obj.tensor) + # Test remove for key in keys: backend.remove(key) + assert not backend.contains(key) + + except Exception: + raise + finally: + if backend: + backend.close() + if thread_loop and thread_loop.is_running(): + thread_loop.call_soon_threadsafe(thread_loop.stop) + if thread and thread.is_alive(): + thread.join() + # Cleanup temporary directory + if os.path.exists(config.extra_config["nixl_path"]): + os.rmdir(config.extra_config["nixl_path"]) + + +@pytest.mark.no_shared_allocator +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") +def test_nixl_storage_backend_different_backends(): + """Test NixlStorageBackend with different backend types""" + backends = ( + [ + ("GDS_MT", "cuda"), + ("GDS", "cuda"), + ("GDS_MT", "cpu"), + ("GDS", "cpu"), + ("POSIX", "cpu"), + ] + if torch.cuda.is_available() + else [ + ("GDS_MT", "cpu"), + ("GDS", "cpu"), + ("POSIX", "cpu"), + ] + ) - return passed_check - else: - logger.error("Only received %d/%d objects", received_count, num_objs) - return False + for backend_type, device in backends: + config = create_test_config(buffer_device=device, backend=backend_type) + metadata = create_test_metadata() + + thread_loop = None + thread = None + backend = None + try: + thread_loop = asyncio.new_event_loop() + thread = threading.Thread(target=thread_loop.run_forever) + thread.start() + + backend = NixlStorageBackend.CreateNixlStorageBackend( + config=config, + loop=thread_loop, + metadata=metadata, + ) + + # Basic allocation test + obj = backend.allocate(torch.Size([32, 2, 256, 1024]), torch.bfloat16) + assert obj is not None + assert obj.tensor is not None + + except Exception: + raise + finally: + if backend: + backend.close() + if thread_loop and thread_loop.is_running(): + thread_loop.call_soon_threadsafe(thread_loop.stop) + if thread and thread.is_alive(): + thread.join() + # Cleanup temporary directory + if os.path.exists(config.extra_config["nixl_path"]): + os.rmdir(config.extra_config["nixl_path"]) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Test NixlBackend with sender/receiver roles" + description="Test NixlStorageBackend with different configurations" ) parser.add_argument( - "--role", + "--backend", type=str, - required=True, - choices=["sender", "receiver"], - help="Role of this instance (sender or receiver)", + default="GDS_MT", + choices=["GDS_MT", "GDS", "POSIX"], + help="NIXL backend type to use", ) parser.add_argument( - "--host", + "--device", type=str, - default="localhost", - help="Host name/IP for connection", - ) - parser.add_argument( - "--port", type=int, default=5555, help="Port number for connection" - ) - parser.add_argument( - "--num-objs", type=int, default=100, help="Number of objects to send" + default="cuda", + choices=["cuda", "cpu"], + help="Device to use for buffer", ) parser.add_argument( - "--num-rounds", + "--num-objs", type=int, - default=1, - help="Number of rounds to run the experiment", + default=100, + help="Number of objects to test with", ) args = parser.parse_args() - # Generate test data - keys, objs = generate_test_data(args.num_objs, torch.Size([32, 2, 256, 1024])) - total_size = sum(obj.get_size() for obj in objs) - logger.info( - "Generated %d objects with total size %.2f MB", - len(objs), - total_size / (1024 * 1024), - ) + # Create config and metadata + config = create_test_config(buffer_device=args.device, backend=args.backend) + metadata = create_test_metadata() + + thread_loop = None + thread = None + backend = None + try: + thread_loop = asyncio.new_event_loop() + thread = threading.Thread(target=thread_loop.run_forever) + thread.start() + + # Create backend + backend = NixlStorageBackend.CreateNixlStorageBackend( + config=config, + loop=thread_loop, + metadata=metadata, + ) - # Common configuration - config = NixlConfig( - role=NixlRole(args.role), - receiver_host=args.host, - receiver_port=args.port, - buffer_size=2**32, # 4GB - buffer_device="cuda", - ) + # Generate and test with data + keys, objs = generate_test_data(args.num_objs, torch.Size([32, 2, 256, 1024])) + total_size = sum(obj.get_size() for obj in objs) + logger.info( + "Generated %d objects with total size %.2f MB", + len(objs), + total_size / (1024 * 1024), + ) - # Create the NixlBackend - backend = NixlBackend(config) - - if args.role == "sender": - throughputs = [] - for i in range(args.num_rounds): - logger.info("Round %d/%d", i + 1, args.num_rounds) - throughput = send_and_measure_throughput(backend, keys, objs) - throughputs.append(throughput) - avg_throughput = sum(throughputs) / len(throughputs) - logger.info("Average throughput: %.2f GB/s", avg_throughput) - else: # receiver - for i in range(args.num_rounds): - logger.info("Round %d/%d", i + 1, args.num_rounds) - success = receive_and_verify_data(backend, keys, args.num_objs) - - # Wait a bit before closing - time.sleep(2) - backend.close() - logger.info("Test completed") + # Test put performance + start_time = time.time() + backend.batched_submit_put_task(keys, objs) + end_time = time.time() + elapsed_time = end_time - start_time + throughput = calculate_throughput(total_size, elapsed_time) + logger.info("Put throughput: %.2f GB/s", throughput) + + # Test get performance + start_time = time.time() + retrieved_objs = asyncio.run( + backend.batched_get_non_blocking(lookup_id="test", keys=keys) + ) + end_time = time.time() + elapsed_time = end_time - start_time + throughput = calculate_throughput(total_size, elapsed_time) + logger.info("Get throughput: %.2f GB/s", throughput) + + # Verify data + for retrieved_obj, original_obj in zip(retrieved_objs, objs, strict=False): + assert torch.equal(retrieved_obj.tensor, original_obj.tensor) + + logger.info("All tests passed successfully!") + + except Exception: + raise + finally: + if backend: + backend.close() + if thread_loop and thread_loop.is_running(): + thread_loop.call_soon_threadsafe(thread_loop.stop) + if thread and thread.is_alive(): + thread.join() + # Cleanup temporary directory + if os.path.exists(config.extra_config["nixl_path"]): + os.rmdir(config.extra_config["nixl_path"]) diff --git a/tests/test_observability.py b/tests/test_observability.py index cd1c6318991..7f4204a13bc 100644 --- a/tests/test_observability.py +++ b/tests/test_observability.py @@ -62,20 +62,28 @@ def test_update_local_storage_usage(stats_monitor): def test_on_lookup_request(stats_monitor): stats_monitor.on_lookup_request(num_tokens=50) + assert len(stats_monitor.lookup_requests) == 1 stats = stats_monitor.get_stats_and_clear() assert stats.interval_lookup_requests == 1 assert stats.interval_lookup_tokens == 50 assert stats.lookup_hit_rate == 0 + assert len(stats.interval_lookup_hit_rates) == 0 + # on_lookup_finished is not called, lookup_requests is not clear + assert len(stats_monitor.lookup_requests) == 1 def test_on_lookup_finished(stats_monitor): - stats_monitor.on_lookup_request(num_tokens=100) - stats_monitor.on_lookup_finished(num_hit_tokens=80) + lookup_request_id = stats_monitor.on_lookup_request(num_tokens=100) + assert len(stats_monitor.lookup_requests) == 1 + stats_monitor.on_lookup_finished(request_id=lookup_request_id, num_hit_tokens=80) stats = stats_monitor.get_stats_and_clear() assert stats.interval_lookup_requests == 1 assert stats.interval_lookup_tokens == 100 assert stats.interval_lookup_hits == 80 assert stats.lookup_hit_rate == 0.8 + assert len(stats.interval_lookup_hit_rates) == 1 + assert stats.interval_lookup_hit_rates[0] == 0.8 + assert len(stats_monitor.lookup_requests) == 0 def test_remote_read_metrics(stats_monitor): @@ -143,16 +151,21 @@ def test_retrieve_and_store_speed(stats_monitor): def test_multiple_lookup_operations(stats_monitor): # Test multiple lookup operations - stats_monitor.on_lookup_request(num_tokens=100) - stats_monitor.on_lookup_finished(num_hit_tokens=80) - stats_monitor.on_lookup_request(num_tokens=200) - stats_monitor.on_lookup_finished(num_hit_tokens=150) + lookup_request_id_1 = stats_monitor.on_lookup_request(num_tokens=100) + stats_monitor.on_lookup_finished(request_id=lookup_request_id_1, num_hit_tokens=80) + lookup_request_id_2 = stats_monitor.on_lookup_request(num_tokens=200) + stats_monitor.on_lookup_finished(request_id=lookup_request_id_2, num_hit_tokens=150) + assert len(stats_monitor.lookup_requests) == 2 + assert stats_monitor.lookup_requests[lookup_request_id_1].hit_rate() == 0.8 + assert stats_monitor.lookup_requests[lookup_request_id_2].hit_rate() == 0.75 stats = stats_monitor.get_stats_and_clear() assert stats.interval_lookup_requests == 2 assert stats.interval_lookup_tokens == 300 assert stats.interval_lookup_hits == 230 assert stats.lookup_hit_rate == 230 / 300 + assert len(stats.interval_lookup_hit_rates) == 2 + assert len(stats_monitor.lookup_requests) == 0 def test_mixed_remote_operations(stats_monitor): @@ -205,21 +218,47 @@ def test_combined_operations(stats_monitor): def test_stats_clearing(stats_monitor): # Add some data - stats_monitor.on_lookup_request(num_tokens=100) + lookup_request_id = stats_monitor.on_lookup_request(num_tokens=100) stats_monitor.update_interval_remote_read_metrics(read_bytes=1024) stats_monitor.update_remote_ping_latency(latency=25.0) + assert len(stats_monitor.lookup_requests) == 1 + # Get stats (which should clear them) stats = stats_monitor.get_stats_and_clear() assert stats.interval_lookup_requests == 1 + assert stats.interval_lookup_tokens == 100 + assert stats.interval_lookup_hits == 0 assert stats.interval_remote_read_requests == 1 + assert stats.interval_remote_read_bytes == 1024 assert stats.interval_remote_ping_latency == 25.0 + assert len(stats.interval_lookup_hit_rates) == 0 # Get stats again - should be cleared stats2 = stats_monitor.get_stats_and_clear() assert stats2.interval_lookup_requests == 0 + assert stats2.interval_lookup_tokens == 0 + assert stats2.interval_lookup_hits == 0 assert stats2.interval_remote_read_requests == 0 + assert stats2.interval_remote_read_bytes == 0 assert stats2.interval_remote_ping_latency == 0 + assert len(stats2.interval_lookup_hit_rates) == 0 + + assert len(stats_monitor.lookup_requests) == 1 + + # finish lookup request + stats_monitor.on_lookup_finished(request_id=lookup_request_id, num_hit_tokens=80) + stats3 = stats_monitor.get_stats_and_clear() + assert stats3.interval_lookup_requests == 0 + assert stats3.interval_lookup_tokens == 0 + assert stats3.interval_lookup_hits == 80 + assert stats3.interval_remote_read_requests == 0 + assert stats3.interval_remote_read_bytes == 0 + assert stats3.interval_remote_ping_latency == 0 + assert len(stats3.interval_lookup_hit_rates) == 1 + assert stats3.interval_lookup_hit_rates[0] == 0.8 + + assert len(stats_monitor.lookup_requests) == 0 def test_zero_division_protection(stats_monitor): diff --git a/tests/v1/data/nixl.yaml b/tests/v1/data/nixl.yaml index 83a9b76deba..12748804741 100644 --- a/tests/v1/data/nixl.yaml +++ b/tests/v1/data/nixl.yaml @@ -3,4 +3,8 @@ chunk_size: 256 nixl_buffer_size: 1073741824 nixl_buffer_device: cpu -extra_config: {enable_nixl_storage: true, nixl_backend: POSIX, nixl_file_pool_size: 64, nixl_path: /tmp/nixl/cache/} +extra_config: + enable_nixl_storage: true + nixl_backend: POSIX + nixl_pool_size: 2 + nixl_path: /tmp/nixl/cache diff --git a/tests/v1/internal_api_server/test_cache_clear.py b/tests/v1/internal_api_server/test_cache_clear.py new file mode 100644 index 00000000000..cee978befdd --- /dev/null +++ b/tests/v1/internal_api_server/test_cache_clear.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from unittest.mock import MagicMock +import json + +# Third Party +from fastapi.testclient import TestClient +import pytest + +# First Party +from lmcache.v1.cache_engine import LMCacheEngine +from lmcache.v1.internal_api_server.api_server import app + + +class TestCacheClearAPI: + """Test suite for the /cache/clear API endpoint.""" + + @pytest.fixture + def mock_lmcache_adapter(self): + """Create a mock LMCacheConnectorV1Impl adapter.""" + adapter = MagicMock() + + # Create a mock LMCache engine + mock_engine = MagicMock(spec=LMCacheEngine) + mock_engine.clear.return_value = 5 # Mock return value for clear operation + + adapter.lmcache_engine = mock_engine + return adapter + + @pytest.fixture + def client_with_adapter(self, mock_lmcache_adapter): + """Create a test client with mocked adapter.""" + app.state.lmcache_adapter = mock_lmcache_adapter + return TestClient(app) + + def test_cache_clear_success(self, client_with_adapter, mock_lmcache_adapter): + """Test successful cache clear operation.""" + # Act + response = client_with_adapter.delete("/cache/clear") + + # Assert + assert response.status_code == 200 + response_data = json.loads(response.text) + assert response_data["status"] == "success" + assert response_data["num_removed"] == 5 + + # Verify that the clear method was called with correct parameters + mock_lmcache_adapter.lmcache_engine.clear.assert_called_once_with( + locations=None, request_configs=None + ) + + @pytest.mark.parametrize( + "locations,test_description", + [ + (["LocalCPUBackend", "LocalDiskBackend"], "multiple locations"), + (["LocalCPUBackend"], "single location"), + ], + ) + def test_cache_clear_with_locations( + self, client_with_adapter, mock_lmcache_adapter, locations, test_description + ): + """Test cache clear with specific locations.""" + # Act + # curl -X DELETE "http://localhost:8000/cache/clear?locations=LocalCPUBackend&locations=LocalDiskBackend" + response = client_with_adapter.delete( + "/cache/clear", params={"locations": locations} + ) + + # Assert + assert response.status_code == 200 + response_data = json.loads(response.text) + assert response_data["status"] == "success" + assert response_data["num_removed"] == 5 + + # Verify that the clear method was called with the correct locations + mock_lmcache_adapter.lmcache_engine.clear.assert_called_once() + call_args = mock_lmcache_adapter.lmcache_engine.clear.call_args + + # Verify the locations parameter was passed correctly + assert call_args is not None + assert "locations" in call_args.kwargs + assert "request_configs" in call_args.kwargs + + # Assert that the locations parameter matches what we sent + assert call_args.kwargs["locations"] == locations + + def test_cache_clear_engine_exception( + self, client_with_adapter, mock_lmcache_adapter + ): + """Test cache clear when engine raises an exception.""" + # Arrange + mock_lmcache_adapter.lmcache_engine.clear.side_effect = Exception("Cache error") + + # Act + response = client_with_adapter.delete("/cache/clear") + + # Assert + assert response.status_code == 500 + response_data = json.loads(response.text) + assert response_data["error"] == "Failed to clear cache" + assert response_data["message"] == "Cache error" + + def test_cache_clear_negative_return_value( + self, client_with_adapter, mock_lmcache_adapter + ): + """Test cache clear when engine returns negative value (edge case).""" + # Arrange + mock_lmcache_adapter.lmcache_engine.clear.return_value = -1 + + # Act + response = client_with_adapter.delete("/cache/clear") + + # Assert + assert response.status_code == 200 + response_data = json.loads(response.text) + assert response_data["status"] == "success" + assert response_data["num_removed"] == -1 + + def test_cache_clear_adapter_attribute_error(self): + """Test cache clear when adapter doesn't have lmcache_engine attribute.""" + + # Arrange + class AdapterWithoutEngine: + pass + + app.state.lmcache_adapter = AdapterWithoutEngine() + client = TestClient(app) + # Act + response = client.delete("/cache/clear") + # Assert + assert response.status_code == 503 + response_data = json.loads(response.text) + assert response_data["error"] == "/cache/clear API is unavailable" + assert response_data["message"] == "LMCache engine not configured." diff --git a/tests/v1/multiprocess/__init__.py b/tests/v1/multiprocess/__init__.py new file mode 100644 index 00000000000..db4392428f5 --- /dev/null +++ b/tests/v1/multiprocess/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 + diff --git a/tests/v1/multiprocess/test_custom_types.py b/tests/v1/multiprocess/test_custom_types.py new file mode 100644 index 00000000000..32fa739bd67 --- /dev/null +++ b/tests/v1/multiprocess/test_custom_types.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from multiprocessing import Queue +import multiprocessing as mp + +# Third Party +import msgspec +import pytest +import torch + +# First Party +from lmcache.v1.multiprocess.custom_types import ( + CudaIPCWrapper, + IPCCacheEngineKey, + get_customized_decoder, + get_customized_encoder, +) + + +def test_ipc_cache_engine_key_serialization(): + """Test encoding and decoding of IPCCacheEngineKey using msgspec.""" + # Create a sample IPCCacheEngineKey + original_key = IPCCacheEngineKey.from_int_hash( + model_name="test_model", world_size=4, worker_id=1, chunk_hash=123456789 + ) + + # Encode the key + encoded = msgspec.msgpack.encode(original_key) + + # Decode the key + decoded_key = msgspec.msgpack.decode(encoded, type=IPCCacheEngineKey) + + # Verify correctness + assert original_key == decoded_key, "IPCCacheEngineKeys do not match!" + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is required for CudaIPCWrapper tests", +) +def test_cudaipc_wrapper_serialization(): + """Test custom encoder/decoder for single CudaIPCWrapper object.""" + encoder = get_customized_encoder(type=CudaIPCWrapper) + decoder = get_customized_decoder(type=CudaIPCWrapper) + + # Create a sample tensor + original_tensor = torch.randn(3, 4, device="cuda") + wrapper = CudaIPCWrapper(original_tensor) + + # Encode the wrapper + encoded = encoder.encode(wrapper) + + # Decode the wrapper + decoded_wrapper = decoder.decode(encoded) + assert isinstance(decoded_wrapper, CudaIPCWrapper), ( + "Decoded object is not of type CudaIPCWrapper" + ) + assert decoded_wrapper == wrapper, ( + "Decoded CudaIPCWrapper does not match the original" + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is required for CudaIPCWrapper tests", +) +def test_cudaipc_wrapper_list_serialization(): + """Test custom encoder/decoder for list of CudaIPCWrapper objects.""" + wrappers = [] + for _ in range(5): + tensor = torch.randn(2, 2, device="cuda") + wrapper = CudaIPCWrapper(tensor) + wrappers.append(wrapper) + + encoder = get_customized_encoder(type=list[CudaIPCWrapper]) + decoder = get_customized_decoder(type=list[CudaIPCWrapper]) + + # Encode the list of wrappers + encoded = encoder.encode(wrappers) + + # Decode the list of wrappers + decoded_wrappers = decoder.decode(encoded) + + assert len(decoded_wrappers) == len(wrappers), ( + "Decoded list length does not match original" + ) + + for original, decoded in zip(wrappers, decoded_wrappers, strict=False): + assert original == decoded, "Decoded CudaIPCWrapper does not match the original" + + +def _worker_process_deserialize_and_reconstruct( + encoded_data: bytes, result_queue: Queue +): + """ + Worker function that runs in a separate process. + Deserializes CudaIPCWrapper list and reconstructs tensors. + """ + try: + # Decode the list of wrappers + torch.cuda.init() + decoder = get_customized_decoder(type=list[CudaIPCWrapper]) + decoded_wrappers = decoder.decode(encoded_data) + + # Convert each wrapper back to tensor and compute checksum + checksums = [] + shapes = [] + for wrapper in decoded_wrappers: + tensor = wrapper.to_tensor() + # Compute checksum as sum of all elements + checksum = float(tensor.sum().cpu().item()) + checksums.append(checksum) + shapes.append(list(tensor.shape)) + + # Do add 1 on the tensor to ensure it's writable + tensor.add_(1) + + result_queue.put(("success", checksums, shapes)) + except Exception as e: + result_queue.put(("error", str(e), None)) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is required for CudaIPCWrapper multiprocessing tests", +) +def test_cudaipc_wrapper_multiprocess_serialization(): + """ + Test CudaIPCWrapper serialization across processes using spawn method. + This verifies that CUDA IPC handles can be properly shared between processes. + """ + # Set multiprocessing start method to spawn + ctx = mp.get_context("spawn") + + # Create test tensors and wrappers in the main process + num_tensors = 3 + tensors = [] + test_data = [] + wrappers = [] + + for i in range(num_tensors): + # Create a tensor with known values + tensor = torch.full( + (2, 3), fill_value=float(i + 1), dtype=torch.float32, device="cuda" + ) + tensors.append(tensor) + wrapper = CudaIPCWrapper(tensor) + wrappers.append(wrapper) + + # Store expected checksum and shape + expected_checksum = float(tensor.sum().cpu().item()) + expected_shape = list(tensor.shape) + test_data.append((expected_checksum, expected_shape)) + + # Serialize the wrappers + encoder = get_customized_encoder(type=list[CudaIPCWrapper]) + encoded_data = encoder.encode(wrappers) + + # Create a queue for results + result_queue = ctx.Queue() + + # Start worker process + process = ctx.Process( + target=_worker_process_deserialize_and_reconstruct, + args=(encoded_data, result_queue), + ) + process.start() + + # Wait for result with timeout + process.join(timeout=10) + + # Check if process completed successfully + if process.is_alive(): + process.terminate() + process.join() + pytest.fail("Worker process timed out") + + assert process.exitcode == 0, ( + f"Worker process failed with exit code {process.exitcode}" + ) + + # Get result from queue + assert not result_queue.empty(), "No result received from worker process" + status, checksums, shapes = result_queue.get() + + assert status == "success", f"Worker process encountered error: {checksums}" + assert len(checksums) == num_tensors, "Number of checksums does not match" + assert len(shapes) == num_tensors, "Number of shapes does not match" + + # Verify checksums and shapes match + for i, ( + (expected_checksum, expected_shape), + actual_checksum, + actual_shape, + ) in enumerate(zip(test_data, checksums, shapes, strict=False)): + assert actual_shape == expected_shape, ( + f"Tensor {i}: shape mismatch. Expected {expected_shape}, got {actual_shape}" + ) + assert abs(actual_checksum - expected_checksum) < 1e-5, ( + f"Tensor {i}: checksum mismatch. Expected {expected_checksum}, " + f"got {actual_checksum}" + ) + + # Verify that the tensors are being modified in the worker process + for i, (tensor, (expected_checksum, _)) in enumerate( + zip(tensors, test_data, strict=False) + ): + # After adding 1 to each element, the new checksum should be: + num_elements = tensor.numel() + new_expected_checksum = expected_checksum + float(num_elements) + actual_checksum = float(tensor.sum().cpu().item()) + assert abs(actual_checksum - new_expected_checksum) < 1e-5, ( + f"Tensor {i}: post-modification checksum mismatch. " + f"Expected {new_expected_checksum}, got {actual_checksum}" + ) diff --git a/tests/v1/multiprocess/test_mq.py b/tests/v1/multiprocess/test_mq.py new file mode 100644 index 00000000000..8580af3b4b3 --- /dev/null +++ b/tests/v1/multiprocess/test_mq.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from multiprocessing.synchronize import Event as EventClass +from typing import Any, Callable +import multiprocessing as mp +import sys +import threading +import time + +# Third Party +import pytest +import torch +import zmq + +# First Party +from lmcache.v1.multiprocess.custom_types import CudaIPCWrapper, IPCCacheEngineKey +from lmcache.v1.multiprocess.mq import ( + MessageQueueClient, + MessageQueueServer, + MessagingFuture, +) +from lmcache.v1.multiprocess.protocol import ( + RequestType, + get_handler_type, + get_payload_classes, +) + +# Test helpers +from tests.v1.multiprocess import test_mq_handler_helpers + + +def test_messaging_future_basic_usage(): + """Test basic usage of MessagingFuture: set result and retrieve it.""" + future = MessagingFuture[int]() + + # Initially, future should not be done + assert not future.query(), "Future should not be done initially" + + # Set result + future.set_result(42) + + # Future should now be done + assert future.query(), "Future should be done after setting result" + + # Get result (should be immediate) + result = future.result(timeout=1) + assert result == 42, f"Expected result 42, got {result}" + + +def test_messaging_future_with_thread(): + """Test MessagingFuture with result set from another thread.""" + future = MessagingFuture[str]() + + def set_future_result(): + time.sleep(0.5) + future.set_result("Hello from thread") + + # Start thread that will set the result + thread = threading.Thread(target=set_future_result) + thread.start() + + # Initially should not be done + assert not future.query(), "Future should not be done before thread sets result" + + # Wait for result + result = future.result(timeout=2) + assert result == "Hello from thread", f"Expected 'Hello from thread', got {result}" + + # Should be done now + assert future.query(), "Future should be done after getting result" + + thread.join() + + +def test_messaging_future_wait_success(): + """Test wait method when result becomes available.""" + future = MessagingFuture[int]() + + def set_future_result(): + time.sleep(0.3) + future.set_result(100) + + thread = threading.Thread(target=set_future_result) + thread.start() + + # Wait should return True when result is set + success = future.wait(timeout=1) + assert success, "Wait should return True when result is available" + assert future.query(), "Future should be done after wait returns True" + + thread.join() + + +def test_messaging_future_wait_timeout(): + """Test wait method when timeout is reached.""" + future = MessagingFuture[int]() + + # Wait with short timeout (result never set) + start_time = time.time() + success = future.wait(timeout=0.2) + elapsed = time.time() - start_time + + assert not success, "Wait should return False on timeout" + assert not future.query(), "Future should not be done after timeout" + assert 0.15 < elapsed < 0.3, f"Wait should respect timeout, elapsed: {elapsed}" + + +def test_messaging_future_result_timeout(): + """Test result method raises TimeoutError when timeout is reached.""" + future = MessagingFuture[int]() + + # Try to get result with timeout (result never set) + with pytest.raises( + TimeoutError, match="Future result not available within timeout" + ): + future.result(timeout=0.2) + + assert not future.query(), "Future should not be done after timeout" + + +def test_messaging_future_wait_no_timeout(): + """Test wait method without timeout (waits indefinitely until result is set).""" + future = MessagingFuture[float]() + + def set_future_result(): + time.sleep(0.3) + future.set_result(3.14) + + thread = threading.Thread(target=set_future_result) + thread.start() + + # Wait without timeout should wait until result is available + success = future.wait() # No timeout parameter + assert success, "Wait should return True when result is set" + assert future.result() == 3.14, "Result should be accessible after wait" + + thread.join() + + +def test_messaging_future_multiple_result_calls(): + """Test that result can be retrieved multiple times after being set.""" + future = MessagingFuture[str]() + future.set_result("persistent value") + + # Get result multiple times + result1 = future.result(timeout=0.1) + result2 = future.result(timeout=0.1) + result3 = future.result(timeout=0.1) + + assert result1 == result2 == result3 == "persistent value", ( + "Result should be retrievable multiple times" + ) + + +def test_messaging_future_complex_type(): + """Test MessagingFuture with complex types like lists and dicts.""" + future = MessagingFuture[dict]() + + complex_data = {"key1": [1, 2, 3], "key2": {"nested": "value"}, "key3": 42} + + def set_future_result(): + time.sleep(0.2) + future.set_result(complex_data) + + thread = threading.Thread(target=set_future_result) + thread.start() + + result = future.result(timeout=1) + assert result == complex_data, "Complex types should be preserved" + + thread.join() + + +# ============================================================================== +# MessageQueueServer and MessageQueueClient Tests Infrastructure +# ============================================================================== + + +def _server_process( + server_url: str, + ready_event: EventClass, + shutdown_event: EventClass, + request_handlers: dict[RequestType, Callable], +): + """ + Server process that runs the MessageQueueServer. + + Args: + server_url: URL to bind the server to + ready_event: Event to signal when server is ready + shutdown_event: Event to signal server shutdown + request_handlers: Dict mapping RequestType to handler functions + """ + context = zmq.Context.instance() + server = MessageQueueServer(server_url, context) + + # Register all handlers + for request_type, handler in request_handlers.items(): + payload_classes = get_payload_classes(request_type) + handler_type = get_handler_type(request_type) + server.add_handler(request_type, payload_classes, handler_type, handler) + + server.start() + + # Signal that server is ready + ready_event.set() + + # Wait for shutdown signal + shutdown_event.wait() + + # Cleanup + server.close() + + +def _run_client_test( + server_url: str, + ready_event: EventClass, + request_type: RequestType, + payloads: list[Any], + expected_response: Any, + num_requests: int = 1, + client_id: int = 0, +) -> None: + """ + Client process that sends requests and validates responses. + + Args: + server_url: URL to connect to + ready_event: Event to wait for server to be ready + request_type: Type of request to send + payloads: List of payloads for the request + expected_response: Expected response from server + num_requests: Number of requests to send + client_id: ID of this client (for debugging) + + Returns: + bool: True if all tests passed, False otherwise + """ + # Wait for server to be ready + if not ready_event.wait(timeout=5): + print(f"Client {client_id}: Server failed to start within timeout") + sys.exit(1) + + # Small delay to ensure server is fully initialized + time.sleep(0.1) + + context = zmq.Context.instance() + client = MessageQueueClient(server_url, context) + successful = True + + try: + futures = [] + # Submit requests + for _ in range(num_requests): + future = client.submit_request(request_type, payloads) # type: ignore + futures.append(future) + + # Validate responses + for i, future in enumerate(futures): + response = future.result(timeout=5) + if response != expected_response: + print( + f"Client {client_id}, Request {i}: Expected " + f"{expected_response}, got {response}" + ) + + # Exit with error code + client.close() + sys.exit(1) + + except Exception as e: + print(f"Client {client_id} test failed with exception: {e}") + successful = False + finally: + client.close() + if not successful: + sys.exit(1) + + +class MessageQueueTestHelper: + """ + Helper class to facilitate testing MessageQueueServer and MessageQueueClient. + + Supports testing with single or multiple concurrent clients, where each client + can send multiple requests to the server. + + Usage: + 1. Create an instance with server URL + 2. Register handlers for different RequestTypes + 3. Call run_test() to execute the test with client requests + + Example: + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5556") + helper.register_handler(RequestType.NOOP, noop_handler) + helper.run_test( + request_type=RequestType.NOOP, + payloads=[], + expected_response="NOOP_OK", + num_requests=10, # Each client sends 10 requests + num_clients=3, # Start 3 concurrent clients + ) + """ + + def __init__(self, server_url: str = "tcp://127.0.0.1:5556"): + self.server_url = server_url + self.handlers: dict[RequestType, Callable] = {} + self.ctx = mp.get_context("spawn") + + def register_handler( + self, + request_type: RequestType, + handler: Callable, + ) -> "MessageQueueTestHelper": + """ + Register a handler for a specific RequestType. + + Args: + request_type: The type of request to handle + handler: Handler function that matches the protocol signature + + Returns: + self for method chaining + """ + self.handlers[request_type] = handler + return self + + def run_test( + self, + request_type: RequestType, + payloads: list[Any], + expected_response: Any, + num_requests: int = 1, + num_clients: int = 1, + timeout: float = 10.0, + ) -> None: + """ + Run a test by starting server and client processes. + + Args: + request_type: Type of request to send + payloads: List of payloads for the request + expected_response: Expected response from server + num_requests: Number of requests each client should send + num_clients: Number of client processes to start + timeout: Maximum time to wait for test completion + + Raises: + AssertionError: If test fails + """ + ready_event = self.ctx.Event() + shutdown_event = self.ctx.Event() + + # Start server process + server_process = self.ctx.Process( + target=_server_process, + args=(self.server_url, ready_event, shutdown_event, self.handlers), + ) + server_process.start() + + # Start multiple client processes + client_processes = [] + for client_id in range(num_clients): + client_process = self.ctx.Process( + target=_run_client_test, + args=( + self.server_url, + ready_event, + request_type, + payloads, + expected_response, + num_requests, + client_id, + ), + ) + client_process.start() + client_processes.append(client_process) + + # Wait for all clients to complete + failed_clients = [] + for client_id, client_process in enumerate(client_processes): + client_process.join(timeout=timeout) + + # Check if client completed successfully + if client_process.is_alive(): + client_process.terminate() + client_process.join() + failed_clients.append((client_id, "timeout")) + elif client_process.exitcode != 0: + failed_clients.append( + (client_id, f"exit code {client_process.exitcode}") + ) + + # Shutdown server + shutdown_event.set() + server_process.join(timeout=2) + + if server_process.is_alive(): + server_process.terminate() + server_process.join() + + # Report any failures + if failed_clients: + failure_details = ", ".join( + [f"Client {cid}: {reason}" for cid, reason in failed_clients] + ) + pytest.fail(f"Some clients failed: {failure_details}") + + if server_process.exitcode != 0: + pytest.fail( + f"Server process failed with exit code {server_process.exitcode}" + ) + + +# ============================================================================== +# Tests for Different RequestTypes +# ============================================================================== + + +def test_mq_noop_request(): + """ + Test MessageQueue with NOOP request type. + NOOP takes no payloads and returns a string response. + """ + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5556") + helper.register_handler(RequestType.NOOP, test_mq_handler_helpers.noop_handler) + + # Run test with single request + helper.run_test( + request_type=RequestType.NOOP, + payloads=[], + expected_response="NOOP_OK", + num_requests=1, + ) + + +def test_mq_noop_multiple_requests(): + """ + Test MessageQueue with multiple NOOP requests. + Verifies that server can handle multiple sequential requests. + """ + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5557") + helper.register_handler(RequestType.NOOP, test_mq_handler_helpers.noop_handler) + + # Run test with multiple requests + helper.run_test( + request_type=RequestType.NOOP, + payloads=[], + expected_response="NOOP_OK", + num_requests=10, + ) + + +def test_mq_noop_multiple_clients(): + """ + Test MessageQueue with multiple concurrent clients. + Verifies that server can handle requests from multiple clients simultaneously. + """ + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5558") + helper.register_handler(RequestType.NOOP, test_mq_handler_helpers.noop_handler) + + # Run test with multiple clients, each sending multiple requests + helper.run_test( + request_type=RequestType.NOOP, + payloads=[], + expected_response="NOOP_OK", + num_requests=5, + num_clients=3, + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is required for REGISTER_KV_CACHE tests", +) +def test_mq_register_kv_cache(): + """ + Test MessageQueue with REGISTER_KV_CACHE request type. + REGISTER_KV_CACHE takes (gpu_id: int, kv_cache: KVCache) and returns None. + """ + # Create test KV cache (list of CudaIPCWrapper objects) + kv_cache = [] + for _ in range(3): + tensor = torch.randn(2, 4, device="cuda") + wrapper = CudaIPCWrapper(tensor) + kv_cache.append(wrapper) + + gpu_id = 0 + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5559") + helper.register_handler( + RequestType.REGISTER_KV_CACHE, test_mq_handler_helpers.register_kv_cache_handler + ) + + # Run test with REGISTER_KV_CACHE request + helper.run_test( + request_type=RequestType.REGISTER_KV_CACHE, + payloads=[gpu_id, kv_cache], + expected_response=None, + num_requests=1, + ) + + +def test_mq_unregister_kv_cache(): + """ + Test MessageQueue with UNREGISTER_KV_CACHE request type. + UNREGISTER_KV_CACHE takes (gpu_id: int) and returns None. + """ + gpu_id = 0 + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5560") + helper.register_handler( + RequestType.UNREGISTER_KV_CACHE, + test_mq_handler_helpers.unregister_kv_cache_handler, + ) + + # Run test with UNREGISTER_KV_CACHE request + helper.run_test( + request_type=RequestType.UNREGISTER_KV_CACHE, + payloads=[gpu_id], + expected_response=None, + num_requests=1, + ) + + +def test_mq_unregister_kv_cache_multiple_clients(): + """ + Test MessageQueue with UNREGISTER_KV_CACHE from multiple clients. + Verifies that multiple clients can unregister KV caches concurrently. + """ + gpu_id = 0 + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5561") + helper.register_handler( + RequestType.UNREGISTER_KV_CACHE, + test_mq_handler_helpers.unregister_kv_cache_handler, + ) + + # Run test with multiple clients + helper.run_test( + request_type=RequestType.UNREGISTER_KV_CACHE, + payloads=[gpu_id], + expected_response=None, + num_requests=3, + num_clients=2, + ) + + +def test_mq_store(): + """ + Test MessageQueue with STORE request type. + STORE takes (keys: list[KeyType], gpu_id: int, gpu_block_ids: list[int]) + and returns bool. + """ + # Create test keys + keys = [ + IPCCacheEngineKey.from_int_hash( + model_name="test_model", world_size=1, worker_id=0, chunk_hash=i + ) + for i in range(3) + ] + gpu_id = 0 + gpu_block_ids = [0, 1, 2] + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5562") + helper.register_handler(RequestType.STORE, test_mq_handler_helpers.store_handler) + + # Run test with STORE request + helper.run_test( + request_type=RequestType.STORE, + payloads=[keys, gpu_id, gpu_block_ids], + expected_response=True, + num_requests=1, + ) + + +def test_mq_retrieve(): + """ + Test MessageQueue with RETRIEVE request type. + RETRIEVE takes (keys: list[KeyType], gpu_id: int, gpu_block_ids: list[int]) + and returns bool. + """ + # Create test keys + keys = [ + IPCCacheEngineKey.from_int_hash( + model_name="test_model", world_size=1, worker_id=0, chunk_hash=i + ) + for i in range(3) + ] + gpu_id = 0 + gpu_block_ids = [0, 1, 2] + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5563") + helper.register_handler( + RequestType.RETRIEVE, test_mq_handler_helpers.retrieve_handler + ) + + # Run test with RETRIEVE request + helper.run_test( + request_type=RequestType.RETRIEVE, + payloads=[keys, gpu_id, gpu_block_ids], + expected_response=[True, True, True], + num_requests=1, + ) + + +def test_mq_lookup(): + """ + Test MessageQueue with LOOKUP request type. + LOOKUP takes (keys: list[KeyType], lock: Optional[bool]) + and returns list[bool]. + """ + # Create test keys + keys = [ + IPCCacheEngineKey.from_int_hash( + model_name="test_model", world_size=1, worker_id=0, chunk_hash=i + ) + for i in range(4) + ] + lock = True + + # Expected response: alternating True/False for each key + expected_response = [True, False, True, False] + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5564") + helper.register_handler(RequestType.LOOKUP, test_mq_handler_helpers.lookup_handler) + + # Run test with LOOKUP request + helper.run_test( + request_type=RequestType.LOOKUP, + payloads=[keys, lock], + expected_response=expected_response, + num_requests=1, + ) + + +def test_mq_lookup_with_none_lock(): + """ + Test MessageQueue with LOOKUP request type with None lock parameter. + Tests that Optional[bool] parameter works correctly with None value. + """ + # Create test keys + keys = [ + IPCCacheEngineKey.from_int_hash( + model_name="test_model", world_size=1, worker_id=0, chunk_hash=i + ) + for i in range(3) + ] + lock = None + + # Expected response: alternating True/False for each key + expected_response = [True, False, True] + + # Create test helper and register handler + helper = MessageQueueTestHelper(server_url="tcp://127.0.0.1:5565") + helper.register_handler(RequestType.LOOKUP, test_mq_handler_helpers.lookup_handler) + + # Run test with LOOKUP request with None lock + helper.run_test( + request_type=RequestType.LOOKUP, + payloads=[keys, lock], + expected_response=expected_response, + num_requests=1, + ) diff --git a/tests/v1/multiprocess/test_mq_handler_helpers.py b/tests/v1/multiprocess/test_mq_handler_helpers.py new file mode 100644 index 00000000000..fcaeeb62099 --- /dev/null +++ b/tests/v1/multiprocess/test_mq_handler_helpers.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Helper handler functions for MessageQueue tests. + +These handlers are defined at module level to allow them to be pickled +and passed between processes during multiprocessing tests. +""" + +# Standard +from typing import Optional + +# First Party +from lmcache.v1.multiprocess.custom_types import KVCache +from lmcache.v1.multiprocess.protocol import KeyType + +# ============================================================================== +# NOOP Request Handlers +# ============================================================================== + + +def noop_handler() -> str: + """ + Dummy handler for NOOP requests. + Takes no arguments and returns a simple string response. + """ + return "NOOP_OK" + + +# ============================================================================== +# REGISTER_KV_CACHE Request Handlers +# ============================================================================== + + +def register_kv_cache_handler(gpu_id: int, kv_cache: KVCache) -> None: + """ + Dummy handler for REGISTER_KV_CACHE requests. + + Args: + gpu_id: GPU device ID + kv_cache: List of CudaIPCWrapper objects representing KV cache + + Returns: + None + """ + # In a real implementation, this would register the KV cache + # For testing, we just validate the inputs are received correctly + assert isinstance(gpu_id, int), f"Expected gpu_id to be int, got {type(gpu_id)}" + assert isinstance(kv_cache, list), ( + f"Expected kv_cache to be list, got {type(kv_cache)}" + ) + # No return value (returns None implicitly) + + +# ============================================================================== +# UNREGISTER_KV_CACHE Request Handlers +# ============================================================================== + + +def unregister_kv_cache_handler(gpu_id: int) -> None: + """ + Dummy handler for UNREGISTER_KV_CACHE requests. + + Args: + gpu_id: GPU device ID + + Returns: + None + """ + # In a real implementation, this would unregister the KV cache for the given GPU + # For testing, we just validate the input is received correctly + assert isinstance(gpu_id, int), f"Expected gpu_id to be int, got {type(gpu_id)}" + # No return value (returns None implicitly) + + +# ============================================================================== +# STORE Request Handlers +# ============================================================================== + + +def store_handler(keys: list[KeyType], gpu_id: int, gpu_block_ids: list[int]) -> bool: + """ + Dummy handler for STORE requests. + + Args: + keys: List of cache keys to store + gpu_id: GPU device ID + gpu_block_ids: List of GPU block IDs + + Returns: + bool: True if store succeeded + """ + # In a real implementation, this would store KV cache data + # For testing, we just validate the inputs are received correctly + assert isinstance(keys, list), f"Expected keys to be list, got {type(keys)}" + assert isinstance(gpu_id, int), f"Expected gpu_id to be int, got {type(gpu_id)}" + assert isinstance(gpu_block_ids, list), ( + f"Expected gpu_block_ids to be list, got {type(gpu_block_ids)}" + ) + # Return success + return True + + +# ============================================================================== +# RETRIEVE Request Handlers +# ============================================================================== + + +def retrieve_handler( + keys: list[KeyType], gpu_id: int, gpu_block_ids: list[int] +) -> list[bool]: + """ + Dummy handler for RETRIEVE requests. + + Args: + keys: List of cache keys to retrieve + gpu_id: GPU device ID + gpu_block_ids: List of GPU block IDs + + Returns: + bool: True if retrieve succeeded + """ + # In a real implementation, this would retrieve KV cache data + # For testing, we just validate the inputs are received correctly + assert isinstance(keys, list), f"Expected keys to be list, got {type(keys)}" + assert isinstance(gpu_id, int), f"Expected gpu_id to be int, got {type(gpu_id)}" + assert isinstance(gpu_block_ids, list), ( + f"Expected gpu_block_ids to be list, got {type(gpu_block_ids)}" + ) + # Return success + return [True for _ in keys] + + +# ============================================================================== +# LOOKUP Request Handlers +# ============================================================================== + + +def lookup_handler(keys: list[KeyType], lock: Optional[bool]) -> list[bool]: + """ + Dummy handler for LOOKUP requests. + + Args: + keys: List of cache keys to look up + lock: Optional flag to lock the keys + + Returns: + list[bool]: List indicating whether each key was found + """ + # In a real implementation, this would look up keys in the cache + # For testing, we just validate the inputs and return dummy results + assert isinstance(keys, list), f"Expected keys to be list, got {type(keys)}" + assert lock is None or isinstance(lock, bool), ( + f"Expected lock to be None or bool, got {type(lock)}" + ) + # Return a result for each key (alternating True/False for testing) + return [i % 2 == 0 for i in range(len(keys))] diff --git a/tests/v1/storage_backend/test_audit_connector.py b/tests/v1/storage_backend/test_audit_connector.py new file mode 100644 index 00000000000..bc886f629ae --- /dev/null +++ b/tests/v1/storage_backend/test_audit_connector.py @@ -0,0 +1,637 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from io import StringIO +import asyncio +import logging + +# Third Party +import pytest +import torch + +# First Party +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.memory_management import ( + AdHocMemoryAllocator, + MemoryFormat, + MemoryObj, + TensorMemoryObj, +) +from lmcache.v1.storage_backend.connector.audit_connector import AuditConnector +from lmcache.v1.storage_backend.connector.mock_connector import MockConnector +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend + + +def create_test_key(key_id: str) -> CacheEngineKey: + """Helper to create a test CacheEngineKey""" + return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id), dtype=torch.uint8) + + +def create_mock_memory_obj(backend: LocalCPUBackend, data: bytes) -> MemoryObj: + """Helper to create a mock MemoryObj with proper structure""" + tensor = torch.tensor( + [ord(c) for c in data.decode("latin1") if ord(c) < 256], dtype=torch.uint8 + ) + if len(tensor) == 0: + tensor = torch.tensor([0], dtype=torch.uint8) + + memory_obj = backend.allocate( + shape=tensor.shape, + dtype=tensor.dtype, + fmt=MemoryFormat.KV_2LTD, + ) + + if memory_obj is not None and isinstance(memory_obj, TensorMemoryObj): + memory_obj.tensor.copy_(tensor) + + return memory_obj + + +@pytest.fixture +def event_loop(): + """Create an event loop for async tests""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + + # Simplified cleanup logic + try: + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + + if not loop.is_closed(): + loop.call_soon(loop.stop) + loop.run_forever() + loop.close() + except Exception: + pass + finally: + asyncio.set_event_loop(None) + + +@pytest.fixture +def local_cpu_backend(): + """Fixture for LocalCPUBackend""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + allocator = AdHocMemoryAllocator(device="cpu") + return LocalCPUBackend(config=config, memory_allocator=allocator) + + +@pytest.fixture +def mock_connector(event_loop, local_cpu_backend): + """Fixture for mock connector""" + connector = MockConnector( + url="mock://test", + loop=event_loop, + local_cpu_backend=local_cpu_backend, + capacity=1, + peeking_latency=0.0, + read_throughput=100.0, + write_throughput=100.0, + ) + yield connector + # No cleanup needed, let event_loop fixture handle it + + +class LogCaptureHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records = [] + self.stream = StringIO() + + def emit(self, record): + self.records.append(record) + msg = self.format(record) + self.stream.write(msg + "\n") + + def get_records(self): + return self.records + + def get_logs(self): + return self.stream.getvalue() + + def clear(self): + self.records = [] + self.stream = StringIO() + + +@pytest.fixture +def log_capture(): + """Fixture for capturing logs""" + handler = LogCaptureHandler() + handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s")) + + # Get audit logger and add handler + audit_logger = logging.getLogger( + "lmcache.v1.storage_backend.connector.audit_connector" + ) + original_level = audit_logger.level + audit_logger.setLevel(logging.INFO) + audit_logger.addHandler(handler) + + yield handler + + # Cleanup + audit_logger.removeHandler(handler) + audit_logger.setLevel(original_level) + + +class TestAuditConnector: + """Test AuditConnector functionality""" + + def test_initialization_basic(self, mock_connector, log_capture): + """Test basic initialization""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + audit = AuditConnector(mock_connector, config) + + assert audit.real_connector is mock_connector + assert audit.verify_checksum is False + assert audit.calc_checksum is False + assert len(audit.excluded_cmds) == 0 + + print("\nCaptured logs:") + print(log_capture.get_logs()) + + init_logs = [ + r.msg + for r in log_capture.get_records() + if "[REMOTE_AUDIT]" in r.msg and "INITIALIZED" in r.msg + ] + + assert len(init_logs) > 0, ( + f"Expected INITIALIZED log. Got logs: {log_capture.get_logs()}" + ) + + def test_initialization_with_checksum(self, mock_connector, log_capture): + """Test initialization with checksum enabled""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={ + "audit_calc_checksum": True, + "audit_verify_checksum": True, + }, + ) + audit = AuditConnector(mock_connector, config) + + assert audit.verify_checksum is True + assert audit.calc_checksum is True + assert audit.registry_lock is not None + + print("\nCaptured logs:") + print(log_capture.get_logs()) + + init_logs = [r for r in log_capture.get_records() if "INITIALIZED" in r.msg] + assert len(init_logs) > 0, ( + f"Expected INITIALIZED log. Got logs: {log_capture.get_logs()}" + ) + assert "Calc Checksum:True" in init_logs[0].msg + + def test_initialization_with_excluded_cmds(self, mock_connector, log_capture): + """Test initialization with excluded commands""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={ + "audit_exclude_cmds": "exists,list", + }, + ) + audit = AuditConnector(mock_connector, config) + + assert "exists" in audit.excluded_cmds + assert "list" in audit.excluded_cmds + + print("\nCaptured logs:") + print(log_capture.get_logs()) + + init_logs = [r for r in log_capture.get_records() if "INITIALIZED" in r.msg] + assert len(init_logs) > 0, ( + f"Expected INITIALIZED log. Got logs: {log_capture.get_logs()}" + ) + assert "Excluded Cmds:" in init_logs[0].msg + + def test_put_and_get_with_audit_log( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test put and get operations with audit log verification""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + memory_obj = create_mock_memory_obj(local_cpu_backend, b"test_data") + + log_capture.clear() + + await audit.put(key, memory_obj) + + print("\nAfter PUT:") + print(log_capture.get_logs()) + + put_logs = [ + r + for r in log_capture.get_records() + if "PUT" in r.msg and "SUCCESS" in r.msg + ] + assert len(put_logs) > 0, ( + f"Expected PUT audit log. Got logs: {log_capture.get_logs()}" + ) + assert "Cost:" in put_logs[0].msg + assert "Size:" in put_logs[0].msg + + log_capture.clear() + + result = await audit.get(key) + assert result is not None + assert result.metadata.shape == memory_obj.metadata.shape + + print("\nAfter GET:") + print(log_capture.get_logs()) + + get_logs = [ + r + for r in log_capture.get_records() + if "GET" in r.msg and "SUCCESS" in r.msg + ] + assert len(get_logs) > 0, ( + f"Expected GET audit log. Got logs: {log_capture.get_logs()}" + ) + assert "Cost:" in get_logs[0].msg + + event_loop.run_until_complete(run_test()) + + def test_exists_with_audit_log( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test exists operation with audit log""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + + log_capture.clear() + + result = await audit.exists(key) + assert result is False + + print("\nAfter EXISTS:") + print(log_capture.get_logs()) + + exists_logs = [ + r + for r in log_capture.get_records() + if "EXISTS" in r.msg and "SUCCESS" in r.msg + ] + assert len(exists_logs) > 0, ( + f"Expected EXISTS audit log. Got logs: {log_capture.get_logs()}" + ) + assert "Cost:" in exists_logs[0].msg + + event_loop.run_until_complete(run_test()) + + def test_put_with_checksum_audit_log( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test put with checksum calculation and audit log""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={ + "audit_calc_checksum": True, + "audit_verify_checksum": True, + }, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + memory_obj = create_mock_memory_obj(local_cpu_backend, b"test_data") + + log_capture.clear() + + await audit.put(key, memory_obj) + + print("\nAfter PUT with checksum:") + print(log_capture.get_logs()) + + put_logs = [ + r + for r in log_capture.get_records() + if "PUT" in r.msg and "SUCCESS" in r.msg + ] + assert len(put_logs) > 0, ( + f"Expected PUT audit log. Got logs: {log_capture.get_logs()}" + ) + assert "Checksum:" in put_logs[0].msg + assert "Checksum:N/A" not in put_logs[0].msg + + event_loop.run_until_complete(run_test()) + + def test_put_and_get_with_kwargs( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test put and get operations with keyword arguments + to ensure robust argument handling.""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key_kwargs") + memory_obj = create_mock_memory_obj(local_cpu_backend, b"test_data_kwargs") + + log_capture.clear() + + # Use keyword arguments to test argument passing + await audit.put(key=key, memory_obj=memory_obj) + + put_logs = [ + r + for r in log_capture.get_records() + if "PUT" in r.msg and "SUCCESS" in r.msg + ] + assert len(put_logs) > 0, "Expected PUT audit log with keyword arguments." + + log_capture.clear() + + result = await audit.get(key=key) + assert result is not None + + get_logs = [ + r + for r in log_capture.get_records() + if "GET" in r.msg and "SUCCESS" in r.msg + ] + assert len(get_logs) > 0, "Expected GET audit log with keyword arguments." + + event_loop.run_until_complete(run_test()) + + def test_excluded_command_no_audit_log( + self, mock_connector, event_loop, log_capture + ): + """Test that excluded commands don't generate audit logs""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={ + "audit_exclude_cmds": "exists", + }, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + + log_capture.clear() + + await audit.exists(key) + + print("\nAfter excluded EXISTS:") + print(log_capture.get_logs()) + + exists_logs = [ + r + for r in log_capture.get_records() + if "EXISTS" in r.msg and "SUCCESS" in r.msg + ] + assert len(exists_logs) == 0, ( + f"Excluded command should not log. Got logs: {log_capture.get_logs()}" + ) + + event_loop.run_until_complete(run_test()) + + def test_non_excluded_command_has_audit_log( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test that non-excluded commands generate audit logs""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={ + "audit_exclude_cmds": "exists", + }, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + memory_obj = create_mock_memory_obj(local_cpu_backend, b"test_data") + + log_capture.clear() + + await audit.put(key, memory_obj) + + print("\nAfter non-excluded PUT:") + print(log_capture.get_logs()) + + put_logs = [ + r + for r in log_capture.get_records() + if "PUT" in r.msg and "SUCCESS" in r.msg + ] + assert len(put_logs) > 0, ( + f"Non-excluded command should log. Got logs: {log_capture.get_logs()}" + ) + + event_loop.run_until_complete(run_test()) + + def test_not_audit_decorator(self, mock_connector, log_capture): + """Test that @NotAudit methods don't generate operation audit logs""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + audit = AuditConnector(mock_connector, config) + + log_capture.clear() + + audit.post_init() + audit.init_chunk_meta(None, None) + + print("\nAfter @NotAudit methods:") + print(log_capture.get_logs()) + + operation_logs = [ + r + for r in log_capture.get_records() + if "POST_INIT" in r.msg or "INIT_CHUNK_META" in r.msg + ] + operation_logs = [r for r in operation_logs if "SUCCESS" in r.msg] + assert len(operation_logs) == 0, ( + f"@NotAudit methods should not generate operation audit logs. " + f"Got logs: {log_capture.get_logs()}" + ) + + def test_batched_get_with_audit_log( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test batched get with audit log""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key1 = create_test_key("key1") + key2 = create_test_key("key2") + + await audit.put(key1, create_mock_memory_obj(local_cpu_backend, b"data1")) + await audit.put(key2, create_mock_memory_obj(local_cpu_backend, b"data2")) + + log_capture.clear() + + results = await audit.batched_get([key1, key2]) + assert len(results) == 2 + assert all(r is not None for r in results) + + print("\nAfter BATCHED_GET:") + print(log_capture.get_logs()) + + batched_logs = [ + r + for r in log_capture.get_records() + if "BATCHED_GET" in r.msg and "SUCCESS" in r.msg + ] + assert len(batched_logs) > 0, ( + f"Expected BATCHED_GET audit log. Got logs: {log_capture.get_logs()}" + ) + assert "Cost:" in batched_logs[0].msg + + event_loop.run_until_complete(run_test()) + + def test_batched_async_contains_with_audit_log( + self, mock_connector, local_cpu_backend, event_loop, log_capture + ): + """Test batched async contains with audit log""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + + async def run_test(): + audit = AuditConnector(mock_connector, config) + key1 = create_test_key("key1") + key2 = create_test_key("key2") + key3 = create_test_key("key3") + + await audit.put(key1, create_mock_memory_obj(local_cpu_backend, b"data1")) + await audit.put(key2, create_mock_memory_obj(local_cpu_backend, b"data2")) + + log_capture.clear() + + count = await audit.batched_async_contains("lookup1", [key1, key2, key3]) + assert count == 2 + + print("\nAfter BATCHED_ASYNC_CONTAINS:") + print(log_capture.get_logs()) + + batched_logs = [ + r + for r in log_capture.get_records() + if "BATCHED_ASYNC_CONTAINS" in r.msg and "SUCCESS" in r.msg + ] + assert len(batched_logs) > 0, ( + f"Expected BATCHED_ASYNC_CONTAINS audit log. " + f"Got logs: {log_capture.get_logs()}" + ) + assert "Cost:" in batched_logs[0].msg + + event_loop.run_until_complete(run_test()) + + def test_support_methods(self, mock_connector): + """Test support_* methods are properly forwarded""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + audit = AuditConnector(mock_connector, config) + + assert audit.support_batched_get() is True + assert audit.support_batched_async_contains() is True + + def test_exists_sync_with_audit_log(self, mock_connector, log_capture): + """Test exists_sync synchronous operation with audit log""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={}, + ) + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + + log_capture.clear() + + result = audit.exists_sync(key) + assert result is False + + print("\nAfter EXISTS_SYNC:") + print(log_capture.get_logs()) + + exists_sync_logs = [ + r + for r in log_capture.get_records() + if "EXISTS_SYNC" in r.msg and "SUCCESS" in r.msg + ] + assert len(exists_sync_logs) > 0, ( + f"Expected EXISTS_SYNC audit log. Got logs: {log_capture.get_logs()}" + ) + assert "Cost:" in exists_sync_logs[0].msg + + def test_exists_sync_excluded_no_audit_log(self, mock_connector, log_capture): + """Test that excluded exists_sync doesn't generate audit logs""" + config = LMCacheEngineConfig.from_defaults( + chunk_size=1, + remote_url="mock://test", + extra_config={ + "audit_exclude_cmds": "exists_sync", + }, + ) + audit = AuditConnector(mock_connector, config) + key = create_test_key("test_key") + + log_capture.clear() + + result = audit.exists_sync(key) + assert result is False + + print("\nAfter excluded EXISTS_SYNC:") + print(log_capture.get_logs()) + + exists_sync_logs = [ + r + for r in log_capture.get_records() + if "EXISTS_SYNC" in r.msg and "SUCCESS" in r.msg + ] + assert len(exists_sync_logs) == 0, ( + f"Excluded exists_sync should not log. Got logs: {log_capture.get_logs()}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/v1/storage_backend/test_connector_completeness.py b/tests/v1/storage_backend/test_connector_completeness.py new file mode 100644 index 00000000000..3dfe1b9fd51 --- /dev/null +++ b/tests/v1/storage_backend/test_connector_completeness.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Test to ensure wrapper connectors (AuditConnector and InstrumentedRemoteConnector) +implement all methods defined in the base RemoteConnector class. +""" + +# Standard +from typing import Dict, List, Set +import inspect + +# Third Party +import pytest + +# First Party +from lmcache.v1.storage_backend.connector.audit_connector import AuditConnector +from lmcache.v1.storage_backend.connector.base_connector import RemoteConnector +from lmcache.v1.storage_backend.connector.instrumented_connector import ( + InstrumentedRemoteConnector, +) + + +def get_all_methods_from_base(base_class) -> Set[str]: + """ + Get all public methods defined in the base class (excluding inherited from object). + """ + methods = set() + for name in dir(base_class): + # Skip private and special methods + if name.startswith("_"): + continue + attr = getattr(base_class, name) + if callable(attr): + methods.add(name) + return methods + + +def get_methods_implemented_in_class(cls) -> Set[str]: + """ + Get methods that are actually implemented in the class itself (not just inherited). + This checks if the method is defined in the class's own __dict__ + or its direct parents (excluding the base class we're checking against). + + For classes that dynamically generate methods (like AuditConnector), we check + the class __dict__ directly without instantiation. + """ + implemented = set() + + # Check the class's own __dict__ for methods + # This works for both regular methods and dynamically generated ones + for name in cls.__dict__: + if name.startswith("_"): + continue + attr = cls.__dict__[name] + # Check if it's callable (function, method, etc.) + if callable(attr): + implemented.add(name) + + # Also check using getattr to catch any dynamically added methods + # that might not be in __dict__ but are accessible + for name in dir(cls): + if name.startswith("_"): + continue + if name in implemented: + continue # Already found + try: + attr = getattr(cls, name) + if callable(attr): + # Verify it's not inherited from RemoteConnector + # by checking if it exists in the class's MRO + # (excluding RemoteConnector) + for base in cls.__mro__: + if base is RemoteConnector: + break + if name in base.__dict__: + implemented.add(name) + break + except AttributeError: + pass + + return implemented + + +def get_abstract_methods(cls) -> Set[str]: + """ + Get all abstract methods from a class. + """ + abstract_methods = set() + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): + if getattr(method, "__isabstractmethod__", False): + abstract_methods.add(name) + return abstract_methods + + +def check_method_signatures(base_class, wrapper_class, wrapper_name: str) -> List[Dict]: + """ + Check if method signatures in wrapper class match the base class. + Returns a list of mismatches. + """ + base_methods = get_all_methods_from_base(base_class) + signature_mismatches = [] + + for method_name in base_methods: + base_method = getattr(base_class, method_name) + wrapper_method = getattr(wrapper_class, method_name, None) + + if wrapper_method is None: + continue + + try: + base_sig = inspect.signature(base_method) + wrapper_sig = inspect.signature(wrapper_method) + + # Compare parameter names (excluding 'self') + base_params = [p for p in base_sig.parameters.keys() if p != "self"] + wrapper_params = [p for p in wrapper_sig.parameters.keys() if p != "self"] + + if base_params != wrapper_params: + signature_mismatches.append( + { + "method": method_name, + "base_params": base_params, + "wrapper_params": wrapper_params, + } + ) + except (ValueError, TypeError): + # Some methods might not have inspectable signatures + pass + + return signature_mismatches + + +class TestConnectorCompleteness: + """Test that wrapper connectors implement all base connector methods""" + + def test_audit_connector_completeness(self): + """ + Comprehensive test to verify AuditConnector implements all methods + from RemoteConnector with correct signatures. + """ + # 1. Get all methods from base class + base_methods = get_all_methods_from_base(RemoteConnector) + + # 2. Get methods actually implemented in AuditConnector + audit_implemented = get_methods_implemented_in_class(AuditConnector) + + # 3. Check which base methods are missing in the implementation + missing_methods = base_methods - audit_implemented + + assert len(missing_methods) == 0, ( + f"AuditConnector is missing {len(missing_methods)} methods from " + f"RemoteConnector: {sorted(missing_methods)}\n" + f"Base methods: {sorted(base_methods)}\n" + f"Implemented methods: {sorted(audit_implemented)}" + ) + + # 4. Check all abstract methods are implemented + abstract_methods = get_abstract_methods(RemoteConnector) + audit_missing_abstract = [] + for method_name in abstract_methods: + # Check if the method is actually implemented in AuditConnector + if method_name not in audit_implemented: + audit_missing_abstract.append(method_name) + + assert len(audit_missing_abstract) == 0, ( + f"AuditConnector has not implemented {len(audit_missing_abstract)} " + f"abstract methods: {sorted(audit_missing_abstract)}" + ) + + # 5. Check method signatures match + signature_mismatches = check_method_signatures( + RemoteConnector, AuditConnector, "AuditConnector" + ) + + assert len(signature_mismatches) == 0, ( + f"AuditConnector has {len(signature_mismatches)} method signature " + f"mismatches:\n" + + "\n".join( + f" - {m['method']}: base={m['base_params']}, " + f"audit={m['wrapper_params']}" + for m in signature_mismatches + ) + ) + + def test_instrumented_connector_completeness(self): + """ + Comprehensive test to verify InstrumentedRemoteConnector implements all methods + from RemoteConnector with correct signatures. + """ + # 1. Get all methods from base class + base_methods = get_all_methods_from_base(RemoteConnector) + + # 2. Get methods actually implemented in InstrumentedRemoteConnector + instrumented_implemented = get_methods_implemented_in_class( + InstrumentedRemoteConnector + ) + + # 3. Check which base methods are missing in the implementation + missing_methods = base_methods - instrumented_implemented + + assert len(missing_methods) == 0, ( + f"InstrumentedRemoteConnector is missing {len(missing_methods)} methods " + f"from RemoteConnector: {sorted(missing_methods)}\n" + f"Base methods: {sorted(base_methods)}\n" + f"Implemented methods: {sorted(instrumented_implemented)}" + ) + + # 4. Check all abstract methods are implemented + abstract_methods = get_abstract_methods(RemoteConnector) + instrumented_missing_abstract = [] + for method_name in abstract_methods: + # Check if the method is actually implemented in InstrumentedRemoteConnector + if method_name not in instrumented_implemented: + instrumented_missing_abstract.append(method_name) + + assert len(instrumented_missing_abstract) == 0, ( + f"InstrumentedRemoteConnector has not implemented " + f"{len(instrumented_missing_abstract)} abstract methods: " + f"{sorted(instrumented_missing_abstract)}" + ) + + # 5. Check method signatures match + signature_mismatches = check_method_signatures( + RemoteConnector, InstrumentedRemoteConnector, "InstrumentedRemoteConnector" + ) + + assert len(signature_mismatches) == 0, ( + f"InstrumentedRemoteConnector has {len(signature_mismatches)} method " + f"signature mismatches:\n" + + "\n".join( + f" - {m['method']}: base={m['base_params']}, " + f"instrumented={m['wrapper_params']}" + for m in signature_mismatches + ) + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/v1/storage_backend/test_fs_connector.py b/tests/v1/storage_backend/test_fs_connector.py index 2df8cddef76..23aeeb41123 100644 --- a/tests/v1/storage_backend/test_fs_connector.py +++ b/tests/v1/storage_backend/test_fs_connector.py @@ -47,7 +47,7 @@ def create_test_metadata(): def create_test_key(key_id: int = 0) -> CacheEngineKey: """Create a test CacheEngineKey.""" - return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id)) + return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id), torch.bfloat16) def create_test_memory_obj(shape=(2, 16, 8, 128), dtype=torch.bfloat16) -> MemoryObj: diff --git a/tests/v1/storage_backend/test_gds_backend.py b/tests/v1/storage_backend/test_gds_backend.py index 85829a43994..3493068682e 100644 --- a/tests/v1/storage_backend/test_gds_backend.py +++ b/tests/v1/storage_backend/test_gds_backend.py @@ -31,7 +31,7 @@ def create_test_config(gds_path: str): def create_test_key(key_id: int = 0) -> CacheEngineKey: - return CacheEngineKey("vllm", "testmodel", 3, 123, key_id) + return CacheEngineKey("vllm", "testmodel", 3, 123, key_id, torch.bfloat16) def create_test_memory_obj( diff --git a/tests/v1/storage_backend/test_lazy_allocator_integration.py b/tests/v1/storage_backend/test_lazy_allocator_integration.py new file mode 100644 index 00000000000..6878828a526 --- /dev/null +++ b/tests/v1/storage_backend/test_lazy_allocator_integration.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Integration tests for LazyMixedMemoryAllocator with LocalCPUBackend""" + +# Standard +import unittest + +# Third Party +import torch + +# First Party +from lmcache.config import LMCacheEngineMetadata +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.lazy_memory_allocator import LazyMixedMemoryAllocator +from lmcache.v1.storage_backend.local_cpu_backend import LocalCPUBackend + + +class TestLazyAllocatorIntegration(unittest.TestCase): + """Test LazyMixedMemoryAllocator integration with LocalCPUBackend""" + + def setUp(self): + self.config = LMCacheEngineConfig.from_defaults() + self.config.local_cpu = True + self.config.max_local_cpu_size = 1.0 + self.config.chunk_size = 256 + self.config.enable_lazy_memory_allocator = True + self.config.lazy_memory_initial_ratio = 0.2 + self.config.lazy_memory_expand_trigger_ratio = 0.5 + self.config.lazy_memory_step_ratio = 0.1 + + self.metadata = LMCacheEngineMetadata( + model_name="test_model", + world_size=1, + worker_id=0, + fmt="vllm", + kv_shape=(32, 2, 256, 32, 128), + kv_dtype=torch.float16, + ) + + def _create_backend(self, config=None): + return LocalCPUBackend( + config=config or self.config, metadata=self.metadata, dst_device="cpu" + ) + + def test_lazy_allocator_enabled(self): + """Test lazy allocator is used when enabled""" + backend = self._create_backend() + self.assertIsInstance(backend.memory_allocator, LazyMixedMemoryAllocator) + expected_size = int(self.config.max_local_cpu_size * 1024**3 * 0.2) + self.assertEqual(backend.memory_allocator.initial_size, expected_size) + backend.close() + + def test_lazy_allocator_disabled(self): + """Test regular allocator when disabled""" + self.config.enable_lazy_memory_allocator = False + backend = self._create_backend() + self.assertNotIsInstance(backend.memory_allocator, LazyMixedMemoryAllocator) + backend.close() + + def test_memory_limit_callback(self): + """Test memory limit callback is set""" + backend = self._create_backend() + allocator = backend.memory_allocator + self.assertIsInstance(allocator, LazyMixedMemoryAllocator) + self.assertIsNotNone(allocator.async_expander.memory_limit_callback) + self.assertGreater(allocator.async_expander.memory_limit_callback(), 0) + backend.close() + + def test_allocation_with_lazy_allocator(self): + """Test basic allocation""" + backend = self._create_backend() + shape, dtype = torch.Size([256, 2, 4096]), torch.float16 + mem_obj = backend.allocate(shape, dtype, eviction=False, busy_loop=False) + self.assertIsNotNone(mem_obj) + self.assertEqual(mem_obj.meta.shape, shape) + self.assertEqual(mem_obj.meta.dtype, dtype) + backend.memory_allocator.free(mem_obj) + backend.close() + + def test_config_parameters(self): + """Test config parameters are applied""" + cfg = LMCacheEngineConfig.from_defaults() + cfg.local_cpu = True + cfg.max_local_cpu_size = 2.0 + cfg.enable_lazy_memory_allocator = True + cfg.lazy_memory_initial_ratio = 0.3 + cfg.lazy_memory_expand_trigger_ratio = 0.6 + cfg.lazy_memory_step_ratio = 0.15 + + backend = self._create_backend(cfg) + alloc = backend.memory_allocator + self.assertEqual(alloc.initial_ratio, 0.3) + self.assertEqual(alloc.expand_trigger_ratio, 0.6) + self.assertEqual(alloc.step_ratio, 0.15) + backend.close() diff --git a/tests/v1/storage_backend/test_local_cpu_backend.py b/tests/v1/storage_backend/test_local_cpu_backend.py index 4b17d354cbc..b907fb42a6b 100644 --- a/tests/v1/storage_backend/test_local_cpu_backend.py +++ b/tests/v1/storage_backend/test_local_cpu_backend.py @@ -54,7 +54,7 @@ def create_test_config( def create_test_key(key_id: str = "test_key") -> CacheEngineKey: """Create a test CacheEngineKey.""" - return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id)) + return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id), torch.bfloat16) def create_test_memory_obj(shape=(2, 16, 8, 128), dtype=torch.bfloat16) -> MemoryObj: diff --git a/tests/v1/storage_backend/test_local_disk_backend.py b/tests/v1/storage_backend/test_local_disk_backend.py index 8523f07269a..ae56be00261 100644 --- a/tests/v1/storage_backend/test_local_disk_backend.py +++ b/tests/v1/storage_backend/test_local_disk_backend.py @@ -66,13 +66,13 @@ def create_test_metadata(): def create_test_key(key_id: int = 0) -> CacheEngineKey: """Create a test CacheEngineKey.""" - return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id)) + return CacheEngineKey("vllm", "test_model", 3, 123, hash(key_id), torch.bfloat16) def create_test_memory_obj(shape=(2, 16, 8, 128), dtype=torch.bfloat16) -> MemoryObj: """Create a test MemoryObj using AdHocMemoryAllocator for testing.""" # First Party - from lmcache.v1.memory_management import AdHocMemoryAllocator, MemoryFormat + from lmcache.v1.memory_management import AdHocMemoryAllocator allocator = AdHocMemoryAllocator(device="cpu") memory_obj = allocator.allocate(shape, dtype, fmt=MemoryFormat.KV_T2D) @@ -194,69 +194,3 @@ def test_get_blocking_key_not_exists(self, local_disk_backend): assert result is None local_disk_backend.local_cpu_backend.memory_allocator.close() - - def test_async_load_bytes_from_disk(self, local_disk_backend): - """Test async_load_bytes_from_disk()""" - key = create_test_key(3) - memory_obj = create_test_memory_obj() - - # Create the file first - path = local_disk_backend._key_to_path(key) - with open(path, "wb") as f: - f.write(memory_obj.byte_array) - - result = local_disk_backend.load_bytes_from_disk( - key, - path, - memory_obj.metadata.dtype, - memory_obj.metadata.shape, - memory_obj.metadata.fmt, - ) - - assert result is not None - assert isinstance(result, MemoryObj) - assert result.metadata.shape == memory_obj.metadata.shape - assert result.metadata.dtype == memory_obj.metadata.dtype - - local_disk_backend.local_cpu_backend.memory_allocator.close() - - def test_load_bytes_from_disk(self, local_disk_backend): - """Test load_bytes_from_disk().""" - key = create_test_key(3) - memory_obj = create_test_memory_obj() - - # Create the file first - path = local_disk_backend._key_to_path(key) - with open(path, "wb") as f: - f.write(memory_obj.byte_array) - - result = local_disk_backend.load_bytes_from_disk( - key, - path, - memory_obj.metadata.dtype, - memory_obj.metadata.shape, - memory_obj.metadata.fmt, - ) - - assert result is not None - assert isinstance(result, MemoryObj) - assert result.metadata.shape == memory_obj.metadata.shape - assert result.metadata.dtype == memory_obj.metadata.dtype - - local_disk_backend.local_cpu_backend.memory_allocator.close() - - def test_file_operations_error_handling(self, local_disk_backend): - """Test error handling in file operations.""" - # Test with non-existent file - key = create_test_key(3) - non_existent_path = "/non/existent/path/file.pt" - - memory_obj = local_disk_backend.load_bytes_from_disk( - key, - non_existent_path, - torch.bfloat16, - torch.Size([2, 16, 8, 128]), - MemoryFormat.KV_T2D, - ) - assert memory_obj is not None - local_disk_backend.local_cpu_backend.memory_allocator.close() diff --git a/tests/v1/test_config.py b/tests/v1/test_config.py index 94b35cc61c1..f4ee4703876 100644 --- a/tests/v1/test_config.py +++ b/tests/v1/test_config.py @@ -3,6 +3,9 @@ from pathlib import Path import os +# Third Party +import pytest + # First Party from lmcache.v1.config import LMCacheEngineConfig @@ -31,3 +34,105 @@ def check_extra_config(config: "LMCacheEngineConfig"): assert len(config.extra_config) == 2 assert config.extra_config["key1"] == "value1" assert config.extra_config["key2"] == "value2" + + +def test_update_config_from_env_basic(): + config = LMCacheEngineConfig.from_defaults() + original_chunk_size = config.chunk_size + os.environ["LMCACHE_CHUNK_SIZE"] = " 512 " + os.environ["LMCACHE_REMOTE_URL"] = " http://example.com:8080 " + config.update_config_from_env() + assert config.chunk_size == 512 and config.chunk_size != original_chunk_size + assert config.remote_url == "http://example.com:8080" + del os.environ["LMCACHE_CHUNK_SIZE"] + del os.environ["LMCACHE_REMOTE_URL"] + + +def test_update_config_from_env_quotes(): + config = LMCacheEngineConfig.from_defaults() + os.environ["LMCACHE_REMOTE_URL"] = "'http://example.com:8080'" + os.environ["LMCACHE_PD_ROLE"] = '"sender"' + os.environ["LMCACHE_BLEND_SPECIAL_STR"] = "' ### '" + config.update_config_from_env() + assert config.remote_url == "http://example.com:8080" + assert config.pd_role == "sender" and config.blend_special_str == " ### " + del os.environ["LMCACHE_REMOTE_URL"] + del os.environ["LMCACHE_PD_ROLE"] + del os.environ["LMCACHE_BLEND_SPECIAL_STR"] + + +def test_update_config_from_env_extra_config(): + config = LMCacheEngineConfig.from_defaults() + test_cases = [ + ( + ' {"test_key": "test_value", "number": 42} ', + {"test_key": "test_value", "number": 42}, + ), + ('\'{"nested": {"key": "value"}}\'', {"nested": {"key": "value"}}), + ('"{\\"config\\": \\"prod\\"}"', {"config": "prod"}), + ] + for test_input, expected in test_cases: + os.environ["LMCACHE_EXTRA_CONFIG"] = test_input + config.update_config_from_env() + assert config.extra_config == expected + del os.environ["LMCACHE_EXTRA_CONFIG"] + + +def test_update_config_from_env_internal_api_server_include_index_list(): + config = LMCacheEngineConfig.from_defaults() + test_cases = [ + (" 1,2,3,4 ", [1, 2, 3, 4]), + ('"1,2,3,4"', [1, 2, 3, 4]), + ("'1,2,3,4'", [1, 2, 3, 4]), + (" 1 , 2 , 3 , 4 ", [1, 2, 3, 4]), + (" 5 ", [5]), + ('"10"', [10]), + ] + for test_input, expected in test_cases: + os.environ["LMCACHE_INTERNAL_API_SERVER_INCLUDE_INDEX_LIST"] = test_input + config.update_config_from_env() + assert config.internal_api_server_include_index_list == expected + del os.environ["LMCACHE_INTERNAL_API_SERVER_INCLUDE_INDEX_LIST"] + + +def test_update_config_from_env_error_handling(): + config = LMCacheEngineConfig.from_defaults() + original_chunk_size, original_extra_config = config.chunk_size, config.extra_config + os.environ["LMCACHE_CHUNK_SIZE"] = "invalid_number" + os.environ["LMCACHE_EXTRA_CONFIG"] = "invalid_json{" + config.update_config_from_env() + assert ( + config.chunk_size == original_chunk_size + and config.extra_config == original_extra_config + ) + os.environ["LMCACHE_CONTROLLER_PULL_URL"] = "http://controller.example.com" + config.update_config_from_env() + assert config.controller_pull_url == "http://controller.example.com" + del os.environ["LMCACHE_CHUNK_SIZE"] + del os.environ["LMCACHE_EXTRA_CONFIG"] + del os.environ["LMCACHE_CONTROLLER_PULL_URL"] + + +@pytest.mark.parametrize("use_mla", [True, False]) +def test_get_lookup_server_worker_ids(use_mla): + config = LMCacheEngineConfig.from_defaults() + lookup_server_worker_ids = config.get_lookup_server_worker_ids(use_mla, 8) + # test default value + if use_mla: + assert lookup_server_worker_ids == [0] + else: + assert lookup_server_worker_ids == [] + + # test different config + # TODO: not support format "[]" or "[0, 3, 6] + os.environ["LMCACHE_LOOKUP_SERVER_WORKER_IDS"] = "1" + config.update_config_from_env() + lookup_server_worker_ids = config.get_lookup_server_worker_ids(use_mla, 8) + assert lookup_server_worker_ids == [1] + + os.environ["LMCACHE_LOOKUP_SERVER_WORKER_IDS"] = "0, 3, 6" + config.update_config_from_env() + lookup_server_worker_ids = config.get_lookup_server_worker_ids(use_mla, 8) + assert lookup_server_worker_ids == [0, 3, 6] + + del os.environ["LMCACHE_LOOKUP_SERVER_WORKER_IDS"] diff --git a/tests/v1/test_connector.py b/tests/v1/test_connector.py index 9a2825398df..f6d6113ca2b 100644 --- a/tests/v1/test_connector.py +++ b/tests/v1/test_connector.py @@ -12,6 +12,7 @@ from lmcache.config import LMCacheEngineMetadata from lmcache.v1.config import LMCacheEngineConfig from lmcache.v1.memory_management import PinMemoryAllocator +from lmcache.v1.protocol import RemoteMetadata from lmcache.v1.storage_backend.connector import CreateConnector # Local @@ -297,3 +298,103 @@ def test_redis_sentinel_connector(url, autorelease_v1): close_asyncio_loop(async_loop, async_thread) memory_allocator.close() + + +REDIS_CLUSTER_URLS = [ + "redis-cluster://host1:7000,host2:7000,host3:7000", + "redis-cluster://clustercfg.cluster-name.id.region.cache.amazonaws.com:6379", + "redis-cluster://user:password@host1:7000,host2:7000,host3:7000", +] + + +@pytest.mark.parametrize("url", REDIS_CLUSTER_URLS) +def test_redis_cluster_connector(url, autorelease_v1): + """Test Redis Cluster connector: exists, put, get operations. + + This test uses the MockRedisCluster from conftest.py to simulate + Redis Cluster behavior without requiring an actual Redis Cluster setup. + """ + + # Standard + import os + + os.environ["REDIS_TIMEOUT"] = "3.5" + + async_loop, async_thread = init_asyncio_loop() + memory_allocator = PinMemoryAllocator(1024 * 1024 * 1024) + + connector = autorelease_v1(CreateConnector(url, async_loop, memory_allocator)) + + random_key = dumb_cache_engine_key() + + # Test 1: Verify key doesn't exist initially, test contains key not exist + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert not future.result() + + # Test 2: Create and store test data + num_tokens = 1000 + mem_obj_shape = [2, 32, num_tokens, 1024] + dtype = torch.bfloat16 + memory_obj = memory_allocator.allocate(mem_obj_shape, dtype) + memory_obj.ref_count_up() + + # Fill with deterministic test data + torch.manual_seed(42) + test_tensor = torch.randint(0, 100, memory_obj.raw_data.shape, dtype=torch.int64) + memory_obj.raw_data.copy_(test_tensor.to(torch.float32).to(dtype)) + + # Test 3: Put data + future = asyncio.run_coroutine_threadsafe( + connector.put(random_key, memory_obj), async_loop + ) + future.result() + + # Test 4: Verify key exists after putting data, test contains key exists + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert future.result() + assert memory_obj.get_ref_count() == 1 + + # Test 5: Retrieve and verify data + future = asyncio.run_coroutine_threadsafe(connector.get(random_key), async_loop) + retrieved_memory_obj = future.result() + + check_mem_obj_equal([retrieved_memory_obj], [memory_obj]) + + close_asyncio_loop(async_loop, async_thread) + memory_allocator.close() + + +@pytest.mark.parametrize("url", REDIS_CLUSTER_URLS) +def test_cluster_metadata_without_kv_bytes(url, autorelease_v1): + async_loop, async_thread = init_asyncio_loop() + memory_allocator = PinMemoryAllocator(1024 * 1024 * 1024) + connector = autorelease_v1(CreateConnector(url, async_loop, memory_allocator)) + + random_key = dumb_cache_engine_key() + # build a small mem obj to get correct metadata bytes + memory_obj = memory_allocator.allocate([2, 32, 8, 64], torch.bfloat16) + kv_bytes = memory_obj.byte_array + meta = RemoteMetadata( + len(kv_bytes), + memory_obj.get_shape(), + memory_obj.get_dtype(), + memory_obj.get_memory_format(), + ) + metadata_bytes = meta.serialize() + + # clean up memory object after getting metadata + memory_obj.ref_count_down() + + # inject only metadata, no kv_bytes + meta_key = random_key.to_string() + "metadata" + connector._connector.cluster.set(meta_key, metadata_bytes) + + # get() should return None and remove the metadata without kv_bytes pair + future = asyncio.run_coroutine_threadsafe(connector.get(random_key), async_loop) + assert future.result() is None + + future = asyncio.run_coroutine_threadsafe(connector.exists(random_key), async_loop) + assert not future.result() + + close_asyncio_loop(async_loop, async_thread) + memory_allocator.close() diff --git a/tests/v1/test_gds.py b/tests/v1/test_gds.py index 2e887494009..db66c242dde 100644 --- a/tests/v1/test_gds.py +++ b/tests/v1/test_gds.py @@ -16,7 +16,7 @@ from lmcache.utils import CacheEngineKey from lmcache.v1.cache_engine import LMCacheEngineBuilder from lmcache.v1.config import LMCacheEngineConfig -from lmcache.v1.memory_management import CuFileMemoryAllocator +from lmcache.v1.memory_management import CuFileMemoryAllocator, MemoryFormat from lmcache.v1.storage_backend import CreateStorageBackends from lmcache.v1.storage_backend.gds_backend import pack_metadata, unpack_metadata @@ -25,11 +25,12 @@ def test_gds_backend_metadata(): # This is a sanity check that packing and unpacking works. We can add # more tensor types to be sure. for [tensor, expected_nbytes] in [(torch.randn(3, 10), 120)]: - r = pack_metadata(tensor, version="test") - size, dtype, nbytes, meta = unpack_metadata(r) + r = pack_metadata(tensor, fmt=MemoryFormat.KV_2LTD, version="test") + size, dtype, nbytes, fmt, meta = unpack_metadata(r) assert size == tensor.size() assert dtype == tensor.dtype assert expected_nbytes == nbytes + assert fmt == MemoryFormat.KV_2LTD assert meta["version"] == "test" # Make sure that safetensors can load this @@ -56,6 +57,7 @@ def test_gds_backend_sanity(): world_size=8, worker_id=0, chunk_hash="e3229141e680fb413d2c5d3ebb416c4ad300d381e309fc9e417757b91406c157", + dtype=torch.uint8, ) BACKEND_NAME = "GdsBackend" diff --git a/tests/v1/test_lazy_memory_allocator.py b/tests/v1/test_lazy_memory_allocator.py new file mode 100644 index 00000000000..5dc18ae5391 --- /dev/null +++ b/tests/v1/test_lazy_memory_allocator.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test cases for LazyMixedMemoryAllocator""" + +# Standard +import time + +# Third Party +import torch + +# First Party +from lmcache.v1.config import LMCacheEngineConfig +from lmcache.v1.lazy_memory_allocator import ( + CompositeBuffer, + CompositeTensorMemoryAllocator, + LazyMixedMemoryAllocator, +) +from lmcache.v1.memory_management import MemoryFormat + + +class TestLazyMemoryAllocator: + """Test suite for LazyMixedMemoryAllocator""" + + @staticmethod + def _create_allocator(size_mb=10, initial=0.2, trigger=0.5, step=0.1): + config = LMCacheEngineConfig.from_defaults( + lazy_memory_initial_ratio=initial, + lazy_memory_expand_trigger_ratio=trigger, + lazy_memory_step_ratio=step, + ) + return LazyMixedMemoryAllocator( + size=size_mb * 1024 * 1024, + config=config, + ) + + @staticmethod + def _verify_mem_obj(mem_obj, shape, dtype=torch.float32): + assert mem_obj and mem_obj.is_valid() + assert mem_obj.get_shape() == shape + if dtype: + assert mem_obj.get_dtype() == dtype + + def test_basic_allocation(self): + """Test basic allocation and tensor access""" + allocator = self._create_allocator() + try: + shape = torch.Size([1024, 256]) + mem_obj = allocator.allocate(shape, torch.float32, MemoryFormat.KV_2LTD) + self._verify_mem_obj(mem_obj, shape) + assert mem_obj.tensor is not None and mem_obj.tensor.shape == shape + mem_obj.ref_count_down() + finally: + allocator.close() + + def test_async_expansion_trigger(self): + """Test async expansion trigger""" + allocator = self._create_allocator(size_mb=20, initial=0.2, step=0.2) + try: + shape = torch.Size([512, 1024]) + mem_obj = allocator.allocate(shape, torch.float32, MemoryFormat.KV_2LTD) + assert mem_obj and allocator.expansion_triggered + time.sleep(0.5) + if allocator.async_expander: + assert allocator.composite_buffer.numel() >= allocator.initial_size + mem_obj.ref_count_down() + finally: + allocator.close() + + def test_multiple_allocations(self): + """Test multiple allocations with expansion""" + allocator = self._create_allocator(size_mb=50, initial=0.1, step=0.2) + try: + shape = torch.Size([256, 512]) + mem_objs = [ + allocator.allocate(shape, torch.float32, MemoryFormat.KV_2LTD) + for _ in range(10) + ] + mem_objs = [m for m in mem_objs if m] + assert len(mem_objs) > 0 + for m in mem_objs: + self._verify_mem_obj(m, shape) + time.sleep(1.0) + for _ in range(5): + m = allocator.allocate(shape, torch.float32, MemoryFormat.KV_2LTD) + if m: + mem_objs.append(m) + for m in mem_objs: + m.ref_count_down() + finally: + allocator.close() + + def test_batched_allocation(self): + """Test batched allocation""" + allocator = self._create_allocator(size_mb=30, initial=0.3, step=0.2) + try: + shape, batch_size = torch.Size([128, 256]), 5 + mem_objs = allocator.batched_allocate( + shape, torch.float32, batch_size, MemoryFormat.KV_2LTD + ) + assert mem_objs and len(mem_objs) == batch_size + for m in mem_objs: + self._verify_mem_obj(m, shape) + m.ref_count_down() + finally: + allocator.close() + + def test_free_and_reuse(self): + """Test memory reuse after free""" + allocator = self._create_allocator(initial=0.5, trigger=0.8) + try: + shape = torch.Size([256, 256]) + mem_obj1 = allocator.allocate(shape, torch.float32, MemoryFormat.KV_2LTD) + assert mem_obj1 + mem_obj1.ref_count_down() + mem_obj2 = allocator.allocate(shape, torch.float32, MemoryFormat.KV_2LTD) + assert mem_obj2 + mem_obj2.ref_count_down() + finally: + allocator.close() + + def test_buffer_allocator_passthrough(self): + """Test BINARY_BUFFER format""" + allocator = self._create_allocator() + try: + mem_obj = allocator.allocate( + torch.Size([1024]), None, MemoryFormat.BINARY_BUFFER + ) + assert mem_obj and mem_obj.get_memory_format() == MemoryFormat.BINARY_BUFFER + mem_obj.ref_count_down() + finally: + allocator.close() + + def test_composite_buffer_growth(self): + """Test composite buffer segment growth""" + allocator = self._create_allocator(size_mb=100, initial=0.1) + try: + initial_segs = len(allocator.composite_buffer.segments) + assert initial_segs == 1 + mem_obj = allocator.allocate( + torch.Size([1024, 1024]), torch.float32, MemoryFormat.KV_2LTD + ) + assert mem_obj + time.sleep(1.0) + assert len(allocator.composite_buffer.segments) >= initial_segs + mem_obj.ref_count_down() + finally: + allocator.close() + + def test_segment_aware_coalescing(self): + """Test coalescing respects segment boundaries""" + buf = torch.empty(1024 * 1024, dtype=torch.uint8) + comp_buf = CompositeBuffer(buf) + alloc = CompositeTensorMemoryAllocator(comp_buf) + + # Allocate and verify + shape1 = torch.Size([256, 1024]) + mem1 = alloc.allocate(shape1, torch.float32, MemoryFormat.KV_2LTD) + assert mem1, "First allocation failed" + + # Add second segment + alloc.expand_with_new_segment(torch.empty(1024 * 1024, dtype=torch.uint8)) + alloc.free(mem1) + + # Verify no blocks span boundaries + for block in alloc.explicit_list: + block_end = block.start + block.size + for boundary in alloc.segment_boundaries[:-1]: + assert not (block.start < boundary < block_end), ( + f"Block crosses boundary at {boundary}" + ) + + # Allocate from second segment + mem2 = alloc.allocate( + torch.Size([128, 1024]), torch.float32, MemoryFormat.KV_2LTD + ) + assert mem2, "Second allocation failed" + + # Verify allocation doesn't span segments + alloc_end = mem2.meta.address + mem2.meta.phy_size + for boundary in alloc.segment_boundaries[:-1]: + assert not (mem2.meta.address < boundary < alloc_end), ( + f"Allocation spans boundary at {boundary}" + ) + + # Verify get_slice works + try: + comp_buf.get_slice(mem2.meta.address, mem2.meta.phy_size) + except ValueError as e: + if "spans multiple segments" in str(e): + raise AssertionError(f"get_slice error: {e}") from e + raise + + def test_cross_segment_allocation_prevented(self): + """Test allocations don't span segments""" + buf = torch.empty(100 * 1024, dtype=torch.uint8) + comp_buf = CompositeBuffer(buf) + alloc = CompositeTensorMemoryAllocator(comp_buf) + + # Allocate most of first segment + mem1 = alloc.allocate( + torch.Size([20, 1024]), torch.float32, MemoryFormat.KV_2LTD + ) + assert mem1 + + # Add second segment and free first allocation + alloc.expand_with_new_segment(torch.empty(100 * 1024, dtype=torch.uint8)) + alloc.free(mem1) + + # Verify separate free blocks (no cross-segment coalescing) + assert len(alloc.explicit_list) == 2, ( + f"Expected 2 free blocks, got {len(alloc.explicit_list)}" + ) diff --git a/tests/v1/test_nixl_storage.py b/tests/v1/test_nixl_storage.py index 0ef09651b09..05fb9d51814 100644 --- a/tests/v1/test_nixl_storage.py +++ b/tests/v1/test_nixl_storage.py @@ -26,6 +26,7 @@ def create_key(chunk_hash: str): world_size=8, worker_id=0, chunk_hash=int(chunk_hash, base=16), + dtype=torch.bfloat16, ) @@ -39,8 +40,13 @@ def run(config: LMCacheEngineConfig, shape, dtype): keys.append( create_key("e3229141e680fb413d2c5d3ebb416c4ad300d381e309fc9e417757b91406d268") ) + keys.append( + create_key("e3229141e680fb413d2c5d3ebb416c4ad300d381e309fc9e417757b91406e379") + ) bad_key = create_key("deadbeefdeadbeef") + thread_loop = None + thread = None try: thread_loop = asyncio.new_event_loop() thread = threading.Thread(target=thread_loop.run_forever) @@ -60,6 +66,7 @@ def run(config: LMCacheEngineConfig, shape, dtype): config, metadata, thread_loop, + dst_device=config.nixl_buffer_device, # Pass the device directly ) assert len(backends) == 2 # NixlStorageBackend + LocalCPUBackend assert BACKEND_NAME in backends @@ -86,9 +93,15 @@ def run(config: LMCacheEngineConfig, shape, dtype): objs[1].tensor[300, 400] = 1e-2 objs[1].tensor[400, 300] = 1e-5 - nixl_backend.batched_submit_put_task(keys, objs) + objs[2].tensor[300, 400] = 3e-2 + objs[2].tensor[400, 300] = 4e-5 + + # Insert first 2 keys + first_keys = keys[0:2] + first_objs = objs[0:2] + nixl_backend.batched_submit_put_task(first_keys, first_objs) - for key, obj in zip(keys, objs, strict=False): + for key, obj in zip(first_keys, first_objs, strict=False): returned_memory_obj = nixl_backend.get_blocking(key) assert returned_memory_obj is not None assert returned_memory_obj.get_size() == obj.get_size() @@ -98,10 +111,10 @@ def run(config: LMCacheEngineConfig, shape, dtype): assert torch.equal(returned_memory_obj.tensor, obj.tensor) obj_list = asyncio.run( - nixl_backend.batched_get_non_blocking(lookup_id="test", keys=keys) + nixl_backend.batched_get_non_blocking(lookup_id="test", keys=first_keys) ) - for i, obj in enumerate(objs): + for i, obj in enumerate(first_objs): returned_memory_obj = obj_list[i] assert returned_memory_obj is not None assert returned_memory_obj.get_size() == obj.get_size() @@ -110,16 +123,70 @@ def run(config: LMCacheEngineConfig, shape, dtype): assert returned_memory_obj.metadata.address != obj.metadata.address assert torch.equal(returned_memory_obj.tensor, obj.tensor) - bad_obj = nixl_backend.get_blocking(bad_key) - assert bad_obj is None + def test_eviction(new_idx, old_idx): + nixl_backend.batched_submit_put_task([keys[new_idx]], [objs[new_idx]]) + + obj = nixl_backend.get_blocking(keys[new_idx]) + assert obj is not None + assert obj.tensor is not None + assert torch.equal(obj.tensor, objs[new_idx].tensor) + + obj = nixl_backend.get_blocking(keys[old_idx]) + assert obj is None + + ######## Test bad key lookup ######### + obj = nixl_backend.get_blocking(bad_key) + assert obj is None + + ######## Test eviction ######### + obj = nixl_backend.get_blocking(keys[0]) + assert obj is not None + + # At this point, key 0 & key 1 are cached. Key 1 is LRU key. + # Submitting key 2 should evict key 1. + + test_eviction(new_idx=2, old_idx=1) + + ######## Test pin ######### + val = nixl_backend.pin(keys[2]) + assert val is True + + obj = nixl_backend.get_blocking(keys[0]) + assert obj is not None + + # At this point, key 0 & key 2 are cached. + # Key 2 is LRU key, but is pinned. + # Submitting key 1 should evict key 0. + + test_eviction(new_idx=1, old_idx=0) + + ######## Test unpin ######### + val = nixl_backend.unpin(keys[2]) + assert val is True + + obj = nixl_backend.get_blocking(keys[1]) + assert obj is not None + + # At this point, key 1 & key 2 are cached. + # Key 2 is LRU key, and is now unpinned. + # Submitting key 0 should evict key 2. + + test_eviction(new_idx=0, old_idx=2) + + for backend in backends.values(): + backend.close() + + except Exception: + raise finally: - if thread_loop.is_running(): + if thread_loop and thread_loop.is_running(): thread_loop.call_soon_threadsafe(thread_loop.stop) - if thread.is_alive(): + if thread and thread.is_alive(): thread.join() @pytest.mark.no_shared_allocator +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") def test_nixl_gds_mt_cuda_backend(): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") @@ -127,8 +194,9 @@ def test_nixl_gds_mt_cuda_backend(): dtype = torch.bfloat16 shape = [2048, 2048] - config.nixl_buffer_device = "cuda" + config.nixl_buffer_device = "cuda:0" # Use explicit device config.extra_config["nixl_backend"] = "GDS_MT" + config.extra_config["enable_cuda"] = True run(config, shape, dtype) @@ -143,11 +211,13 @@ def test_nixl_gds_mt_cpu_backend(): config.nixl_buffer_device = "cpu" config.extra_config["nixl_backend"] = "GDS_MT" + config.extra_config["enable_cuda"] = False run(config, shape, dtype) @pytest.mark.no_shared_allocator +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") def test_nixl_gds_cuda_backend(): BASE_DIR = Path(__file__).parent config = LMCacheEngineConfig.from_file(BASE_DIR / "data/nixl.yaml") @@ -155,8 +225,9 @@ def test_nixl_gds_cuda_backend(): dtype = torch.bfloat16 shape = [2048, 2048] - config.nixl_buffer_device = "cuda" + config.nixl_buffer_device = "cuda:0" # Use explicit device config.extra_config["nixl_backend"] = "GDS" + config.extra_config["enable_cuda"] = True run(config, shape, dtype) @@ -171,6 +242,7 @@ def test_nixl_gds_cpu_backend(): config.nixl_buffer_device = "cpu" config.extra_config["nixl_backend"] = "GDS" + config.extra_config["enable_cuda"] = False run(config, shape, dtype) @@ -185,5 +257,6 @@ def test_nixl_posix_backend(): config.nixl_buffer_device = "cpu" config.extra_config["nixl_backend"] = "POSIX" + config.extra_config["enable_cuda"] = False run(config, shape, dtype) diff --git a/tests/v1/test_remote_mla_worker_id_as0.py b/tests/v1/test_remote_mla_worker_id_as0.py index 522e9a07e07..87b738afedb 100644 --- a/tests/v1/test_remote_mla_worker_id_as0.py +++ b/tests/v1/test_remote_mla_worker_id_as0.py @@ -106,6 +106,7 @@ def test_remote_mla_worker_id_as0(mock_stream): world_size=4, worker_id=2, chunk_hash="test_hash", + dtype=torch.float32, ) backend0 = RemoteBackend( @@ -122,6 +123,7 @@ def test_remote_mla_worker_id_as0(mock_stream): world_size=4, worker_id=0, chunk_hash="test_hash", + dtype=torch.float32, ) # Test not contains before adding data diff --git a/tests/v1/test_weka.py b/tests/v1/test_weka.py index bac81fb66e4..1ccc1af2b6f 100644 --- a/tests/v1/test_weka.py +++ b/tests/v1/test_weka.py @@ -28,6 +28,7 @@ def test_weka_backend_sanity(): world_size=8, worker_id=0, chunk_hash="e3229141e680fb413d2c5d3ebb416c4ad300d381e309fc9e417757b91406c157", + dtype=torch.uint8, ) BACKEND_NAME = "WekaGdsBackend" diff --git a/tests/v1/utils.py b/tests/v1/utils.py index 00f20bd86b7..04bccb5dcf0 100644 --- a/tests/v1/utils.py +++ b/tests/v1/utils.py @@ -33,7 +33,7 @@ def dumb_metadata_with_model_name( def dumb_cache_engine_key(id: int = 0) -> CacheEngineKey: - return CacheEngineKey("vllm", "test_model", 3, 123, id) + return CacheEngineKey("vllm", "test_model", 3, 123, id, torch.bfloat16) def random_string(N):