Skip to content

IBM/task-aware-embedding-refinement

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

5 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

๐Ÿš€ LLM-Guided Embedding Refinement

Boost zero-shot classification and retrieval with test-time query optimization

Python 3.11+ License


๐Ÿ“– Overview

This repository contains code to improve zero-shot classification and retrieval using embedding models through test-time optimization of query embedding representations.

At test time, given a user query, the query embedding is optimized with gradient descent based on targeted feedback from a stronger model. The method uses scores from an LLM or reranker over a small sampled set of candidate documents, then updates the query representation so the embedding space better reflects the task-specific intent of the query.

๐ŸŽฏ Key Features

  • โœจ Test-time optimization โ€” No retraining required
  • ๐Ÿ”„ Flexible architecture โ€” Works with various text embedding models and LLM rerankers
  • ๐Ÿ“Š Proven results โ€” Consistent gains across multiple benchmarks

๐Ÿ“„ Paper

Task-Adaptive Embedding Refinement via Test-time LLM Guidance
Ariel Gera, Shir Ashury-Tahan, Gal Bloch, Ohad Eytan, Assaf Toledo


๐Ÿ’ก How It Works

Embedding models are efficient and scalable, but in challenging zero-shot settings they may miss nuanced task constraints. This library explores a test-time refinement procedure that adapts the query representation using external guidance from a generative LLM, without retraining the embedding model.

This approach improves ranking quality across multiple search and classification benchmarks, with consistent gains on tasks such as:

  • ๐Ÿ“š Literature search
  • ๐ŸŽฏ Intent detection
  • ๐Ÿ”‘ Key-point matching
  • ๐Ÿ“‹ Instruction-following retrieval

๐Ÿ”„ Workflow

Step-by-step process:

  1. ๐Ÿ“ Embed the original query and candidate documents
  2. ๐Ÿ” Retrieve top candidates by embedding similarity
  3. ๐ŸŽฏ Score a sampled subset using an LLM, cross-encoder reranker, or gold labels
  4. โšก Optimize the query embedding to better align with supervision
  5. ๐Ÿ”„ Re-score the corpus using the refined query embedding

๐Ÿš€ Quick Start

๐Ÿ“ฆ Installation

1๏ธโƒฃ Create and activate a Python environment

python -m venv .venv
source .venv/bin/activate

2๏ธโƒฃ Install dependencies

pip install -r requirements.txt

3๏ธโƒฃ Configure inference for LiteLLM or OpenAI

The reranker and HyDE components use an OpenAI-compatible chat-completions API.

Option A โ€” LiteLLM gateway or proxy

Set BASE_URL to your LiteLLM endpoint and API_KEY to the corresponding key:

export BASE_URL="http://localhost:4000"
export API_KEY="your-litellm-api-key"

Use LiteLLM explicitly by prefixing the model name with LiteLLM/, for example:

python run_experiments.py \
  --reranker_model "LiteLLM/mistralai/Mistral-Small-3.2-24B-Instruct-2506" \

Option B โ€” OpenAI API

Set your OpenAI API key:

export OPENAI_API_KEY="your-openai-api-key"

Use OpenAI explicitly by prefixing the model name with OpenAI/, for example:

python run_experiments.py \
  --reranker_model "OpenAI/gpt-4.1-mini" \
  --hyde_model "OpenAI/gpt-4.1-mini"

Important notes

  • If no service prefix is provided, this repository defaults to LiteLLM for LLM inference. This will fail if you did not define a suitable endpoint in your environment.
  • --reranker_model controls the LLM used for relevance feedback during optimization.
  • --hyde_model controls the LLM used to generate hypothetical documents for HyDE (HYpothetical Document Embeddings, see here), this is optional and is not required for basic query optimization functionality.
  • LiteLLM/OpenAI setup is only required when using LLM-based reranking or HyDE. It is not needed for --optimize_with_gold or cross-encoder rerankers such as cross-encoder/ms-marco-MiniLM-L-6-v2.

โ–ถ๏ธ Basic Usage

Run all default models on all datasets:

python run_experiments.py

Run a specific model on selected datasets:

python run_experiments.py \
  --embedding_models "Qwen/Qwen3-Embedding-0.6B" \
  --datasets "Clinc150" "NFCorpus"

Run experiments in parallel:

python run_experiments.py --parallel 3

๐Ÿ“Š Supported Datasets

The repository currently supports the following datasets through dataset_loaders.py:

Dataset Description Reference
๐ŸŽ“ RealScholarQuery Real-world academic search queries over arXiv CS papers He et al., 2025
๐Ÿ”‘ ArgKP-21 Key-point matching from 2021 KPA shared task Friedman et al., 2021
๐Ÿ“‹ FollowIR Information retrieval from TREC relevance narratives Weller et al., 2025
๐Ÿ’ฌ Clinc150 Intent classification with 150 intents across 10 domains Larson et al., 2019
๐Ÿฆ Banking77 Banking domain with 77 fine-grained intent categories Casanueva et al., 2020
๐Ÿฅ NFCorpus Medical literature retrieval with lay queries Boteva et al., 2016

๐Ÿ› ๏ธ Custom Usage

๐ŸŽฎ Main Entry Points

Script Purpose
embedding_adaptation.py Core script for single experiment runs
run_experiments.py Batch runner for multiple experiments

๐Ÿ”ง Command Examples

Preview commands without execution
python run_experiments.py --dry_run
Enable HyDE (Hypothetical Document Embeddings)
python run_experiments.py \
  --hyde_model "meta-llama/Llama-3.1-8B-Instruct"
Run single experiment with custom parameters
python embedding_adaptation.py \
  --embedding_model "Qwen/Qwen3-Embedding-0.6B" \
  --dataset "NFCorpus" \
  --reranker_model "mistralai/Mistral-Small-3.2-24B-Instruct-2506" \
  --lr 1e-4 \
  --num_steps 100 \
  --total_scores 20 \
  --scores_from_top 20

โš™๏ธ Configuration Parameters

๐ŸŽฏ Single Experiment Parameters (embedding_adaptation.py)

These parameters configure individual experiment runs. They can also be passed to run_experiments.py and will be forwarded to each experiment.

Parameter Description Default
--embedding_model Embedding model to use (single model) Qwen/Qwen3-Embedding-0.6B
--dataset Dataset to use for experiment RealScholarQuery
--reranker_model Reranker model for feedback mistralai/Mistral-Small-3.2-24B-Instruct-2506
--hyde_model LLM for hypothetical document generation (optional) None
--lr Learning rate for query embedding optimization 1e-4
--num_steps Number of optimization steps 100
--total_scores Total documents to sample for reranking signal 20
--scores_from_top Documents sampled from top results (by embedding similarity) 20
--optimize_with_gold Use gold labels instead of reranker scores False
--embedder_batch_size Batch size for embedding inference 10
--reranker_batch_size Batch size for reranker inference 10
--random_seed Random seed for reproducibility 42
--save_tensors Save query trajectory tensors for analysis False
--experiment_name Custom experiment name (auto-generated if not provided) None

๐Ÿ”„ Batch Runner Parameters (run_experiments.py only)

These parameters are specific to the batch runner and control how multiple experiments are executed.

Parameter Description Default
--parallel Number of concurrent experiments to run 1
--experiment_prefix Custom prefix for experiment names None
--embedding_models List of embedding models to test (space-separated) See defaults in script
--datasets List of datasets to test (space-separated) See defaults in script
--dry_run Preview commands without executing False
--continue_on_error Continue running experiments even if one fails False

๐Ÿ“ Output Structure

Each experiment run creates a directory under output/<experiment_name>/ containing:

output/
โ””โ”€โ”€ <experiment_name>/
    โ”œโ”€โ”€ *_results.csv              # Per-topic evaluation metrics
    โ”œโ”€โ”€ *_raw_scores.parquet       # Raw document-level scores
    โ”œโ”€โ”€ tensors/                   # Query trajectory tensors (if enabled)
    โ””โ”€โ”€ config.json                # Experiment configuration metadata

โฑ๏ธ Runtime Estimates

The total runtime of the experiments consists of two main components:

  1. Document embeddings โ€” Embedding all documents in the corpus. In a deployment environment this would typically be done offline.
  2. Per-query computations โ€” Query embedding, refinement optimization, and LLM teacher feedback.

Per-query latency with a GPU is well under a second per query, as shown in the paper. This means just a few minutes to run all the queries in a dataset, assuming an efficient model endpoint for obtaining the LLM feedback scores.

Thus, much of the experiment runtime is devoted to the one-time cost of computing the corpus document embeddings. For convenience, it is possible to run this initial step separately using the script embed_all_documents.py.

Runtime estimates on a single A100-80GB GPU:

  • Using a small embedding model, like Qwen/Qwen3-Embedding-0.6B: Running the full experiment on all datasets, including embedding all corpus documents and optimizing all the queries, should take about an hour.

  • Using larger embedding models, like Qwen/Qwen3-Embedding-8B:

    • Embedding corpus documents with 7B/8B models can range from a few minutes for small datasets (e.g., ArgKP-21) to 3-4 hours for larger datasets like RealScholarQuery or FollowIR.
    • Note that for datasets with long documents (e.g., FollowIR), it may be necessary to use a small --embedder_batch_size to avoid running out of GPU memory.

๐Ÿ“ Notes

๐Ÿ’พ Caching: Embeddings, reranker outputs, and generated texts are cached under /cache to avoid repeated computation.

๐Ÿ”ฌ Research Focus: This implementation is optimized for research and experimentation. For production deployment, consider replacing the file-system cache with a scalable vector-store solution.


๐Ÿ“š Citation

If you use this code in your research, please cite our paper:

@article{gera2026taskadaptive,
  title={Task-Adaptive Embedding Refinement via Test-time LLM Guidance},
  author={Gera, Ariel and Ashury-Tahan, Shir and Bloch, Gal and Eytan, Ohad and Toledo, Assaf},
  year={2026},
  journal={arXiv:2605.12487},
  url={https://arxiv.org/abs/2605.12487},
}

About

Test-time optimization of query embedding representations for zero-shot classification and retrieval

Resources

License

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages