Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pip3 install boto3
2. Download the ZSH plugin.

```bash
git clone https://github.com/tom-doerr/zsh_codex.git ~/.oh-my-zsh/custom/plugins/zsh_codex
git clone https://github.com/tom-doerr/zsh_codex.git ~/.oh-my-zsh/custom/plugins/zsh_codex
```

3. Add the following to your `.zshrc` file.
Expand Down Expand Up @@ -124,6 +124,15 @@ model = gemma2-9b-it
api_type = mistral
api_key = <mistral_apikey>
model = mistral-small-latest

; AWS Bedrock service SSO configuration
; Provide the 'region_name' and 'aws_profile'.
; Use `aws sso login --profile <profile_name>` to login before using.
[bedrock_service]
api_type = bedrock
aws_profile = <profile_name>
aws_region = us-east-1
model = us.anthropic.claude-sonnet-4-5-20250929-v1:0
```

In this configuration file, you can define multiple services with their own configurations. The required and optional parameters of the `api_type` are specified in `services/sevices.py`. Choose which service to use in the `[service]` section.
Expand Down Expand Up @@ -178,10 +187,11 @@ git clone https://github.com/tom-doerr/zsh_codex.git ~/.oh-my-zsh/custom/plugins
## Passing in context

Since the current filesystem is not passed into the ai you will need to either

1. Pass in all context in your descriptive command
2. Use a command to collect the context

In order for option 2 to work you will need to first add `export ZSH_CODEX_PREEXECUTE_COMMENT="true"` to your .zshrc file to enable the feature.
In order for option 2 to work you will need to first add `export ZSH_CODEX_PREEXECUTE_COMMENT="true"` to your .zshrc file to enable the feature.

> [!WARNING]
> This will run your prompt using zsh each time before using it, which could potentially modify your system when you hit ^X.
Expand Down
73 changes: 46 additions & 27 deletions services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BaseClient(ABC):
"""Base class for all clients"""

api_type: str = None
system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block."
system_prompt = "You are a zsh shell expert. Help me write and complete the following command. Make sure to only output the completed command as your output will be used directly in the shell. Use optimal functions (use rg over grep when applicable). Do not include any other explanation.Do not use wrap the output in a code block."

@abstractmethod
def get_completion(self, full_command: str) -> str:
Expand Down Expand Up @@ -101,10 +101,10 @@ class GroqClient(BaseClient):
- model (optional): defaults to "llama-3.2-11b-text-preview"
- temperature (optional): defaults to 1.0.
"""

api_type = "groq"
default_model = os.getenv("GROQ_DEFAULT_MODEL", "llama-3.2-11b-text-preview")

def __init__(self, config: dict):
try:
from groq import Groq
Expand All @@ -119,7 +119,7 @@ def __init__(self, config: dict):
self.client = Groq(
api_key=self.config["api_key"],
)

def get_completion(self, full_command: str) -> str:
response = self.client.chat.completions.create(
model=self.config["model"],
Expand All @@ -140,10 +140,10 @@ class MistralClient(BaseClient):
- model (optional): defaults to "codestral-latest"
- temperature (optional): defaults to 1.0.
"""

api_type = "mistral"
default_model = os.getenv("MISTRAL_DEFAULT_MODEL", "codestral-latest")

def __init__(self, config: dict):
try:
from mistralai import Mistral
Expand All @@ -152,13 +152,13 @@ def __init__(self, config: dict):
"Mistral library is not installed. Please install it using 'pip install mistralai'"
)
sys.exit(1)

self.config = config
self.config["model"] = self.config.get("model", self.default_model)
self.client = Mistral(
api_key=self.config["api_key"],
)

def get_completion(self, full_command: str) -> str:
response = self.client.chat.complete(
model=self.config["model"],
Expand All @@ -170,20 +170,27 @@ def get_completion(self, full_command: str) -> str:
)
return response.choices[0].message.content


class AmazonBedrock(BaseClient):
"""
config keys:
- api_type="bedrock"
- aws_region (optional): defaults to environment variable AWS_REGION
- aws_access_key_id (optional): defaults to environment variable AWS_ACCESS_KEY_ID
- aws_secret_access_key (optional): defaults to environment variable AWS_SECRET_ACCESS_KEY
- aws_session_token (optional): defaults to environment variable AWS_SESSION_TOKEN
- aws_profile (optional): AWS profile name for SSO or other credentials from ~/.aws/config
- aws_region (optional): defaults to profile/environment variable AWS_REGION
- aws_access_key_id (optional): explicit credentials (not needed for SSO)
- aws_secret_access_key (optional): explicit credentials (not needed for SSO)
- aws_session_token (optional): explicit credentials (not needed for SSO)
- model (optional): defaults to "anthropic.claude-3-5-sonnet-20240620-v1:0" or environment variable BEDROCK_DEFAULT_MODEL
- temperature (optional): defaults to 1.0.

For AWS SSO: Set aws_profile to your SSO profile name, then run `aws sso login --profile <profile_name>`
If no credentials are specified, boto3 will use the default credential chain (environment, config files, SSO, IAM roles, etc.)
"""

api_type = "bedrock"
default_model = os.getenv("BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0")
default_model = os.getenv(
"BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0"
)

def __init__(self, config: dict):
try:
Expand All @@ -197,24 +204,32 @@ def __init__(self, config: dict):
self.config = config
self.config["model"] = self.config.get("model", self.default_model)

# Create session with profile if specified (for SSO support)
session_kwargs = {}
if "aws_profile" in config:
session_kwargs["profile_name"] = config["aws_profile"]

session = boto3.Session(**session_kwargs)

# Create client kwargs
client_kwargs = {}
if "aws_region" in config:
session_kwargs["region_name"] = config["aws_region"]
client_kwargs["region_name"] = config["aws_region"]

# Only use explicit credentials if provided (not needed for SSO)
if "aws_access_key_id" in config:
session_kwargs["aws_access_key_id"] = config["aws_access_key_id"]
client_kwargs["aws_access_key_id"] = config["aws_access_key_id"]
if "aws_secret_access_key" in config:
session_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"]
client_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"]
if "aws_session_token" in config:
session_kwargs["aws_session_token"] = config["aws_session_token"]
client_kwargs["aws_session_token"] = config["aws_session_token"]

self.client = boto3.client("bedrock-runtime", **session_kwargs)
self.client = session.client("bedrock-runtime", **client_kwargs)

def get_completion(self, full_command: str) -> str:
import json

messages = [
{"role": "user", "content": full_command}
]
messages = [{"role": "user", "content": full_command}]

# Format request body based on model type
if "claude" in self.config["model"].lower():
Expand All @@ -223,23 +238,27 @@ def get_completion(self, full_command: str) -> str:
"max_tokens": 1000,
"system": self.system_prompt,
"messages": messages,
"temperature": float(self.config.get("temperature", 1.0))
"temperature": float(self.config.get("temperature", 1.0)),
}
else:
raise ValueError(f"Unsupported model: {self.config['model']}")

response = self.client.invoke_model(
modelId=self.config["model"],
body=json.dumps(body)
modelId=self.config["model"], body=json.dumps(body)
)

response_body = json.loads(response['body'].read())
response_body = json.loads(response["body"].read())
return response_body["content"][0]["text"]



class ClientFactory:
api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type]
api_types = [
OpenAIClient.api_type,
GoogleGenAIClient.api_type,
GroqClient.api_type,
MistralClient.api_type,
AmazonBedrock.api_type,
]

@classmethod
def create(cls):
Expand Down