88from logging import getLogger
99from pathlib import Path
1010import tempfile
11- from typing import Literal
1211from google .cloud import storage
1312from google .cloud .storage import transfer_manager
1413import proto
1918 HarmBlockThreshold ,
2019 GenerationResponse ,
2120)
21+ from enum import StrEnum
2222from ..base_llm import LLM , Message
2323from ..error_handling import notify_bugsnag
2424
2727logger = getLogger (__name__ )
2828
2929project_id = "css-lehrbereich" # from google cloud console
30- frankfurt = "europe-west3 " # https://cloud.google.com/about/locations#europe
30+ location = "europe-west1 " # https://cloud.google.com/about/locations#europe
3131
3232
3333class Buckets :
@@ -40,18 +40,23 @@ def storage_uri(bucket: str, blob_name: str) -> str:
4040 return "gs://%s/%s" % (bucket , blob_name )
4141
4242
43- class Models :
44- gemini_pro = "models/gemini-1.5-pro"
45- gemini_flash = "models/gemini-1.5-flash"
43+ class GeminiModels (StrEnum ):
44+ gemini_15_pro = "models/gemini-1.5-pro"
45+ gemini_20_flash = "models/gemini-2.0-flash"
46+ gemini_20_flash_lite = "models/gemini-2.0-flash-lite"
4647
4748
48- available_models = [Models .gemini_pro , Models .gemini_flash ]
49+ available_models = [
50+ GeminiModels .gemini_15_pro ,
51+ GeminiModels .gemini_20_flash ,
52+ GeminiModels .gemini_20_flash_lite ,
53+ ]
4954
5055
5156@dataclass
5257class Request :
5358 media_files : list [Path ]
54- model_name : Literal [ Models . gemini_pro , Models . gemini_flash ] = Models . gemini_pro
59+ model_name : GeminiModels = GeminiModels . gemini_15_pro
5560 prompt : str = "Describe this video in detail."
5661 max_output_tokens : int = 1000
5762
@@ -101,7 +106,7 @@ def fetch_media_description(req: Request) -> str:
101106
102107
103108def init_vertex () -> None :
104- vertexai .init (project = project_id , location = frankfurt )
109+ vertexai .init (project = project_id , location = location )
105110
106111
107112def mime_type (file_name : str ) -> str :
@@ -174,11 +179,11 @@ class ResponseRefusedException(Exception):
174179
175180@dataclass
176181class GeminiAPI (LLM ):
177- model_id : str = Models . gemini_pro
182+ model_id : str = GeminiModels . gemini_20_flash_lite
178183 max_output_tokens : int = 1000
179184
180185 requires_gpu_exclusively = False
181- model_ids = [ Models . gemini_pro , Models . gemini_flash ]
186+ model_ids = available_models
182187
183188 def complete_msgs (self , msgs : list [Message ]) -> str :
184189 if len (msgs ) != 1 :
0 commit comments