Skip to content

Commit 46833a3

Browse files
committed
Update to Gemini 2.0
1 parent 00e46c9 commit 46833a3

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from logging import getLogger
99
from pathlib import Path
1010
import tempfile
11-
from typing import Literal
1211
from google.cloud import storage
1312
from google.cloud.storage import transfer_manager
1413
import proto
@@ -19,6 +18,7 @@
1918
HarmBlockThreshold,
2019
GenerationResponse,
2120
)
21+
from enum import StrEnum
2222
from ..base_llm import LLM, Message
2323
from ..error_handling import notify_bugsnag
2424

@@ -27,7 +27,7 @@
2727
logger = getLogger(__name__)
2828

2929
project_id = "css-lehrbereich" # from google cloud console
30-
frankfurt = "europe-west3" # https://cloud.google.com/about/locations#europe
30+
location = "europe-west1" # https://cloud.google.com/about/locations#europe
3131

3232

3333
class Buckets:
@@ -40,18 +40,23 @@ def storage_uri(bucket: str, blob_name: str) -> str:
4040
return "gs://%s/%s" % (bucket, blob_name)
4141

4242

43-
class Models:
44-
gemini_pro = "models/gemini-1.5-pro"
45-
gemini_flash = "models/gemini-1.5-flash"
43+
class GeminiModels(StrEnum):
44+
gemini_15_pro = "models/gemini-1.5-pro"
45+
gemini_20_flash = "models/gemini-2.0-flash"
46+
gemini_20_flash_lite = "models/gemini-2.0-flash-lite"
4647

4748

48-
available_models = [Models.gemini_pro, Models.gemini_flash]
49+
available_models = [
50+
GeminiModels.gemini_15_pro,
51+
GeminiModels.gemini_20_flash,
52+
GeminiModels.gemini_20_flash_lite,
53+
]
4954

5055

5156
@dataclass
5257
class Request:
5358
media_files: list[Path]
54-
model_name: Literal[Models.gemini_pro, Models.gemini_flash] = Models.gemini_pro
59+
model_name: GeminiModels = GeminiModels.gemini_15_pro
5560
prompt: str = "Describe this video in detail."
5661
max_output_tokens: int = 1000
5762

@@ -101,7 +106,7 @@ def fetch_media_description(req: Request) -> str:
101106

102107

103108
def init_vertex() -> None:
104-
vertexai.init(project=project_id, location=frankfurt)
109+
vertexai.init(project=project_id, location=location)
105110

106111

107112
def mime_type(file_name: str) -> str:
@@ -174,11 +179,11 @@ class ResponseRefusedException(Exception):
174179

175180
@dataclass
176181
class GeminiAPI(LLM):
177-
model_id: str = Models.gemini_pro
182+
model_id: str = GeminiModels.gemini_20_flash_lite
178183
max_output_tokens: int = 1000
179184

180185
requires_gpu_exclusively = False
181-
model_ids = [Models.gemini_pro, Models.gemini_flash]
186+
model_ids = available_models
182187

183188
def complete_msgs(self, msgs: list[Message]) -> str:
184189
if len(msgs) != 1:

llmlib/llmlib/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .replicate_api import Apollo7B
22
from .internvl import InternVL
3-
from .gemini.media_description import GeminiAPI
3+
from .gemini.gemini_code import GeminiAPI
44
from .gemma import PaliGemma2
55
from .minicpm import MiniCPM
66
from .llama3 import LLama3Vision8B

tests/test_gemini.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from llmlib.gemini.media_description import GeminiAPI, Models, Request
2+
from llmlib.gemini.gemini_code import GeminiAPI, GeminiModels, Request
33
import pytest
44

55
from tests.helpers import (
@@ -23,7 +23,7 @@ def test_gemini_vision():
2323
assert path.exists()
2424

2525
req = Request(
26-
model_name=Models.gemini_flash,
26+
model_name=GeminiModels.gemini_20_flash,
2727
media_files=files,
2828
prompt="Describe this combined images/audio/text in detail.",
2929
)
@@ -35,7 +35,7 @@ def test_gemini_vision():
3535

3636
@pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
3737
def test_gemini_vision_using_interface():
38-
model = GeminiAPI(model_id=Models.gemini_flash, max_output_tokens=50)
38+
model = GeminiAPI(model_id=GeminiModels.gemini_20_flash_lite, max_output_tokens=50)
3939
assert_model_knows_capital_of_france(model)
4040
assert_model_recognizes_pyramid_in_image(model)
4141
assert_model_recognizes_afd_in_video(model)

0 commit comments

Comments
 (0)