Skip to content

Commit 23e8ecf

Browse files
authored
Merge pull request #169 from RedHitMark/main
fix lora and multi-lora
2 parents 6db2c44 + 084d000 commit 23e8ecf

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ Below is a summary of the available RunPod Worker images, categorized by image s
139139
| `LONG_LORA_SCALING_FACTORS` | None | `tuple` | Specify multiple scaling factors for LoRA adapters. |
140140
| `MAX_CPU_LORAS` | None | `int` | Maximum number of LoRAs to store in CPU memory. |
141141
| `FULLY_SHARDED_LORAS` | False | `bool` | Enable fully sharded LoRA layers. |
142+
| `LORA_MODULES`| `[]`| `list[dict]`| Add lora adapters from Hugging Face `[{"name": "xx", "path": "xxx/xxxx", "base_model_name": "xxx/xxxx"}`|
142143
| `SCHEDULER_DELAY_FACTOR` | 0.0 | `float` | Apply a delay before scheduling next prompt. |
143144
| `ENABLE_CHUNKED_PREFILL` | False | `bool` | Enable chunked prefill requests. |
144145
| `SPECULATIVE_MODEL` | None | `str` | The name of the draft model to be used in speculative decoding. |

src/engine.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,31 +122,40 @@ def __init__(self, vllm_engine):
122122
super().__init__(vllm_engine)
123123
self.served_model_name = os.getenv("OPENAI_SERVED_MODEL_NAME_OVERRIDE") or self.engine_args.model
124124
self.response_role = os.getenv("OPENAI_RESPONSE_ROLE") or "assistant"
125+
self.lora_adapters = self._load_lora_adapters()
125126
asyncio.run(self._initialize_engines())
126127
self.raw_openai_output = bool(int(os.getenv("RAW_OPENAI_OUTPUT", 1)))
127-
128+
129+
def _load_lora_adapters(self):
130+
adapters = []
131+
try:
132+
adapters = json.loads(os.getenv("LORA_MODULES", '[]'))
133+
except Exception as e:
134+
logging.info(f"---Initialized adapter json load error: {e}")
135+
136+
for i, adapter in enumerate(adapters):
137+
try:
138+
adapters[i] = LoRAModulePath(**adapter)
139+
logging.info(f"---Initialized adapter: {adapter}")
140+
except Exception as e:
141+
logging.info(f"---Initialized adapter not worked: {e}")
142+
continue
143+
return adapters
144+
128145
async def _initialize_engines(self):
129146
self.model_config = await self.llm.get_model_config()
130147
self.base_model_paths = [
131148
BaseModelPath(name=self.engine_args.model, model_path=self.engine_args.model)
132149
]
133150

134-
lora_modules = os.getenv('LORA_MODULES', None)
135-
if lora_modules is not None:
136-
try:
137-
lora_modules = json.loads(lora_modules)
138-
lora_modules = [LoRAModulePath(**lora_modules)]
139-
except:
140-
lora_modules = None
141-
142151
self.serving_models = OpenAIServingModels(
143152
engine_client=self.llm,
144153
model_config=self.model_config,
145154
base_model_paths=self.base_model_paths,
146-
lora_modules=None,
155+
lora_modules=self.lora_adapters,
147156
prompt_adapters=None,
148157
)
149-
158+
await self.serving_models.init_static_loras()
150159
self.chat_engine = OpenAIServingChat(
151160
engine_client=self.llm,
152161
model_config=self.model_config,

0 commit comments

Comments
 (0)