Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions core/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ type ModelConfig struct {
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
Name string `yaml:"name,omitempty" json:"name,omitempty"`

F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`

PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
Expand Down Expand Up @@ -294,8 +294,9 @@ func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
if err := value.Decode(&aux); err != nil {
return err
}
*c = ModelConfig(aux)

mc := ModelConfig(aux)
*c = mc
c.syncKnownUsecasesFromString()
return nil
}
Expand Down Expand Up @@ -514,30 +515,30 @@ func (c *ModelConfig) GetModelConfigFile() string {
return c.modelConfigFile
}

type ModelConfigUsecases int
type ModelConfigUsecase int

const (
FLAG_ANY ModelConfigUsecases = 0b000000000000
FLAG_CHAT ModelConfigUsecases = 0b000000000001
FLAG_COMPLETION ModelConfigUsecases = 0b000000000010
FLAG_EDIT ModelConfigUsecases = 0b000000000100
FLAG_EMBEDDINGS ModelConfigUsecases = 0b000000001000
FLAG_RERANK ModelConfigUsecases = 0b000000010000
FLAG_IMAGE ModelConfigUsecases = 0b000000100000
FLAG_TRANSCRIPT ModelConfigUsecases = 0b000001000000
FLAG_TTS ModelConfigUsecases = 0b000010000000
FLAG_SOUND_GENERATION ModelConfigUsecases = 0b000100000000
FLAG_TOKENIZE ModelConfigUsecases = 0b001000000000
FLAG_VAD ModelConfigUsecases = 0b010000000000
FLAG_VIDEO ModelConfigUsecases = 0b100000000000
FLAG_DETECTION ModelConfigUsecases = 0b1000000000000
FLAG_ANY ModelConfigUsecase = 0b000000000000
FLAG_CHAT ModelConfigUsecase = 0b000000000001
FLAG_COMPLETION ModelConfigUsecase = 0b000000000010
FLAG_EDIT ModelConfigUsecase = 0b000000000100
FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000
FLAG_RERANK ModelConfigUsecase = 0b000000010000
FLAG_IMAGE ModelConfigUsecase = 0b000000100000
FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000
FLAG_TTS ModelConfigUsecase = 0b000010000000
FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000
FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000
FLAG_VAD ModelConfigUsecase = 0b010000000000
FLAG_VIDEO ModelConfigUsecase = 0b100000000000
FLAG_DETECTION ModelConfigUsecase = 0b1000000000000

// Common Subsets
FLAG_LLM ModelConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
)

func GetAllModelConfigUsecases() map[string]ModelConfigUsecases {
return map[string]ModelConfigUsecases{
func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
return map[string]ModelConfigUsecase{
// Note: FLAG_ANY is intentionally excluded from this map
// because it's 0 and would always match in HasUsecases checks
"FLAG_CHAT": FLAG_CHAT,
Expand All @@ -561,23 +562,25 @@ func stringToFlag(s string) string {
return "FLAG_" + strings.ToUpper(s)
}

func GetUsecasesFromYAML(input []string) *ModelConfigUsecases {
func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
if len(input) == 0 {
return nil
}
result := FLAG_ANY
flags := GetAllModelConfigUsecases()
for _, str := range input {
flag, exists := flags[stringToFlag(str)]
if exists {
result |= flag
for _, flag := range []string{stringToFlag(str), str} {
f, exists := flags[flag]
if exists {
result |= f
}
}
}
return &result
}

// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) {
return true
}
Expand All @@ -587,7 +590,7 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
// GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at.
// In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half.
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
return false
Expand Down
2 changes: 1 addition & 1 deletion core/config/model_config_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) {
}, nil
}

func BuildUsecaseFilterFn(usecases ModelConfigUsecases) ModelConfigFilterFn {
func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn {
if usecases == FLAG_ANY {
return NoFilterFn
}
Expand Down
Loading