diff --git a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py index 61118b743..000478afb 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/lightweight.py +++ b/FlagEmbedding/inference/reranker/decoder_only/lightweight.py @@ -1,4 +1,5 @@ import torch +import sys import warnings import numpy as np from tqdm import trange @@ -10,8 +11,6 @@ from FlagEmbedding.abc.inference import AbsReranker from FlagEmbedding.inference.reranker.encoder_only.base import sigmoid -from .models.gemma_model import CostWiseGemmaForCausalLM - def last_logit_pool_lightweight(logits: Tensor, attention_mask: Tensor) -> Tensor: @@ -144,6 +143,15 @@ def __init__( normalize: bool = False, **kwargs: Any, ) -> None: + try: + from .models.gemma_model import CostWiseGemmaForCausalLM + except: + print('*') * 20 + print('*') * 20 + print('error for load lightweight reranker, please install transformers==4.46.0') + print('*') * 20 + print('*') * 20 + sys.exit() super().__init__( model_name_or_path=model_name_or_path, diff --git a/research/Matroyshka_reranker/finetune/compensation/__init__.py b/research/Matroyshka_reranker/finetune/compensation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/Matroyshka_reranker/requirements.txt b/research/Matroyshka_reranker/requirements.txt index b5de2fc43..8e7ba3481 100644 --- a/research/Matroyshka_reranker/requirements.txt +++ b/research/Matroyshka_reranker/requirements.txt @@ -7,12 +7,12 @@ func_timeout==4.3.5 pandas==2.2.1 sqlglot==22.1.1 rank_bm25==0.2.2 +peft==0.10.0 transformers==4.41.1 jinja2 datasets sentencepiece flash-attn modelscope -peft deepspeed bitsandbytes \ No newline at end of file