Skip to content

Commit 37afdf9

Browse files
authored
Merge pull request #21 from jax-ml/kimi-k2-draft
Initial kimi k2 draft
2 parents 85804f8 + 2dbb7a0 commit 37afdf9

17 files changed

+4448
-0
lines changed

kimi_k2/.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Python ignores
2+
__pycache__/
3+
*.pyc
4+
*.egg-info
5+
build/**
6+
7+
.venv
8+
.vscode

kimi_k2/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Minimal Kimi K2 inference
2+
3+
**tl;dr: open-source Kimi K2 inference using JAX, minimal yet performant**
4+
5+
This model is a work in progress.
6+
7+
<br/>
8+
9+
This is a pure JAX implementation of Kimi K2 for inference, including a
10+
checkpoint converter for the K2 Instruct weights. on TPU.
11+
It should work on GPU.
12+
13+
The entire model is defined in [model.py](kimi_k2_jax/model.py) and invoked
14+
via [main.py](main.py). Among other things, the model code demonstrates:
15+
* an MLA attention implementation;
16+
* expert and tensor-parallelism via JAX's
17+
[`shard_map`](https://docs.jax.dev/en/latest/sharded-computation.html#manual-parallelism-with-shard-map)
18+
for easy multi-device/multi-host computation; and
19+
* simple int8 quantization.
20+
21+
## Quickstart
22+
23+
Due to the large model size (1T parameters), a multi-host platform is required to run
24+
the full model.
25+
26+
Run on all hosts in the TPU cluster:
27+
```
28+
$ python3 main.py
29+
```
30+
e.g. for Cloud TPU:
31+
```
32+
$ gcloud compute tpus tpu-vm ssh {TPU_NAME} --worker=all \
33+
--command="cd ~/jax-llm-examples/kimi_k2 && python3 main.py"
34+
```

kimi_k2/kimi_k2_jax/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)