Source code for hypotestx.core.llm.prompts

"""
Prompt templates for HypoTestX's LLM routing layer.

All prompts are plain strings so they are easy to read, audit, and override.
Nothing here calls any LLM — that is the backend's job.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .base import SchemaInfo


# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """\
You are a statistical analysis routing expert embedded in the HypoTestX library.

Your ONLY job is to read a plain-English question about a dataset and decide
which statistical test to run and which columns to use.

## Available tests

| test key              | When to use |
|-----------------------|-------------|
| one_sample_ttest      | One group, test if mean equals a number (e.g. "Is the mean height 170?") |
| two_sample_ttest      | Two independent groups, compare means (e.g. "Do males earn more than females?") |  # noqa: E501
| paired_ttest          | Same subjects measured twice, compare before/after means |
| anova                 | Three or more independent groups, compare means |
| mann_whitney          | Two groups, non-parametric (use when data is non-normal or ordinal) |
| wilcoxon              | Two paired measurements, non-parametric |
| kruskal_wallis        | Three or more groups, non-parametric |
| chi_square            | Two categorical variables, test for association / independence |
| fisher                | Two categorical variables (2x2 table only), small sample sizes |
| pearson               | Two continuous variables, test linear correlation |
| spearman              | Two variables (one/both ordinal or non-normal), rank correlation |
| point_biserial        | One binary (0/1) variable and one continuous variable |

## Decision rules

1. If the question asks to "compare means" between exactly TWO named groups
   inside one column -> two_sample_ttest (or mann_whitney if ordinal/non-normal).
2. If THREE or more groups -> anova (or kruskal_wallis).
3. "Before and after", "pre and post", "change over time" on the SAME subjects -> paired_ttest.
4. "Association", "relationship", "independent" between two categorical columns -> chi_square.
5. "Correlation", "related to", "predict" between two numeric columns -> pearson or spearman.
6. "Is the mean equal to / greater than / less than {number}" -> one_sample_ttest.

## alternative field

- "two-sided" : question asks "different / any change"
- "greater"   : first group / variable is hypothesised HIGHER
- "less"      : first group / variable is hypothesised LOWER

## Output format

Reply with ONLY a JSON object — no markdown prose before/after it.

```json
{
  "test": "<test key from table above>",
  "value_column": "<name of the numeric response column, or null>",
  "group_column": "<name of the grouping/categorical column, or null>",
  "x_column": "<first variable for correlation, or null>",
  "y_column": "<second variable for correlation, or null>",
  "group_values": ["<group A label>", "<group B label>"] or null,
  "alternative": "two-sided" | "greater" | "less",
  "alpha": 0.05,
  "mu": <null or numeric value for one-sample test>,
  "equal_var": false,
  "method": "parametric" | "non-parametric",
  "reasoning": "<one sentence explaining your choice>"
}
```

Rules:
- Only use column names that exist in the provided schema.
- If a column does not exist, set it to null.
- Do NOT invent column names.
- Never explain or apologise — output JSON only.
"""


# ---------------------------------------------------------------------------
# Schema builder
# ---------------------------------------------------------------------------


[docs] def build_schema(df) -> "SchemaInfo": """ Build a ``SchemaInfo`` snapshot from a DataFrame (pandas or polars). Works without importing pandas/polars at module level. """ from .base import SchemaInfo info = SchemaInfo() # ── dict fallback (used in tests and simple scripts) ──────────────── if isinstance(df, dict): info.columns = list(df.keys()) first_col = info.columns[0] if info.columns else None info.n_rows = len(df[first_col]) if first_col is not None else 0 for col, vals in df.items(): # Detect numeric vs categorical non_null = [v for v in vals if v is not None] if non_null and isinstance(non_null[0], (int, float)): info.dtypes[col] = "float64" info.numerics[col] = { "min": float(min(non_null)), "max": float(max(non_null)), "mean": float(sum(non_null) / len(non_null)), } else: info.dtypes[col] = "object" unique_vals = list(dict.fromkeys(str(v) for v in non_null))[:20] info.categoricals[col] = unique_vals return info info.n_rows = len(df) # Support both pandas and polars try: # pandas info.columns = list(df.columns) for col in df.columns: dtype_str = str(df[col].dtype) info.dtypes[col] = dtype_str if dtype_str in ("object", "category", "string", "str", "bool", "boolean"): vals = [str(v) for v in df[col].dropna().unique()[:20]] info.categoricals[col] = vals elif "int" in dtype_str or "float" in dtype_str: series = df[col].dropna() if len(series) > 0: info.numerics[col] = { "min": float(series.min()), "max": float(series.max()), "mean": float(series.mean()), } except AttributeError: # polars (has .schema dict) try: info.columns = df.columns for col, dtype in df.schema.items(): dtype_str = str(dtype) info.dtypes[col] = dtype_str col_series = df[col].drop_nulls() if "Utf8" in dtype_str or "Categorical" in dtype_str or "Boolean" in dtype_str: vals = [str(v) for v in col_series.unique().to_list()[:20]] info.categoricals[col] = vals elif "Int" in dtype_str or "Float" in dtype_str: if len(col_series) > 0: info.numerics[col] = { "min": float(col_series.min()), "max": float(col_series.max()), "mean": float(col_series.mean()), } except Exception: pass return info
def build_system_prompt() -> str: """Return the system prompt (constant).""" return SYSTEM_PROMPT def build_user_prompt( question: str, schema: "SchemaInfo", extra_context: str = "", ) -> str: """ Build the user-turn prompt that includes the dataset schema and the question. """ lines = [f"Dataset: {schema.n_rows} rows"] lines.append("") lines.append("Columns:") for col in schema.columns: dtype = schema.dtypes.get(col, "unknown") if col in schema.categoricals: uniques = schema.categoricals[col] preview = ", ".join(f'"{v}"' for v in uniques[:8]) if len(uniques) > 8: preview += f", ... ({len(uniques)} unique)" lines.append(f" - {col!r} [{dtype}]: {preview}") elif col in schema.numerics: n = schema.numerics[col] lines.append( f" - {col!r} [{dtype}]: " f"min={n['min']:.2f}, max={n['max']:.2f}, mean={n['mean']:.2f}" ) else: lines.append(f" - {col!r} [{dtype}]") if extra_context: lines.append("") lines.append("Additional context:") lines.append(extra_context) lines.append("") lines.append(f'Question: "{question}"') lines.append("") lines.append("Reply with JSON only.") return "\n".join(lines) # --------------------------------------------------------------------------- # Prompt fragments used by backends that need a simpler single-string prompt # (e.g. completion-only models) # --------------------------------------------------------------------------- def build_completion_prompt(question: str, schema: "SchemaInfo") -> str: """ Combines system + user into a single string for completion-style APIs. """ system = build_system_prompt() user = build_user_prompt(question, schema) return f"{system}\n\n---\n\n{user}\n\nJSON answer:\n"