Source code for hypotestx.core.llm.base

"""
LLM Backend abstractions for HypoTestX.

All LLM backends implement ``LLMBackend``.  The only method that must be
overridden is ``chat()``, which receives a list of OpenAI-style message dicts
and returns the assistant text string.

Custom backends
---------------
Subclass ``LLMBackend`` and implement ``chat()``, then pass an instance to
``analyze()``:

    class MyBackend(LLMBackend):
        def chat(self, messages):
            # call your model / API
            return response_text

    result = hx.analyze(df, "Is height correlated with weight?",
                         backend=MyBackend())

You can also pass any callable that accepts ``(messages: list) -> str``:

    result = hx.analyze(df, "...", backend=my_chat_fn)
"""

from __future__ import annotations

import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

# ---------------------------------------------------------------------------
# Structured routing result
# ---------------------------------------------------------------------------


[docs] @dataclass class RoutingResult: """ Structured intent extracted from a user question by an LLM or fallback parser. The engine uses this to fetch the correct columns from the DataFrame and call the right test function. """ # Which statistical test to run test: str = "" # e.g. "two_sample_ttest" # Column names in the dataframe value_column: Optional[str] = None # the numeric / response variable group_column: Optional[str] = None # the grouping / categorical variable x_column: Optional[str] = None # alias for first var in correlation y_column: Optional[str] = None # alias for second var in correlation # Specific group labels to compare (subset of group_column's values) group_values: Optional[List[str]] = None # e.g. ["Male", "Female"] # Test parameters alternative: str = "two-sided" # "two-sided" | "greater" | "less" alpha: Optional[float] = None # None means "use analyze()'s alpha" mu: Optional[float] = None # hypothesised mean (one-sample t-test) equal_var: bool = False # Student vs Welch correction: bool = True # Yates correction for chi-square method: str = "parametric" # "parametric" | "non-parametric" # Meta reasoning: str = "" # LLM's explanation of its choice confidence: float = 1.0 # 0.0–1.0; LLM = 1.0, regex fallback = 0.6 routing_source: str = "llm" # "llm" or "fallback" raw_response: str = "" # full LLM output for debugging
[docs] @dataclass class SchemaInfo: """ Summary of a DataFrame passed to the LLM as context. Built by ``build_schema()`` in prompts.py. """ columns: List[str] = field(default_factory=list) dtypes: Dict[str, str] = field(default_factory=dict) n_rows: int = 0 # For categorical columns: unique values (up to 20) categoricals: Dict[str, List[Any]] = field(default_factory=dict) # For numeric columns: (min, max, mean) numerics: Dict[str, Dict[str, float]] = field(default_factory=dict)
# --------------------------------------------------------------------------- # Abstract base class # ---------------------------------------------------------------------------
[docs] class LLMBackend(ABC): """ Abstract LLM backend. Subclass this and implement ``chat()`` to integrate any LLM. The default ``route()`` method handles prompt building, JSON extraction, validation, and returning a ``RoutingResult`` — you only need to supply the raw API call. """ # Name shown in error messages / repr name: str = "custom"
[docs] @abstractmethod def chat(self, messages: List[Dict[str, str]]) -> str: """ Send a list of OpenAI-style messages and return the assistant reply. Args: messages: List of dicts with 'role' ('system' | 'user' | 'assistant') and 'content' (str) keys. Returns: The model's text response. """
# ------------------------------------------------------------------ # Route — called by engine.py # ------------------------------------------------------------------
[docs] def route( self, question: str, schema: "SchemaInfo", # noqa: F821 extra_context: str = "", warn_fallback: bool = True, ) -> RoutingResult: """ Build prompts, call the LLM, parse the JSON response. This method is final — override ``chat()`` instead. """ from .prompts import build_system_prompt, build_user_prompt # local import system = build_system_prompt() user = build_user_prompt(question, schema, extra_context) messages = [ {"role": "system", "content": system}, {"role": "user", "content": user}, ] try: raw = self.chat(messages) except Exception as exc: raise RuntimeError( f"[HypoTestX] LLM backend '{self.name}' raised an error: {exc}" ) from exc result = _parse_routing_json(raw) result.raw_response = raw return result
def __repr__(self) -> str: return f"<{self.__class__.__name__} backend='{self.name}'>"
# --------------------------------------------------------------------------- # Callable wrapper — lets users pass a plain function as backend # ---------------------------------------------------------------------------
[docs] class CallableBackend(LLMBackend): """Wraps any ``callable(messages) -> str`` as an ``LLMBackend``.""" name = "callable" def __init__(self, fn): if not callable(fn): raise TypeError("CallableBackend requires a callable(messages) -> str") self._fn = fn
[docs] def chat(self, messages: List[Dict[str, str]]) -> str: return self._fn(messages)
# --------------------------------------------------------------------------- # JSON extraction helpers # --------------------------------------------------------------------------- def _parse_routing_json(raw: str) -> RoutingResult: """ Extract a ``RoutingResult`` from the raw LLM text. Tries (in order): 1. Parse the first JSON code block (```json ... ```) 2. Parse bare JSON object { ... } 3. Regex fallback on key fields """ text = raw.strip() # 1. code-fence block m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) if m: return _json_to_result(m.group(1)) # 2. bare JSON object (first { ... } that spans the whole answer) m = re.search(r"\{.*\}", text, re.DOTALL) if m: return _json_to_result(m.group(0)) # 3. partial regex fallback return _regex_fallback(text) def _json_to_result(json_str: str) -> RoutingResult: """Parse a JSON string into a RoutingResult, tolerating minor issues.""" # Remove trailing commas before } or ] cleaned = re.sub(r",\s*([}\]])", r"\1", json_str) try: data: Dict = json.loads(cleaned) except json.JSONDecodeError: return _regex_fallback(json_str) r = RoutingResult() r.test = str(data.get("test", data.get("test_type", ""))).lower() r.value_column = data.get("value_column") or data.get("outcome_column") r.group_column = data.get("group_column") or data.get("grouping_column") r.x_column = data.get("x_column") or r.group_column r.y_column = data.get("y_column") or r.value_column r.group_values = data.get("group_values") r.alternative = str(data.get("alternative", "two-sided")).lower() r.alpha = float(data.get("alpha", 0.05)) r.mu = float(data["mu"]) if "mu" in data and data["mu"] is not None else None r.equal_var = bool(data.get("equal_var", False)) r.correction = bool(data.get("correction", True)) r.method = str(data.get("method", "parametric")).lower() r.reasoning = str(data.get("reasoning", "")) r.confidence = float(data.get("confidence", 1.0)) return r _TEST_KEYWORDS = { "one_sample_ttest": ["one.sample", "one_sample", "single.sample"], "two_sample_ttest": ["two.sample", "two_sample", "independent", "unpaired"], "paired_ttest": ["paired", "repeated", "before.after", "within"], "anova": ["anova", "one.way", "multiple.group", "three.or.more"], "kruskal_wallis": ["kruskal", "kruskal.wallis"], "mann_whitney": ["mann", "whitney", "wilcoxon.rank"], "wilcoxon": ["wilcoxon.signed", "signed.rank"], "chi_square": [ "chi.square", "chi2", "categorical", "contingency", "association", "independence", ], "fisher": ["fisher"], "pearson": ["pearson", "linear.corr"], "spearman": ["spearman", "rank.corr", "monotone"], "correlation": ["correlat"], } def _regex_fallback(text: str) -> RoutingResult: """Last-resort regex scan of the raw LLM output.""" lower = text.lower() r = RoutingResult() for test, kws in _TEST_KEYWORDS.items(): if any(re.search(kw, lower) for kw in kws): r.test = test break if "greater" in lower: r.alternative = "greater" elif "less" in lower: r.alternative = "less" r.reasoning = "(extracted via regex fallback)" return r