chatbot-rag / config.py
GitHub Actions
Auto-deploy from GitHub Actions: 8db2ada1034ca253ed7c30bd9fac024191d82bd6
406ae9b
import os
import sys
from pathlib import Path
from typing import Any, ClassVar
from chromadb import EmbeddingFunction
class ModelStorage:
'''Global model storage'''
EMBED_MODEL: ClassVar[dict[str, EmbeddingFunction]] = {}
class UiGradioConfig:
'''Gradio settings for gr.Blocks()'''
css: str | None = '''
.gradio-container {
width: 70% !important;
margin: 0 auto !important;
}
'''
if hasattr(sys, 'getandroidapilevel') or 'ANDROID_ROOT' in os.environ:
css = None
theme: str | None = None
fill_height: bool = False
footer_links: list[str] = ['gradio', 'settings']
delete_cache: tuple[int, int] | None = None
@classmethod
def get_demo_launch_kwargs(cls):
return dict(
css=cls.css,
theme=cls.theme,
footer_links=cls.footer_links,
)
@classmethod
def get_demo_blocks_kwargs(cls):
return dict(
fill_height=cls.fill_height,
delete_cache=cls.delete_cache,
)
class InferenceConfig:
'''Model inference settings'''
def __init__(self):
self.encode_kwargs: dict[str, Any] = dict(
batch_size=300,
normalize_embeddings=None,
)
self.reasoning_format = 'none'
self.sampling_kwargs: dict[str, Any] = dict(
temperature=0.2,
top_p=0.9,
top_k=40,
repeat_penalty=1,
)
self.max_tokens = -1
self.do_sample: bool = False
self.rag_mode: bool = False
self.history_len: int = 0
self.enable_thinking = False
self.show_thinking: bool = False
self.resize_size = 256
class TextLoadConfig:
'''Settings for loading texts from documents'''
def __init__(self):
self.partition_kwargs: dict[str, str | int | bool | None] = dict(
chunking_strategy='basic',
max_characters=800,
new_after_n_chars=500,
overlap=0,
clean=True,
bullets=True,
extra_whitespace=True,
dashes=False,
trailing_punctuation=True,
lowercase=False,
)
self.SUPPORTED_FILE_EXTS: str = '.csv .tsv .docx .md .org .pdf .pptx .xlsx'
self.subtitle_lang: str = 'ru'
self.SUBTITLE_LANGS: list[str] = ['ru', 'en']
self.max_lines_text_view: int = 200
class DbConfig:
'''Vector database parameters (Chroma)'''
def __init__(self):
self.create_collection_kwargs: dict[str, Any] = dict(
configuration=dict(
hnsw=dict(
space='cosine', # l2, ip, cosine, default l2
ef_construction=200,
)
)
)
self.query_kwargs: dict[str, Any] = dict(
n_results=2,
max_distance_treshold=0.5,
)
class PromptConfig:
'''Prompts'''
def __init__(self):
self.system_prompt: str | None = None
self.user_msg_with_context: str = ''
self.context_template: str = '''Ответь на вопрос при условии контекста.
Контекст:
{context}
Вопрос:
{user_message}'''
class ModelConfig:
'''Configuration of paths, models and generation parameters'''
def __init__(self):
self.LLM_MODEL_DIR: Path = Path(os.getenv('LLAMA_CACHE', 'llm_models'))
self.EMBED_MODEL_DIR: Path = Path(os.getenv('EMBED_MODEL_DIR', 'embed_models'))
self.LLM_MODEL_DIR.mkdir(exist_ok=True)
self.EMBED_MODEL_DIR.mkdir(exist_ok=True)
self.llm_model_repo: str = None
self.llm_model_file: str = None
self.llm_model_mmproj: str = None
self.embed_model_repo: str = os.getenv('EMBED_MODEL_REPO', 'Alibaba-NLP/gte-multilingual-base')
self.embed_model_kwargs: dict[str, Any] = dict(
cache_folder=self.EMBED_MODEL_DIR,
trust_remote_code=True,
token=os.getenv('HF_TOKEN'),
model_kwargs=dict(
dtype='auto',
device_map='auto',
)
)
self.llm_model_kwargs: dict[str, Any] = dict(
n_gpu_layers=-1,
n_ctx=4096,
)
class ReposConfig:
'''Links to repositories with ggu models'''
def __init__(self):
self.llm_model_repos: list[str] = [
'bartowski/google_gemma-3-1b-it-GGUF',
'bartowski/google_gemma-3-4b-it-GGUF',
'bartowski/Qwen_Qwen3-1.7B-GGUF',
'bartowski/Qwen_Qwen3-4B-GGUF',
'bartowski/Qwen_Qwen3-0.6B-GGUF',
]
self.embed_model_repos: list[str] = [
'Alibaba-NLP/gte-multilingual-base',
'sergeyzh/rubert-tiny-turbo',
'intfloat/multilingual-e5-large',
'intfloat/multilingual-e5-base',
'intfloat/multilingual-e5-small',
'intfloat/multilingual-e5-large-instruct',
'sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
'ai-forever/ruElectra-medium',
'ai-forever/sbert_large_nlu_ru',
'deepvk/USER2-small',
'BAAI/bge-m3-retromae',
]
class Config:
'''General config'''
def __init__(self):
self.Inference: InferenceConfig = InferenceConfig()
self.TextLoad: TextLoadConfig = TextLoadConfig()
self.Prompt: PromptConfig = PromptConfig()
self.Db: DbConfig = DbConfig()
self.Model: ModelConfig = ModelConfig()
self.Repos: ReposConfig = ReposConfig()
self.CHATBOT_RAG_ENABLED = os.getenv('CHATBOT_RAG_ENABLED', '1').lower() in ('1', 'true')
self.generation_kwargs: dict[str, Any] = dict(
do_sample=self.Inference.do_sample,
temperature=self.Inference.sampling_kwargs['temperature'],
top_p=self.Inference.sampling_kwargs['top_p'],
top_k=self.Inference.sampling_kwargs['top_k'],
repeat_penalty=self.Inference.sampling_kwargs['repeat_penalty'],
max_tokens=self.Inference.max_tokens,
history_len=self.Inference.history_len,
system_prompt=self.Prompt.system_prompt,
context_template=self.Prompt.context_template,
enable_thinking=self.Inference.enable_thinking,
show_thinking=self.Inference.show_thinking,
n_results=self.Db.query_kwargs['n_results'],
max_distance_treshold=self.Db.query_kwargs['max_distance_treshold'],
user_msg_with_context=self.Prompt.user_msg_with_context,
rag_mode=self.Inference.rag_mode,
)
self.load_text_kwargs: dict[str, Any] = dict(
chunking_strategy=self.TextLoad.partition_kwargs['chunking_strategy'],
max_characters=self.TextLoad.partition_kwargs['max_characters'],
new_after_n_chars=self.TextLoad.partition_kwargs['new_after_n_chars'],
overlap=self.TextLoad.partition_kwargs['overlap'],
clean=self.TextLoad.partition_kwargs['clean'],
bullets=self.TextLoad.partition_kwargs['bullets'],
extra_whitespace=self.TextLoad.partition_kwargs['extra_whitespace'],
dashes=self.TextLoad.partition_kwargs['dashes'],
trailing_punctuation=self.TextLoad.partition_kwargs['trailing_punctuation'],
lowercase=self.TextLoad.partition_kwargs['lowercase'],
subtitle_lang=self.TextLoad.subtitle_lang,
)
self.load_model_kwargs: dict[str, Any] = dict(
llm_model_repo=self.Model.llm_model_repo,
llm_model_file=self.Model.llm_model_file,
llm_model_mmproj=self.Model.llm_model_mmproj,
embed_model_repo=self.Model.embed_model_repo,
n_gpu_layers=self.Model.llm_model_kwargs['n_gpu_layers'],
n_ctx=self.Model.llm_model_kwargs['n_ctx'],
)
self.view_text_kwargs: dict[str, Any] = dict(
max_lines_text_view=self.TextLoad.max_lines_text_view,
)
def update_env(self, **kwargs) -> None:
LLAMA_ARG_MMPROJ=self.load_model_kwargs['llm_model_mmproj']
dict_to_updating = dict(
LLAMA_ARG_CTX_SIZE=str(self.load_model_kwargs['n_ctx']),
LLAMA_ARG_N_GPU_LAYERS=str(self.load_model_kwargs['n_gpu_layers']),
)
for k, v in kwargs.items():
dict_to_updating[k] = str(v)
os.environ.pop('LLAMA_ARG_MODEL_URL', None)
os.environ.pop('LLAMA_ARG_MMPROJ_URL', None)
os.environ.pop('LLAMA_ARG_HF_REPO', None)
os.environ.pop('LLAMA_ARG_HF_FILE', None)
os.environ.pop('LLAMA_ARG_MMPROJ', None)
if LLAMA_ARG_MMPROJ:
dict_to_updating['LLAMA_ARG_MMPROJ'] = str(self.Model.LLM_MODEL_DIR / LLAMA_ARG_MMPROJ)
os.environ.update(dict_to_updating)
def get_completions_kwargs(self) -> dict[str, Any]:
return dict(
temperature=self.generation_kwargs['temperature'],
top_p=self.generation_kwargs['top_p'],
max_tokens=self.generation_kwargs['max_tokens'],
extra_body=dict(
top_k=self.generation_kwargs['top_k'],
repeat_penalty=self.generation_kwargs['repeat_penalty'],
reasoning_format=self.Inference.reasoning_format,
chat_template_kwargs=dict(
enable_thinking=self.generation_kwargs['enable_thinking'],
),
),
)
def get_rag_kwargs(self) -> dict[str, Any]:
return dict(
n_results=self.generation_kwargs['n_results'],
max_distance_treshold=self.generation_kwargs['max_distance_treshold'],
user_msg_with_context=self.generation_kwargs['user_msg_with_context'],
context_template=self.generation_kwargs['context_template'],
)
def get_partition_kwargs(self) -> dict[str, Any]:
return dict(
chunking_strategy=self.load_text_kwargs['chunking_strategy'],
max_characters=self.load_text_kwargs['max_characters'],
new_after_n_chars=self.load_text_kwargs['new_after_n_chars'],
overlap=self.load_text_kwargs['overlap'],
clean=self.load_text_kwargs['clean'],
bullets=self.load_text_kwargs['bullets'],
extra_whitespace=self.load_text_kwargs['extra_whitespace'],
dashes=self.load_text_kwargs['dashes'],
trailing_punctuation=self.load_text_kwargs['trailing_punctuation'],
lowercase=self.load_text_kwargs['lowercase'],
)
def get_clean_kwargs(self) -> dict[str, Any]:
return dict(
bullets=self.load_text_kwargs['bullets'],
extra_whitespace=self.load_text_kwargs['extra_whitespace'],
dashes=self.load_text_kwargs['dashes'],
trailing_punctuation=self.load_text_kwargs['trailing_punctuation'],
lowercase=self.load_text_kwargs['lowercase'],
)
def get_chunking_kwargs(self):
return dict(
max_characters=self.load_text_kwargs['max_characters'],
new_after_n_chars=self.load_text_kwargs['new_after_n_chars'],
overlap=self.load_text_kwargs['overlap'],
)
def get_embed_model_kwargs(self) -> dict[str, Any]:
return self.Model.embed_model_kwargs
def get_encode_kwargs(self) -> dict[str, Any]:
return self.Inference.encode_kwargs
def get_llm_model_kwargs(self) -> dict[str, Any]:
return self.Model.llm_model_kwargs
def get_query_kwargs(self) -> dict[str, Any]:
return dict(
n_results=self.generation_kwargs['n_results'],
max_distance_treshold=self.generation_kwargs['max_distance_treshold'],
)