# -*- coding: utf-8 -*-
"""
SQLite-RAG スターター（高速モード/ストリーミング対応版）

元コード: https://github.com/amaiya/sqlite-rag
ライセンス: MIT License（© 2024 amaiya）
本ファイルは上記を基に改変・再構成したものです。

- ベクトルDB: sqlite-vec（LangChainのSQLiteVec）
- 埋め込み: HuggingFaceEmbeddings（MiniLM 多言語）
- 生成: LM Studio (OpenAI互換API) 既定: gemma-3-4b-it
- 体感を上げる工夫:
  * Embedding/VectorStore/OpenAIClient/SQLite拡張の初期化は lru_cache で1回だけ
  * --fast で MMR を省略・短文コンテキスト・ストリーミング出力
  * 通常モードは MMR フォールバックで精度寄り

使い方:
  1) 取り込み:  python sqlite_rag.py ingest --data .\data
  2) 一発質問:  python sqlite_rag.py ask --q "…" --k 4 --fast
  3) 常駐REPL:  python sqlite_rag.py repl --fast
"""

import os, sys, argparse, textwrap, time, math
import sqlite3
from pathlib import Path
from functools import lru_cache

# pip install: langchain langchain-community langchain-text-splitters langchain-huggingface pypdf sqlite-vec openai sentence-transformers tqdm
import sqlite_vec
from langchain_community.vectorstores import SQLiteVec
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from pypdf import PdfReader
from openai import OpenAI

DB_PATH = os.environ.get("RAG_DB", "rag.db")
TABLE   = os.environ.get("RAG_TABLE", "docs")

# LM Studio (OpenAI互換) / OpenAI公式の両対応にするため base_url を可変に
OPENAI_BASE    = os.environ.get("OPENAI_BASE", "http://localhost:1234/v1")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "lm-studio")
OPENAI_MODEL   = os.environ.get("OPENAI_MODEL", "gemma-3-4b-it")

GEN_MAX_TOKENS_ENV = os.environ.get("GEN_MAX_TOKENS", "192")
FAST_MODE_DEFAULT = bool(int(os.environ.get("FAST_MODE", "0")))

def load_texts_from_folder(data_dir: str):
    docs = []
    for path in Path(data_dir).rglob("*"):
        if not path.is_file():
            continue
        try:
            if path.suffix.lower() in [".txt", ".md", ".csv"]:
                txt = path.read_text(encoding="utf-8", errors="ignore")
                if txt.strip():
                    docs.append(Document(page_content=txt, metadata={"path": str(path)}))
            elif path.suffix.lower() == ".pdf":
                reader = PdfReader(str(path))
                pages = []
                for page in reader.pages:
                    pages.append(page.extract_text() or "")
                txt = "\n".join(pages)
                if txt.strip():
                    docs.append(Document(page_content=txt, metadata={"path": str(path)}))
        except Exception as e:
            print(f"[WARN] 読み込み失敗: {path} ({e})")
    return docs

def get_text_splitter():
    return RecursiveCharacterTextSplitter(
        chunk_size=1200, chunk_overlap=80, length_function=len,
        separators=["\n\n", "\n", "。", "、", " ", ""]
    )

def _is_summary_query(q: str) -> bool:
    kw = ["要旨", "要約", "まとめ", "ポイント", "3点", "5点", "特長", "特徴"]
    return any(k in q for k in kw)

@lru_cache(maxsize=1)
def get_conn():
    conn = sqlite3.connect(DB_PATH, check_same_thread=False)
    conn.row_factory = sqlite3.Row
    conn.enable_load_extension(True)
    sqlite_vec.load(conn)
    return conn

@lru_cache(maxsize=1)
def get_embedder():
    return HuggingFaceEmbeddings(
        model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
        # GPUがあれば: model_kwargs={"device": "cuda"}
    )

@lru_cache(maxsize=1)
def get_vs():
    return SQLiteVec(
        connection=get_conn(),
        embedding=get_embedder(),
        table=TABLE
    )

@lru_cache(maxsize=1)
def get_client():
    return OpenAI(base_url=OPENAI_BASE, api_key=OPENAI_API_KEY)

def get_generator(streaming: bool):
    client = get_client()
    model  = OPENAI_MODEL
    def chat(prompt: str, is_summary: bool):
        max_tokens = 160 if is_summary else int(GEN_MAX_TOKENS_ENV)
        if streaming:
            stream = client.chat.completions.create(
                model=model,
                messages=[
                    {"role":"system","content":"要点のみ。冗長・前置き・繰り返しは禁止。根拠のない推測はしない。"},
                    {"role":"user","content":prompt}
                ],
                temperature=0.2,
                max_tokens=max_tokens,
                stream=True
            )
            out=[]
            for chunk in stream:
                delta = getattr(chunk.choices[0].delta, "content", None)
                if delta:
                    print(delta, end="", flush=True)
                    out.append(delta)
            print()
            return "".join(out).strip()
        else:
            resp = client.chat.completions.create(
                model=model,
                messages=[
                    {"role":"system","content":"要点のみ。冗長・前置き・繰り返しは禁止。根拠のない推測はしない。"},
                    {"role":"user","content":prompt}
                ],
                temperature=0.2,
                max_tokens=max_tokens
            )
            return resp.choices[0].message.content.strip()
    return chat

_warmed = False
def warm_up_once():
    global _warmed
    if _warmed:
        return
    try:
        _ = get_embedder().embed_query("warm up")
    except Exception as e:
        print("[WARN] embed warmup:", e)
    try:
        _ = get_generator(streaming=False)("OKと言ってください。", is_summary=True)
    except Exception as e:
        print("[WARN] gen warmup:", e)
    _warmed = True

def _cosine(a, b):
    dot = sum(x*y for x, y in zip(a, b))
    na  = math.sqrt(sum(x*x for x in a)) or 1.0
    nb  = math.sqrt(sum(x*x for x in b)) or 1.0
    return dot / (na * nb)

def _mmr_select(query_vec, cand_vecs, k, lambda_mult=0.2):
    n = len(cand_vecs)
    if n == 0 or k <= 0:
        return []
    sims_q = [_cosine(query_vec, v) for v in cand_vecs]
    selected = []
    selected_idx = set()
    first = max(range(n), key=lambda i: sims_q[i])
    selected.append(first); selected_idx.add(first)
    while len(selected) < min(k, n):
        best_i, best_score = None, -1e9
        for i in range(n):
            if i in selected_idx: continue
            max_sim_selected = max(_cosine(cand_vecs[i], cand_vecs[j]) for j in selected) if selected else 0.0
            score = lambda_mult * sims_q[i] - (1 - lambda_mult) * max_sim_selected
            if score > best_score:
                best_score, best_i = score, i
        selected.append(best_i); selected_idx.add(best_i)
    return selected

def ensure_db():
    Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
    sqlite3.connect(DB_PATH).close()

def cmd_ingest(args):
    ensure_db()
    print(f"[INGEST] data={args.data} -> {DB_PATH} / table={TABLE}")
    raw_docs = load_texts_from_folder(args.data)
    if not raw_docs:
        print("投入対象が見つかりません（.pdf/.txt/.md/.csv）")
        return

    splitter = get_text_splitter()
    chunks = splitter.split_documents(raw_docs)
    for i, c in enumerate(chunks):
        c.metadata = dict(c.metadata or {})
        c.metadata["chunk_id"] = i

    conn = get_conn()
    embedder = get_embedder()
    vs = SQLiteVec(connection=conn, embedding=embedder, table=TABLE)
    # 進捗表示（件数だけ）
    total = len(chunks)
    for idx in range(0, total, 64):
        batch = chunks[idx:idx+64]
        vs.add_documents(batch)
        print(f"[INGEST] {min(idx+64, total)}/{total}")
    conn.commit()

    try:
        n = conn.execute(f"SELECT COUNT(*) FROM {TABLE}_vec_rowids").fetchone()[0]
        print(f"[OK] {n} チャンクをインデックス化しました。")
    except Exception as e:
        print("[WARN] 行数確認に失敗:", e)

def retrieve(query: str, k: int, fast: bool):
    vs = get_vs()
    if fast or _is_summary_query(query):
        k = min(k, 3)
    if fast:
        return vs.similarity_search(query, k=k)
    fetch_k = max(12, k*4)
    try:
        return vs.max_marginal_relevance_search(query, k=k, fetch_k=fetch_k, lambda_mult=0.2)
    except NotImplementedError:
        cands = vs.similarity_search(query, k=fetch_k)
        emb = get_embedder()
        q_vec = emb.embed_query(query)
        d_vecs = emb.embed_documents([d.page_content for d in cands])
        pick = _mmr_select(q_vec, d_vecs, k=k, lambda_mult=0.2)
        return [cands[i] for i in pick]

def build_prompt(user_q: str, docs, fast: bool):
    is_sum = _is_summary_query(user_q)
    per = 480 if (fast or is_sum) else 600
    ctx_blocks=[]
    for i, d in enumerate(docs, 1):
        src = d.metadata.get("path", "?")
        pg  = d.metadata.get("page")
        head = f"[{i}] SOURCE: {src}" + (f" p.{pg}" if pg else "")
        ctx_blocks.append(f"{head}\n{d.page_content[:per]}")
    context = "\n\n---\n\n".join(ctx_blocks)
    prompt = textwrap.dedent(f"""
    次のコンテキスト内の**原文から**、質問に**該当する箇所のみ**を日本語で簡潔に箇条書きしてください。
    新しい数値や推測は**加えない**でください。各箇条書きの末尾に [番号]（SOURCE番号）を付けてください。
    答えが無ければ「不明」と明言。

    # 質問
    {user_q}

    # コンテキスト
    {context}
    """).strip()
    return prompt, is_sum

def cmd_ask(args):
    warm_up_once()
    fast = bool(args.fast) or FAST_MODE_DEFAULT
    t0 = time.time()
    docs = retrieve(args.q, k=args.k, fast=fast)
    if not docs:
        print("該当なし"); return
    t1 = time.time()
    prompt, is_sum = build_prompt(args.q, docs, fast=fast)
    chat = get_generator(streaming=True)
    _ = chat(prompt, is_summary=is_sum)
    t2 = time.time()

    print("\n=== 参考コンテキスト ===")
    for i, d in enumerate(docs, 1):
        print(f"[{i}] {d.metadata.get('path','?')}")
    print(f"\n[Perf] retrieve: {t1-t0:.2f}s | generate: {t2-t1:.2f}s | total: {t2-t0:.2f}s | fast={fast}")

def cmd_repl(args):
    warm_up_once()
    fast = bool(args.fast) or FAST_MODE_DEFAULT
    print("SQLite-RAG REPL: 空行で終了。 fast=", fast)
    while True:
        try:
            q = input("\nQ> ").strip()
        except KeyboardInterrupt:
            print("\nBye"); break
        if not q:
            print("Bye"); break
        t0 = time.time()
        docs = retrieve(q, k=4, fast=fast)
        if not docs:
            print("該当なし"); continue
        t1 = time.time()
        prompt, is_sum = build_prompt(q, docs, fast=fast)
        ans = get_generator(streaming=True)(prompt, is_summary=is_sum)
        t2 = time.time()
        print(f"[Perf] retrieve: {t1-t0:.2f}s | generate: {t2-t1:.2f}s | total: {t2-t0:.2f}s | fast={fast}")

def main():
    ap  = argparse.ArgumentParser(description="SQLite-RAG fast starter")
    sub = ap.add_subparsers(dest="cmd")

    ap_i = sub.add_parser("ingest", help="データ投入")
    ap_i.add_argument("--data", required=True, help="PDF/TXT/MD/CSV を含むフォルダ")
    ap_i.set_defaults(func=cmd_ingest)

    ap_q = sub.add_parser("ask", help="質問/要約")
    ap_q.add_argument("--q", required=True, help="質問文（改行は \\n か PowerShell ヒアストリングで）")
    ap_q.add_argument("--k", type=int, default=4, help="検索件数（少なめ推奨）")
    ap_q.add_argument("--fast", action="store_true", help="高速モード（MMR無効＋短文コンテキスト＋ストリーミング）")
    ap_q.set_defaults(func=cmd_ask)

    ap_r = sub.add_parser("repl", help="常駐REPL（コールドスタート費用を1回に集約）")
    ap_r.add_argument("--fast", action="store_true", help="高速モード（MMR無効＋短文コンテキスト＋ストリーミング）")
    ap_r.set_defaults(func=cmd_repl)

    args = ap.parse_args()
    if not args.cmd:
        ap.print_help(); sys.exit(0)
    args.func(args)

if __name__ == "__main__":
    main()
