Skip to content

Utilities

Experiments

core

ExperimentJob dataclass

Defines a single atomic unit of work for the backtesting engine.

Contains all necessary information to run and reproduce a specific experiment.

Source code in src/quantrl_lab/experiments/backtesting/core.py
@dataclass
class ExperimentJob:
    """
    Defines a single atomic unit of work for the backtesting engine.

    Contains all necessary information to run and reproduce a specific
    experiment.
    """

    algorithm_class: Type
    env_config: BacktestEnvironmentConfig

    # Explicit algorithm configuration (e.g. {'learning_rate': 0.001})
    # If empty, uses the algorithm's default hyperparameters.
    algorithm_config: Dict[str, Any] = field(default_factory=dict)

    # Run parameters
    total_timesteps: int = 50000
    n_envs: int = 4
    num_eval_episodes: int = 5

    # Identification
    id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
    tags: Dict[str, str] = field(default_factory=dict)

    def __post_init__(self):
        # Ensure tags exist
        if not self.tags:
            self.tags = {
                "algo": self.algorithm_class.__name__,
                "env": self.env_config.name,
            }

ExperimentResult dataclass

Standardized result object containing all artifacts from a job.

Source code in src/quantrl_lab/experiments/backtesting/core.py
@dataclass
class ExperimentResult:
    """Standardized result object containing all artifacts from a
    job."""

    job: ExperimentJob
    metrics: Dict[str, float]  # Flattened metrics (train_return, test_sharpe, etc.)

    # Artifacts
    model: Any = None  # The trained model object
    train_episodes: List[Dict] = field(default_factory=list)
    test_episodes: List[Dict] = field(default_factory=list)
    top_features: Dict[str, float] = field(default_factory=dict)
    explanation_method: str = "Correlation"

    # Metadata
    status: str = "completed"  # completed, failed
    error: Optional[Exception] = None
    execution_time: float = 0.0

JobGenerator

Helper to generate combinatorial lists of jobs.

Source code in src/quantrl_lab/experiments/backtesting/core.py
class JobGenerator:
    """Helper to generate combinatorial lists of jobs."""

    @staticmethod
    def generate_grid(
        algorithms: List[Type],
        env_configs: Dict[str, BacktestEnvironmentConfig],
        algorithm_configs: Optional[List[Dict[str, Any]]] = None,
        **job_kwargs,
    ) -> List[ExperimentJob]:
        """
        Generate a grid of experiments.

        Args:
            algorithms: List of algorithm classes
            env_configs: Dictionary of name -> BacktestEnvironmentConfig
            algorithm_configs: List of configuration dictionaries to try.
                             If None, uses a single empty dict (defaults).
            **job_kwargs: Common arguments for all jobs (total_timesteps, etc.)

        Returns:
            List[ExperimentJob]: List of jobs to be executed
        """
        if algorithm_configs is None:
            algorithm_configs = [{}]  # Single default run

        jobs = []
        for algo in algorithms:
            for env_name, env_conf in env_configs.items():
                for i, config in enumerate(algorithm_configs):
                    # Create tags
                    tags = {
                        "algo": algo.__name__,
                        "env": env_name,
                        "config_id": str(i) if len(algorithm_configs) > 1 else "default",
                    }

                    job = ExperimentJob(
                        algorithm_class=algo,
                        env_config=env_conf,
                        algorithm_config=config,
                        tags=tags,
                        **job_kwargs,
                    )
                    jobs.append(job)
        return jobs

generate_grid(algorithms, env_configs, algorithm_configs=None, **job_kwargs) staticmethod

Generate a grid of experiments.

Parameters:

Name Type Description Default
algorithms List[Type]

List of algorithm classes

required
env_configs Dict[str, BacktestEnvironmentConfig]

Dictionary of name -> BacktestEnvironmentConfig

required
algorithm_configs Optional[List[Dict[str, Any]]]

List of configuration dictionaries to try. If None, uses a single empty dict (defaults).

None
**job_kwargs

Common arguments for all jobs (total_timesteps, etc.)

{}

Returns:

Type Description
List[ExperimentJob]

List[ExperimentJob]: List of jobs to be executed

Source code in src/quantrl_lab/experiments/backtesting/core.py
@staticmethod
def generate_grid(
    algorithms: List[Type],
    env_configs: Dict[str, BacktestEnvironmentConfig],
    algorithm_configs: Optional[List[Dict[str, Any]]] = None,
    **job_kwargs,
) -> List[ExperimentJob]:
    """
    Generate a grid of experiments.

    Args:
        algorithms: List of algorithm classes
        env_configs: Dictionary of name -> BacktestEnvironmentConfig
        algorithm_configs: List of configuration dictionaries to try.
                         If None, uses a single empty dict (defaults).
        **job_kwargs: Common arguments for all jobs (total_timesteps, etc.)

    Returns:
        List[ExperimentJob]: List of jobs to be executed
    """
    if algorithm_configs is None:
        algorithm_configs = [{}]  # Single default run

    jobs = []
    for algo in algorithms:
        for env_name, env_conf in env_configs.items():
            for i, config in enumerate(algorithm_configs):
                # Create tags
                tags = {
                    "algo": algo.__name__,
                    "env": env_name,
                    "config_id": str(i) if len(algorithm_configs) > 1 else "default",
                }

                job = ExperimentJob(
                    algorithm_class=algo,
                    env_config=env_conf,
                    algorithm_config=config,
                    tags=tags,
                    **job_kwargs,
                )
                jobs.append(job)
    return jobs

runner

BacktestRunner

Orchestrates complete backtesting workflows by chaining training and evaluation.

This class provides high-level interfaces for running comprehensive experiments that train multiple algorithms on different environment configurations and evaluate their performance.

Source code in src/quantrl_lab/experiments/backtesting/runner.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
class BacktestRunner:
    """
    Orchestrates complete backtesting workflows by chaining training and
    evaluation.

    This class provides high-level interfaces for running comprehensive
    experiments that train multiple algorithms on different environment
    configurations and evaluate their performance.
    """

    def __init__(self, verbose: bool = True):
        self.verbose = verbose
        self.metrics_calculator = MetricsCalculator()

    def run_job(self, job: ExperimentJob) -> ExperimentResult:
        """
        Executes a single experiment job using the new Job/Batch
        architecture.

        Args:
            job (ExperimentJob): The job description containing all parameters.

        Returns:
            ExperimentResult: The result of the experiment.
        """
        import time

        start_time = time.time()

        if self.verbose:
            console.print(f"\n[bold blue]{'='*60}[/bold blue]")
            console.print(f"[bold blue]RUNNING JOB: {job.id}[/bold blue]")
            console.print(f"[cyan]Algo: {job.algorithm_class.__name__} | Env: {job.env_config.name}[/cyan]")
            if job.algorithm_config:
                console.print(f"[dim]Config: {job.algorithm_config}[/dim]")

        try:
            # 1. Training Phase
            if self.verbose:
                console.print("[bold green]🔄 Training...[/bold green]")

            # Create vectorized environment for training
            if isinstance(job.env_config.train_env_factory, list):
                # Multi-stock vectorized training
                train_vec_env = SubprocVecEnv(job.env_config.train_env_factory)
            else:
                # Single-stock parallel rollout
                train_vec_env = make_vec_env(job.env_config.train_env_factory, n_envs=job.n_envs)

            model = train_model(
                algo_class=job.algorithm_class,
                env=train_vec_env,
                config=job.algorithm_config,
                total_timesteps=job.total_timesteps,
                verbose=1 if self.verbose else 0,
            )

            # 2. Evaluation Phase
            if self.verbose:
                console.print("[bold blue]📊 Evaluating...[/bold blue]")

            def _evaluate_factories(factories) -> tuple:
                if not isinstance(factories, list):
                    factories = [factories]
                all_rewards = []
                all_episodes = []
                for factory in factories:
                    env = factory()
                    rew, eps = evaluate_model(model=model, env=env, num_episodes=job.num_eval_episodes, verbose=False)
                    all_rewards.extend(rew)
                    all_episodes.extend(eps)
                    env.close()
                return all_rewards, all_episodes

            # Evaluate on Train
            train_rewards, train_episodes = _evaluate_factories(job.env_config.train_env_factory)

            # Evaluate on Test
            test_rewards, test_episodes = _evaluate_factories(job.env_config.test_env_factory)

            # 3. Metrics Calculation
            train_metrics = self.metrics_calculator.calculate(train_episodes)
            test_metrics = self.metrics_calculator.calculate(test_episodes)

            # Flattened metrics for the result object
            # Prefix with dataset name for clarity
            metrics = {}
            for k, v in train_metrics.items():
                metrics[f"train_{k}"] = v
            for k, v in test_metrics.items():
                metrics[f"test_{k}"] = v

            # 4. Feature Importance
            top_features = {}
            explanation_method = "Correlation"
            if self.verbose:
                console.print("[bold yellow]🧠 Analyzing Feature Importance...[/bold yellow]")
            try:
                from quantrl_lab.experiments.backtesting.explainer import AgentExplainer

                exp_factory = job.env_config.test_env_factory
                env_for_explainer = exp_factory[0]() if isinstance(exp_factory, list) else exp_factory()

                explainer = AgentExplainer(model, env_for_explainer)
                top_features = explainer.analyze_feature_importance(top_k=5)
                explanation_method = getattr(explainer, "last_method_used", "Correlation")
                env_for_explainer.close()
            except Exception as e:
                explanation_method = "Correlation"
                if self.verbose:
                    console.print(f"[yellow]Feature importance analysis skipped/failed: {e}[/yellow]")

            execution_time = time.time() - start_time

            result = ExperimentResult(
                job=job,
                metrics=metrics,
                model=model,
                train_episodes=train_episodes,
                test_episodes=test_episodes,
                top_features=top_features,
                explanation_method=explanation_method,
                status="completed",
                execution_time=execution_time,
            )

            if self.verbose:
                train_return = metrics.get("train_avg_return_pct", 0.0)
                test_return = metrics.get("test_avg_return_pct", 0.0)
                train_sharpe = metrics.get("train_avg_sharpe_ratio", 0.0)
                test_sharpe = metrics.get("test_avg_sharpe_ratio", 0.0)

                train_color = "green" if train_return > 0 else "red"
                test_color = "green" if test_return > 0 else "red"

                console.print("[bold]Result:[/bold]")
                console.print(
                    f"  Train: [{train_color}]{train_return:.2f}%[/{train_color}] (Sharpe: {train_sharpe:.2f})"
                )
                console.print(f"  Test:  [{test_color}]{test_return:.2f}%[/{test_color}] (Sharpe: {test_sharpe:.2f})")

            return result

        except Exception as e:
            if self.verbose:
                console.print(f"[bold red]❌ Job Failed: {str(e)}[/bold red]")
                import traceback

                console.print(traceback.format_exc())

            return ExperimentResult(
                job=job, metrics={}, status="failed", error=e, execution_time=time.time() - start_time
            )

    def run_batch(self, jobs: List[ExperimentJob]) -> List[ExperimentResult]:
        """
        Executes a batch of jobs sequentially (can be upgraded to
        parallel later).

        Args:
            jobs (List[ExperimentJob]): List of jobs to run.

        Returns:
            List[ExperimentResult]: Results for each job.
        """
        results = []
        if self.verbose:
            console.print(f"\n[bold magenta]Starting Batch Execution: {len(jobs)} jobs[/bold magenta]")

        for i, job in enumerate(jobs):
            if self.verbose:
                console.print(f"\n[dim]--- Job {i+1}/{len(jobs)} ---[/dim]")
            results.append(self.run_job(job))

        if self.verbose:
            success_count = sum(1 for r in results if r.status == "completed")
            console.print(f"\n[bold magenta]Batch Completed: {success_count}/{len(jobs)} successful[/bold magenta]")

        return results

    @staticmethod
    def inspect_result(result: ExperimentResult) -> None:
        """
        Inspect and display the results of a single experiment job.

        Args:
            result (ExperimentResult): The result object to inspect.
        """
        job = result.job
        metrics = result.metrics

        # --- Main Summary Panel ---
        algo_name = job.algorithm_class.__name__
        config_id = job.tags.get("config_id", "default")
        train_return = metrics.get("train_avg_return_pct", 0.0)
        test_return = metrics.get("test_avg_return_pct", 0.0)

        train_return_color = "green" if train_return >= 0 else "red"
        test_return_color = "green" if test_return >= 0 else "red"

        summary_text = (
            f"Job ID: [bold]{job.id}[/bold]\n"
            f"Algorithm: [bold cyan]{algo_name}[/bold cyan]\n"
            f"Env: [yellow]{job.env_config.name}[/yellow]\n"
            f"Config ID: [yellow]{config_id}[/yellow]\n"
            f"Status: {result.status}\n"
            f"Train Avg Return: [{train_return_color}]{train_return:.2f}%[/{train_return_color}]\n"
            f"Test Avg Return:  [{test_return_color}]{test_return:.2f}%[/{test_return_color}]\n"
        )

        # Add advanced metrics if available
        if "test_avg_sharpe_ratio" in metrics:
            summary_text += f"Test Sharpe: {metrics['test_avg_sharpe_ratio']:.2f}\n"
        if "test_avg_max_drawdown" in metrics:
            summary_text += f"Test Max DD: {metrics['test_avg_max_drawdown']*100:.2f}%\n"

        if result.top_features:
            summary_text += f"\n[bold]Top Learned Features ({result.explanation_method}):[/bold]\n"
            for feat, score in result.top_features.items():
                summary_text += f"  - {feat}: {score:+.2f}\n"

        if result.error:
            summary_text += f"\n[red]Error: {str(result.error)}[/red]"

        console.print(Panel(summary_text, title="[bold]Experiment Result[/bold]", expand=False))

        if result.status == "failed":
            return

        # --- Episode Details Table ---
        episode_table = Table(title="Episode Performance Details", show_header=True, header_style="bold magenta")
        episode_table.add_column("Dataset", style="cyan")
        episode_table.add_column("Episode", justify="center")
        episode_table.add_column("Return %", justify="right")
        episode_table.add_column("Reward", justify="right")
        episode_table.add_column("Final Value", justify="right")
        episode_table.add_column("Total Steps", justify="right")

        # Function to add rows for a dataset (train/test)
        def add_episode_rows(dataset_name, episodes):
            if not episodes:
                return
            for i, ep in enumerate(episodes):
                if "error" in ep:
                    continue
                initial = ep.get("initial_value", 0)
                final = ep.get("final_value", 0)
                reward = ep.get("total_reward", 0)

                ret = ((final - initial) / initial) * 100 if initial != 0 else 0
                ret_color = "green" if ret >= 0 else "red"
                reward_color = "green" if reward >= 0 else "red"

                episode_table.add_row(
                    dataset_name,
                    str(i + 1),
                    f"[{ret_color}]{ret:.2f}%[/{ret_color}]",
                    f"[{reward_color}]{reward:.2f}[/{reward_color}]",
                    f"${final:,.2f}",
                    str(ep.get("steps", "N/A")),
                )

        add_episode_rows("Train", result.train_episodes)
        add_episode_rows("Test", result.test_episodes)

        if result.train_episodes or result.test_episodes:
            console.print(episode_table)
        else:
            console.print("[yellow]No episode data available.[/yellow]")

        # --- Action Distribution Table ---
        all_episodes = result.train_episodes + result.test_episodes
        all_actions: dict = {}
        total_steps = 0
        for ep in all_episodes:
            if "error" not in ep:
                total_steps += ep.get("steps", 0)
                for action_type, count in ep.get("actions_taken", {}).items():
                    all_actions[action_type] = all_actions.get(action_type, 0) + count

        if all_actions and total_steps > 0:
            action_table = Table(title="Action Distribution (all episodes)", show_header=True, header_style="bold cyan")
            action_table.add_column("Action", style="cyan")
            action_table.add_column("Count", justify="right")
            action_table.add_column("% of Steps", justify="right", style="yellow")
            for action_type, count in sorted(all_actions.items()):
                action_table.add_row(action_type, str(count), f"{count / total_steps * 100:.1f}%")
            console.print(action_table)

    @staticmethod
    def inspect_batch(results: List[ExperimentResult]) -> None:
        """
        Inspect and display a summary of a batch of experiments.

        Args:
            results (List[ExperimentResult]): List of experiment results.
        """
        console.print(f"\n[bold magenta]{'='*80}[/bold magenta]")
        console.print("[bold magenta]BATCH EXPERIMENT SUMMARY[/bold magenta]")
        console.print(f"[bold magenta]{'='*80}[/bold magenta]")
        # Preset column removed
        table = Table(title="Batch Results", show_header=True, header_style="bold magenta")
        table.add_column("ID", style="dim", no_wrap=True)
        table.add_column("Algo", style="cyan")
        table.add_column("Env", style="yellow")
        table.add_column("Status", justify="center")
        table.add_column("Train Ret %", justify="right")
        table.add_column("Test Ret %", justify="right")
        table.add_column("Test Sharpe", justify="right")
        table.add_column("Time (s)", justify="right")
        table.add_column("Top Feature", style="dim")

        for res in results:
            job = res.job
            metrics = res.metrics

            status_style = "green" if res.status == "completed" else "red"
            status_str = f"[{status_style}]{res.status}[/{status_style}]"

            if res.status == "completed":
                train_ret = metrics.get("train_avg_return_pct", 0.0)
                test_ret = metrics.get("test_avg_return_pct", 0.0)
                test_sharpe = metrics.get("test_avg_sharpe_ratio", 0.0)

                train_color = "green" if train_ret >= 0 else "red"
                test_color = "green" if test_ret >= 0 else "red"

                train_str = f"[{train_color}]{train_ret:.2f}%[/{train_color}]"
                test_str = f"[{test_color}]{test_ret:.2f}%[/{test_color}]"
                sharpe_str = f"{test_sharpe:.2f}"
            else:
                train_str = "-"
                test_str = "-"
                sharpe_str = "-"

            top_feat_str = "-"
            if res.top_features:
                # Get the highest correlated feature
                top_feat_name, top_feat_corr = list(res.top_features.items())[0]
                top_feat_str = f"{top_feat_name} ({top_feat_corr:+.2f})"

            # Add row to table
            table.add_row(
                job.id,
                job.algorithm_class.__name__,
                job.env_config.name,
                status_str,
                train_str,
                test_str,
                sharpe_str,
                f"{res.execution_time:.1f}",
                top_feat_str,
            )

        console.print(table)

    @staticmethod
    def create_env_config(train_env_factory: Callable, test_env_factory: Callable) -> BacktestEnvironmentConfig:
        """
        Helper method to create env_config from individual factory
        functions.

        Args:
            train_env_factory (Callable): Function that creates training environment
            test_env_factory (Callable): Function that creates test environment

        Returns:
            BacktestEnvironmentConfig: Environment configuration object
        """
        return BacktestEnvironmentConfig(train_env_factory=train_env_factory, test_env_factory=test_env_factory)

    @staticmethod
    def create_env_config_factory(
        train_data: "pd.DataFrame",
        test_data: "pd.DataFrame",
        action_strategy: "BaseActionStrategy",
        reward_strategy: "BaseRewardStrategy",
        observation_strategy: "BaseObservationStrategy",
        eval_data: Optional["pd.DataFrame"] = None,
        initial_balance: float = 100000.0,
        transaction_cost_pct: float = 0.001,
        slippage_pct: float = 0.0005,
        window_size: int = 20,
        order_expiration_steps: int = 5,
    ) -> BacktestEnvironmentConfig:
        """
        Creates a configuration object with environment factories.
        DEPRECATED: Use BacktestEnvironmentBuilder instead.

        Args:
            train_data (pd.DataFrame): DataFrame for the training environment.
            test_data (pd.DataFrame): DataFrame for the test environment.
            action_strategy (BaseActionStrategy): The action strategy to use.
            reward_strategy (BaseRewardStrategy): The reward strategy to use.
            observation_strategy (BaseObservationStrategy): The observation strategy.
            eval_data (Optional[pd.DataFrame], optional): DataFrame for evaluation.
            initial_balance (float, optional): Initial portfolio balance.
            transaction_cost_pct (float, optional): Transaction cost percentage.
            window_size (int, optional): The size of the observation window.

        Returns:
            BacktestEnvironmentConfig: Configuration object containing factories and metadata.
        """
        from quantrl_lab.environments.stock.components.config import SingleStockEnvConfig
        from quantrl_lab.environments.stock.single import SingleStockTradingEnv

        # Helper function to create a single environment factory
        def _create_factory(data: "pd.DataFrame"):
            return lambda: SingleStockTradingEnv(
                data=data,
                config=SingleStockEnvConfig(
                    initial_balance=initial_balance,
                    transaction_cost_pct=transaction_cost_pct,
                    slippage=slippage_pct,
                    window_size=window_size,
                    order_expiration_steps=order_expiration_steps,
                ),
                action_strategy=action_strategy,
                reward_strategy=reward_strategy,
                observation_strategy=observation_strategy,
                price_column="Close",  # Default to Close
            )

        # Capture parameters for reproducibility
        parameters = {
            "initial_balance": initial_balance,
            "transaction_cost_pct": transaction_cost_pct,
            "slippage_pct": slippage_pct,
            "window_size": window_size,
            "order_expiration_steps": order_expiration_steps,
            "action_strategy": action_strategy.__class__.__name__,
            "reward_strategy": reward_strategy.__class__.__name__,
            "observation_strategy": observation_strategy.__class__.__name__,
        }

        eval_factory = _create_factory(eval_data) if eval_data is not None else None

        return BacktestEnvironmentConfig(
            train_env_factory=_create_factory(train_data),
            test_env_factory=_create_factory(test_data),
            eval_env_factory=eval_factory,
            parameters=parameters,
            description=f"Standard Stock Env (Window: {window_size})",
        )

run_job(job)

Executes a single experiment job using the new Job/Batch architecture.

Parameters:

Name Type Description Default
job ExperimentJob

The job description containing all parameters.

required

Returns:

Name Type Description
ExperimentResult ExperimentResult

The result of the experiment.

Source code in src/quantrl_lab/experiments/backtesting/runner.py
def run_job(self, job: ExperimentJob) -> ExperimentResult:
    """
    Executes a single experiment job using the new Job/Batch
    architecture.

    Args:
        job (ExperimentJob): The job description containing all parameters.

    Returns:
        ExperimentResult: The result of the experiment.
    """
    import time

    start_time = time.time()

    if self.verbose:
        console.print(f"\n[bold blue]{'='*60}[/bold blue]")
        console.print(f"[bold blue]RUNNING JOB: {job.id}[/bold blue]")
        console.print(f"[cyan]Algo: {job.algorithm_class.__name__} | Env: {job.env_config.name}[/cyan]")
        if job.algorithm_config:
            console.print(f"[dim]Config: {job.algorithm_config}[/dim]")

    try:
        # 1. Training Phase
        if self.verbose:
            console.print("[bold green]🔄 Training...[/bold green]")

        # Create vectorized environment for training
        if isinstance(job.env_config.train_env_factory, list):
            # Multi-stock vectorized training
            train_vec_env = SubprocVecEnv(job.env_config.train_env_factory)
        else:
            # Single-stock parallel rollout
            train_vec_env = make_vec_env(job.env_config.train_env_factory, n_envs=job.n_envs)

        model = train_model(
            algo_class=job.algorithm_class,
            env=train_vec_env,
            config=job.algorithm_config,
            total_timesteps=job.total_timesteps,
            verbose=1 if self.verbose else 0,
        )

        # 2. Evaluation Phase
        if self.verbose:
            console.print("[bold blue]📊 Evaluating...[/bold blue]")

        def _evaluate_factories(factories) -> tuple:
            if not isinstance(factories, list):
                factories = [factories]
            all_rewards = []
            all_episodes = []
            for factory in factories:
                env = factory()
                rew, eps = evaluate_model(model=model, env=env, num_episodes=job.num_eval_episodes, verbose=False)
                all_rewards.extend(rew)
                all_episodes.extend(eps)
                env.close()
            return all_rewards, all_episodes

        # Evaluate on Train
        train_rewards, train_episodes = _evaluate_factories(job.env_config.train_env_factory)

        # Evaluate on Test
        test_rewards, test_episodes = _evaluate_factories(job.env_config.test_env_factory)

        # 3. Metrics Calculation
        train_metrics = self.metrics_calculator.calculate(train_episodes)
        test_metrics = self.metrics_calculator.calculate(test_episodes)

        # Flattened metrics for the result object
        # Prefix with dataset name for clarity
        metrics = {}
        for k, v in train_metrics.items():
            metrics[f"train_{k}"] = v
        for k, v in test_metrics.items():
            metrics[f"test_{k}"] = v

        # 4. Feature Importance
        top_features = {}
        explanation_method = "Correlation"
        if self.verbose:
            console.print("[bold yellow]🧠 Analyzing Feature Importance...[/bold yellow]")
        try:
            from quantrl_lab.experiments.backtesting.explainer import AgentExplainer

            exp_factory = job.env_config.test_env_factory
            env_for_explainer = exp_factory[0]() if isinstance(exp_factory, list) else exp_factory()

            explainer = AgentExplainer(model, env_for_explainer)
            top_features = explainer.analyze_feature_importance(top_k=5)
            explanation_method = getattr(explainer, "last_method_used", "Correlation")
            env_for_explainer.close()
        except Exception as e:
            explanation_method = "Correlation"
            if self.verbose:
                console.print(f"[yellow]Feature importance analysis skipped/failed: {e}[/yellow]")

        execution_time = time.time() - start_time

        result = ExperimentResult(
            job=job,
            metrics=metrics,
            model=model,
            train_episodes=train_episodes,
            test_episodes=test_episodes,
            top_features=top_features,
            explanation_method=explanation_method,
            status="completed",
            execution_time=execution_time,
        )

        if self.verbose:
            train_return = metrics.get("train_avg_return_pct", 0.0)
            test_return = metrics.get("test_avg_return_pct", 0.0)
            train_sharpe = metrics.get("train_avg_sharpe_ratio", 0.0)
            test_sharpe = metrics.get("test_avg_sharpe_ratio", 0.0)

            train_color = "green" if train_return > 0 else "red"
            test_color = "green" if test_return > 0 else "red"

            console.print("[bold]Result:[/bold]")
            console.print(
                f"  Train: [{train_color}]{train_return:.2f}%[/{train_color}] (Sharpe: {train_sharpe:.2f})"
            )
            console.print(f"  Test:  [{test_color}]{test_return:.2f}%[/{test_color}] (Sharpe: {test_sharpe:.2f})")

        return result

    except Exception as e:
        if self.verbose:
            console.print(f"[bold red]❌ Job Failed: {str(e)}[/bold red]")
            import traceback

            console.print(traceback.format_exc())

        return ExperimentResult(
            job=job, metrics={}, status="failed", error=e, execution_time=time.time() - start_time
        )

run_batch(jobs)

Executes a batch of jobs sequentially (can be upgraded to parallel later).

Parameters:

Name Type Description Default
jobs List[ExperimentJob]

List of jobs to run.

required

Returns:

Type Description
List[ExperimentResult]

List[ExperimentResult]: Results for each job.

Source code in src/quantrl_lab/experiments/backtesting/runner.py
def run_batch(self, jobs: List[ExperimentJob]) -> List[ExperimentResult]:
    """
    Executes a batch of jobs sequentially (can be upgraded to
    parallel later).

    Args:
        jobs (List[ExperimentJob]): List of jobs to run.

    Returns:
        List[ExperimentResult]: Results for each job.
    """
    results = []
    if self.verbose:
        console.print(f"\n[bold magenta]Starting Batch Execution: {len(jobs)} jobs[/bold magenta]")

    for i, job in enumerate(jobs):
        if self.verbose:
            console.print(f"\n[dim]--- Job {i+1}/{len(jobs)} ---[/dim]")
        results.append(self.run_job(job))

    if self.verbose:
        success_count = sum(1 for r in results if r.status == "completed")
        console.print(f"\n[bold magenta]Batch Completed: {success_count}/{len(jobs)} successful[/bold magenta]")

    return results

inspect_result(result) staticmethod

Inspect and display the results of a single experiment job.

Parameters:

Name Type Description Default
result ExperimentResult

The result object to inspect.

required
Source code in src/quantrl_lab/experiments/backtesting/runner.py
@staticmethod
def inspect_result(result: ExperimentResult) -> None:
    """
    Inspect and display the results of a single experiment job.

    Args:
        result (ExperimentResult): The result object to inspect.
    """
    job = result.job
    metrics = result.metrics

    # --- Main Summary Panel ---
    algo_name = job.algorithm_class.__name__
    config_id = job.tags.get("config_id", "default")
    train_return = metrics.get("train_avg_return_pct", 0.0)
    test_return = metrics.get("test_avg_return_pct", 0.0)

    train_return_color = "green" if train_return >= 0 else "red"
    test_return_color = "green" if test_return >= 0 else "red"

    summary_text = (
        f"Job ID: [bold]{job.id}[/bold]\n"
        f"Algorithm: [bold cyan]{algo_name}[/bold cyan]\n"
        f"Env: [yellow]{job.env_config.name}[/yellow]\n"
        f"Config ID: [yellow]{config_id}[/yellow]\n"
        f"Status: {result.status}\n"
        f"Train Avg Return: [{train_return_color}]{train_return:.2f}%[/{train_return_color}]\n"
        f"Test Avg Return:  [{test_return_color}]{test_return:.2f}%[/{test_return_color}]\n"
    )

    # Add advanced metrics if available
    if "test_avg_sharpe_ratio" in metrics:
        summary_text += f"Test Sharpe: {metrics['test_avg_sharpe_ratio']:.2f}\n"
    if "test_avg_max_drawdown" in metrics:
        summary_text += f"Test Max DD: {metrics['test_avg_max_drawdown']*100:.2f}%\n"

    if result.top_features:
        summary_text += f"\n[bold]Top Learned Features ({result.explanation_method}):[/bold]\n"
        for feat, score in result.top_features.items():
            summary_text += f"  - {feat}: {score:+.2f}\n"

    if result.error:
        summary_text += f"\n[red]Error: {str(result.error)}[/red]"

    console.print(Panel(summary_text, title="[bold]Experiment Result[/bold]", expand=False))

    if result.status == "failed":
        return

    # --- Episode Details Table ---
    episode_table = Table(title="Episode Performance Details", show_header=True, header_style="bold magenta")
    episode_table.add_column("Dataset", style="cyan")
    episode_table.add_column("Episode", justify="center")
    episode_table.add_column("Return %", justify="right")
    episode_table.add_column("Reward", justify="right")
    episode_table.add_column("Final Value", justify="right")
    episode_table.add_column("Total Steps", justify="right")

    # Function to add rows for a dataset (train/test)
    def add_episode_rows(dataset_name, episodes):
        if not episodes:
            return
        for i, ep in enumerate(episodes):
            if "error" in ep:
                continue
            initial = ep.get("initial_value", 0)
            final = ep.get("final_value", 0)
            reward = ep.get("total_reward", 0)

            ret = ((final - initial) / initial) * 100 if initial != 0 else 0
            ret_color = "green" if ret >= 0 else "red"
            reward_color = "green" if reward >= 0 else "red"

            episode_table.add_row(
                dataset_name,
                str(i + 1),
                f"[{ret_color}]{ret:.2f}%[/{ret_color}]",
                f"[{reward_color}]{reward:.2f}[/{reward_color}]",
                f"${final:,.2f}",
                str(ep.get("steps", "N/A")),
            )

    add_episode_rows("Train", result.train_episodes)
    add_episode_rows("Test", result.test_episodes)

    if result.train_episodes or result.test_episodes:
        console.print(episode_table)
    else:
        console.print("[yellow]No episode data available.[/yellow]")

    # --- Action Distribution Table ---
    all_episodes = result.train_episodes + result.test_episodes
    all_actions: dict = {}
    total_steps = 0
    for ep in all_episodes:
        if "error" not in ep:
            total_steps += ep.get("steps", 0)
            for action_type, count in ep.get("actions_taken", {}).items():
                all_actions[action_type] = all_actions.get(action_type, 0) + count

    if all_actions and total_steps > 0:
        action_table = Table(title="Action Distribution (all episodes)", show_header=True, header_style="bold cyan")
        action_table.add_column("Action", style="cyan")
        action_table.add_column("Count", justify="right")
        action_table.add_column("% of Steps", justify="right", style="yellow")
        for action_type, count in sorted(all_actions.items()):
            action_table.add_row(action_type, str(count), f"{count / total_steps * 100:.1f}%")
        console.print(action_table)

inspect_batch(results) staticmethod

Inspect and display a summary of a batch of experiments.

Parameters:

Name Type Description Default
results List[ExperimentResult]

List of experiment results.

required
Source code in src/quantrl_lab/experiments/backtesting/runner.py
@staticmethod
def inspect_batch(results: List[ExperimentResult]) -> None:
    """
    Inspect and display a summary of a batch of experiments.

    Args:
        results (List[ExperimentResult]): List of experiment results.
    """
    console.print(f"\n[bold magenta]{'='*80}[/bold magenta]")
    console.print("[bold magenta]BATCH EXPERIMENT SUMMARY[/bold magenta]")
    console.print(f"[bold magenta]{'='*80}[/bold magenta]")
    # Preset column removed
    table = Table(title="Batch Results", show_header=True, header_style="bold magenta")
    table.add_column("ID", style="dim", no_wrap=True)
    table.add_column("Algo", style="cyan")
    table.add_column("Env", style="yellow")
    table.add_column("Status", justify="center")
    table.add_column("Train Ret %", justify="right")
    table.add_column("Test Ret %", justify="right")
    table.add_column("Test Sharpe", justify="right")
    table.add_column("Time (s)", justify="right")
    table.add_column("Top Feature", style="dim")

    for res in results:
        job = res.job
        metrics = res.metrics

        status_style = "green" if res.status == "completed" else "red"
        status_str = f"[{status_style}]{res.status}[/{status_style}]"

        if res.status == "completed":
            train_ret = metrics.get("train_avg_return_pct", 0.0)
            test_ret = metrics.get("test_avg_return_pct", 0.0)
            test_sharpe = metrics.get("test_avg_sharpe_ratio", 0.0)

            train_color = "green" if train_ret >= 0 else "red"
            test_color = "green" if test_ret >= 0 else "red"

            train_str = f"[{train_color}]{train_ret:.2f}%[/{train_color}]"
            test_str = f"[{test_color}]{test_ret:.2f}%[/{test_color}]"
            sharpe_str = f"{test_sharpe:.2f}"
        else:
            train_str = "-"
            test_str = "-"
            sharpe_str = "-"

        top_feat_str = "-"
        if res.top_features:
            # Get the highest correlated feature
            top_feat_name, top_feat_corr = list(res.top_features.items())[0]
            top_feat_str = f"{top_feat_name} ({top_feat_corr:+.2f})"

        # Add row to table
        table.add_row(
            job.id,
            job.algorithm_class.__name__,
            job.env_config.name,
            status_str,
            train_str,
            test_str,
            sharpe_str,
            f"{res.execution_time:.1f}",
            top_feat_str,
        )

    console.print(table)

create_env_config(train_env_factory, test_env_factory) staticmethod

Helper method to create env_config from individual factory functions.

Parameters:

Name Type Description Default
train_env_factory Callable

Function that creates training environment

required
test_env_factory Callable

Function that creates test environment

required

Returns:

Name Type Description
BacktestEnvironmentConfig BacktestEnvironmentConfig

Environment configuration object

Source code in src/quantrl_lab/experiments/backtesting/runner.py
@staticmethod
def create_env_config(train_env_factory: Callable, test_env_factory: Callable) -> BacktestEnvironmentConfig:
    """
    Helper method to create env_config from individual factory
    functions.

    Args:
        train_env_factory (Callable): Function that creates training environment
        test_env_factory (Callable): Function that creates test environment

    Returns:
        BacktestEnvironmentConfig: Environment configuration object
    """
    return BacktestEnvironmentConfig(train_env_factory=train_env_factory, test_env_factory=test_env_factory)

create_env_config_factory(train_data, test_data, action_strategy, reward_strategy, observation_strategy, eval_data=None, initial_balance=100000.0, transaction_cost_pct=0.001, slippage_pct=0.0005, window_size=20, order_expiration_steps=5) staticmethod

Creates a configuration object with environment factories. DEPRECATED: Use BacktestEnvironmentBuilder instead.

Parameters:

Name Type Description Default
train_data DataFrame

DataFrame for the training environment.

required
test_data DataFrame

DataFrame for the test environment.

required
action_strategy BaseActionStrategy

The action strategy to use.

required
reward_strategy BaseRewardStrategy

The reward strategy to use.

required
observation_strategy BaseObservationStrategy

The observation strategy.

required
eval_data Optional[DataFrame]

DataFrame for evaluation.

None
initial_balance float

Initial portfolio balance.

100000.0
transaction_cost_pct float

Transaction cost percentage.

0.001
window_size int

The size of the observation window.

20

Returns:

Name Type Description
BacktestEnvironmentConfig BacktestEnvironmentConfig

Configuration object containing factories and metadata.

Source code in src/quantrl_lab/experiments/backtesting/runner.py
@staticmethod
def create_env_config_factory(
    train_data: "pd.DataFrame",
    test_data: "pd.DataFrame",
    action_strategy: "BaseActionStrategy",
    reward_strategy: "BaseRewardStrategy",
    observation_strategy: "BaseObservationStrategy",
    eval_data: Optional["pd.DataFrame"] = None,
    initial_balance: float = 100000.0,
    transaction_cost_pct: float = 0.001,
    slippage_pct: float = 0.0005,
    window_size: int = 20,
    order_expiration_steps: int = 5,
) -> BacktestEnvironmentConfig:
    """
    Creates a configuration object with environment factories.
    DEPRECATED: Use BacktestEnvironmentBuilder instead.

    Args:
        train_data (pd.DataFrame): DataFrame for the training environment.
        test_data (pd.DataFrame): DataFrame for the test environment.
        action_strategy (BaseActionStrategy): The action strategy to use.
        reward_strategy (BaseRewardStrategy): The reward strategy to use.
        observation_strategy (BaseObservationStrategy): The observation strategy.
        eval_data (Optional[pd.DataFrame], optional): DataFrame for evaluation.
        initial_balance (float, optional): Initial portfolio balance.
        transaction_cost_pct (float, optional): Transaction cost percentage.
        window_size (int, optional): The size of the observation window.

    Returns:
        BacktestEnvironmentConfig: Configuration object containing factories and metadata.
    """
    from quantrl_lab.environments.stock.components.config import SingleStockEnvConfig
    from quantrl_lab.environments.stock.single import SingleStockTradingEnv

    # Helper function to create a single environment factory
    def _create_factory(data: "pd.DataFrame"):
        return lambda: SingleStockTradingEnv(
            data=data,
            config=SingleStockEnvConfig(
                initial_balance=initial_balance,
                transaction_cost_pct=transaction_cost_pct,
                slippage=slippage_pct,
                window_size=window_size,
                order_expiration_steps=order_expiration_steps,
            ),
            action_strategy=action_strategy,
            reward_strategy=reward_strategy,
            observation_strategy=observation_strategy,
            price_column="Close",  # Default to Close
        )

    # Capture parameters for reproducibility
    parameters = {
        "initial_balance": initial_balance,
        "transaction_cost_pct": transaction_cost_pct,
        "slippage_pct": slippage_pct,
        "window_size": window_size,
        "order_expiration_steps": order_expiration_steps,
        "action_strategy": action_strategy.__class__.__name__,
        "reward_strategy": reward_strategy.__class__.__name__,
        "observation_strategy": observation_strategy.__class__.__name__,
    }

    eval_factory = _create_factory(eval_data) if eval_data is not None else None

    return BacktestEnvironmentConfig(
        train_env_factory=_create_factory(train_data),
        test_env_factory=_create_factory(test_data),
        eval_env_factory=eval_factory,
        parameters=parameters,
        description=f"Standard Stock Env (Window: {window_size})",
    )

optuna_runner

OptunaRunner

A hyperparameter tuning runner using Optuna with SQLite storage.

Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
class OptunaRunner:
    """A hyperparameter tuning runner using Optuna with SQLite
    storage."""

    def __init__(
        self,
        runner: BacktestRunner,
        storage_url: Optional[str] = None,
    ):
        """
        Initialize the runner with Optuna configuration.

        Args:
            runner: BacktestRunner instance
            storage_url: Optuna storage URL (defaults to sqlite:///optuna_studies.db)
        """
        self.runner = runner
        self.storage_url = storage_url or "sqlite:///optuna_studies.db"
        db_path = self.storage_url.replace("sqlite:///", "")
        console.print(
            Panel(
                f"[bold blue]Optuna Storage URL:[/bold blue] {self.storage_url}\n"
                f"[bold blue]Database file at:[/bold blue] [green]{os.path.abspath(db_path)}[/green]",
                title="[bold yellow]QuantRL-Lab Optuna Runner[/bold yellow]",
                border_style="blue",
            )
        )

    def create_objective_function(
        self,
        algo_class,
        env_config: Union[Dict[str, Any], BacktestEnvironmentConfig],
        search_space: Dict[str, Any],
        fixed_params: Optional[Dict[str, Any]] = None,
        total_timesteps: int = 50000,
        num_eval_episodes: int = 5,
        optimization_metric: str = "test_avg_return_pct",
    ) -> Callable:
        """
        Create an objective function for Optuna optimization.

        Args:
            algo_class: RL algorithm class (PPO, SAC, A2C, etc.)
            env_config: Environment configuration object or dictionary
            search_space: Dictionary defining the hyperparameter search space
            fixed_params: Fixed parameters that won't be optimized
            total_timesteps: Number of training timesteps
            num_eval_episodes: Number of evaluation episodes
            optimization_metric: Metric to optimize (default: test_avg_return_pct).
                                 Can use any metric from MetricsCalculator (e.g., 'test_avg_sharpe_ratio').

        Returns:
            Callable: The objective function for Optuna.
        """
        fixed_params = fixed_params or {}

        # Normalize env_config
        if isinstance(env_config, dict):
            # Legacy support
            env_config_obj = BacktestEnvironmentConfig.from_dict(env_config)
        else:
            env_config_obj = env_config

        def objective(trial: optuna.Trial) -> float:
            try:
                # Sample hyperparameters from the search space
                params = self._sample_hyperparameters(trial, search_space)
                params.update(fixed_params)

                console.print(
                    f"[bold cyan]Starting Trial {trial.number}[/bold cyan] with params: [yellow]{params}[/yellow]"
                )

                # Create Job
                job = ExperimentJob(
                    algorithm_class=algo_class,
                    env_config=env_config_obj,
                    algorithm_config=params,
                    total_timesteps=total_timesteps,
                    num_eval_episodes=num_eval_episodes,
                    tags={"trial": str(trial.number), "tuner": "optuna"},
                )

                # Run the backtesting experiment
                result = self.runner.run_job(job)

                if result.status == "failed":
                    raise RuntimeError(f"Job failed: {result.error}")

                # Extract the target value for Optuna to optimize
                target_value = result.metrics.get(optimization_metric)
                if target_value is None:
                    console.print(
                        f"[bold yellow]⚠️ Optimization metric '[/bold yellow][cyan]{optimization_metric}[/cyan]"
                        f"[bold yellow]' not found in results.[/bold yellow] "
                        "Defaulting to -1000.0.",
                        style="yellow",
                    )
                    target_value = -1000.0

                console.print(
                    f"[bold green]Trial {trial.number} finished.[/bold green] "
                    f"[blue]{optimization_metric}[/blue] = [cyan]{target_value:.4f}[/cyan] ✓"
                )

                return target_value

            except Exception as e:
                console.print(
                    f"[bold red]❌ Trial {trial.number} failed with an exception:[/bold red] {str(e)}", style="red"
                )
                console.print_exception()
                raise optuna.exceptions.TrialPruned()

        return objective

    def _sample_hyperparameters(self, trial: optuna.Trial, search_space: Dict[str, Any]) -> Dict[str, Any]:
        """Sample hyperparameters from the defined search space."""
        params = {}
        for param_name, param_config in search_space.items():
            param_type = param_config["type"]
            if param_type == "float":
                params[param_name] = trial.suggest_float(
                    param_name, param_config["low"], param_config["high"], log=param_config.get("log", False)
                )
            elif param_type == "int":
                params[param_name] = trial.suggest_int(
                    param_name, param_config["low"], param_config["high"], log=param_config.get("log", False)
                )
            elif param_type == "categorical":
                params[param_name] = trial.suggest_categorical(param_name, param_config["choices"])
            elif param_type == "discrete_uniform":
                params[param_name] = trial.suggest_float(
                    param_name, param_config["low"], param_config["high"], step=param_config["q"]
                )
            else:
                console.print(
                    f"[bold yellow]⚠️ Unknown parameter type:[/bold yellow] {param_type} for [cyan]{param_name}[/cyan]",
                    style="yellow",
                )
        return params

    def optimize_hyperparameters(
        self,
        algo_class,
        env_config: Dict[str, Any],
        search_space: Dict[str, Any],
        study_name: str,
        n_trials: int = 100,
        fixed_params: Optional[Dict[str, Any]] = None,
        total_timesteps: int = 50000,
        num_eval_episodes: int = 5,
        optimization_metric: str = "test_avg_return_pct",
        direction: str = "maximize",
        timeout: Optional[float] = None,
        n_jobs: int = 1,
        sampler: Optional[optuna.samplers.BaseSampler] = None,
        pruner: Optional[optuna.pruners.BasePruner] = None,
    ) -> optuna.Study:
        """Run hyperparameter optimization using Optuna."""
        sampler = sampler or optuna.samplers.TPESampler(seed=42)
        pruner = pruner or optuna.pruners.MedianPruner()

        study = optuna.create_study(
            study_name=study_name,
            storage=self.storage_url,
            direction=direction,
            sampler=sampler,
            pruner=pruner,
            load_if_exists=True,
        )

        objective_func = self.create_objective_function(
            algo_class=algo_class,
            env_config=env_config,
            search_space=search_space,
            fixed_params=fixed_params,
            total_timesteps=total_timesteps,
            num_eval_episodes=num_eval_episodes,
            optimization_metric=optimization_metric,
        )

        console.rule(
            (f"[bold blue]Starting optimization for {n_trials} " f"trials[/bold blue]"),
        )
        console.print()

        try:
            if n_jobs > 1 and self.storage_url.startswith("sqlite"):
                console.print(
                    "[bold yellow]⚠️ n_jobs > 1 is not safe with SQLite storage "
                    "— falling back to n_jobs=1.[/bold yellow]"
                )
                n_jobs = 1

            study.optimize(
                objective_func,
                n_trials=n_trials,
                timeout=timeout,
                n_jobs=n_jobs,
            )

            console.rule("[bold green]Optimization finished successfully[/bold green]")
            completed = [t for t in study.trials if t.value is not None]
            if completed:
                console.print(f"[bold blue]Best trial value:[/bold blue] [cyan]{study.best_value:.4f}[/cyan]")
                console.print("[bold blue]Best params:[/bold blue]")
                console.print(study.best_params, style="yellow")
            else:
                console.print("[bold yellow]⚠️ No trials completed successfully — all were pruned.[/bold yellow]")

        except Exception as e:
            console.print(f"[bold red]❌ Optimization loop failed with an exception:[/bold red] {str(e)}", style="red")
            console.print_exception()
            raise

        return study

__init__(runner, storage_url=None)

Initialize the runner with Optuna configuration.

Parameters:

Name Type Description Default
runner BacktestRunner

BacktestRunner instance

required
storage_url Optional[str]

Optuna storage URL (defaults to sqlite:///optuna_studies.db)

None
Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
def __init__(
    self,
    runner: BacktestRunner,
    storage_url: Optional[str] = None,
):
    """
    Initialize the runner with Optuna configuration.

    Args:
        runner: BacktestRunner instance
        storage_url: Optuna storage URL (defaults to sqlite:///optuna_studies.db)
    """
    self.runner = runner
    self.storage_url = storage_url or "sqlite:///optuna_studies.db"
    db_path = self.storage_url.replace("sqlite:///", "")
    console.print(
        Panel(
            f"[bold blue]Optuna Storage URL:[/bold blue] {self.storage_url}\n"
            f"[bold blue]Database file at:[/bold blue] [green]{os.path.abspath(db_path)}[/green]",
            title="[bold yellow]QuantRL-Lab Optuna Runner[/bold yellow]",
            border_style="blue",
        )
    )

create_objective_function(algo_class, env_config, search_space, fixed_params=None, total_timesteps=50000, num_eval_episodes=5, optimization_metric='test_avg_return_pct')

Create an objective function for Optuna optimization.

Parameters:

Name Type Description Default
algo_class

RL algorithm class (PPO, SAC, A2C, etc.)

required
env_config Union[Dict[str, Any], BacktestEnvironmentConfig]

Environment configuration object or dictionary

required
search_space Dict[str, Any]

Dictionary defining the hyperparameter search space

required
fixed_params Optional[Dict[str, Any]]

Fixed parameters that won't be optimized

None
total_timesteps int

Number of training timesteps

50000
num_eval_episodes int

Number of evaluation episodes

5
optimization_metric str

Metric to optimize (default: test_avg_return_pct). Can use any metric from MetricsCalculator (e.g., 'test_avg_sharpe_ratio').

'test_avg_return_pct'

Returns:

Name Type Description
Callable Callable

The objective function for Optuna.

Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
def create_objective_function(
    self,
    algo_class,
    env_config: Union[Dict[str, Any], BacktestEnvironmentConfig],
    search_space: Dict[str, Any],
    fixed_params: Optional[Dict[str, Any]] = None,
    total_timesteps: int = 50000,
    num_eval_episodes: int = 5,
    optimization_metric: str = "test_avg_return_pct",
) -> Callable:
    """
    Create an objective function for Optuna optimization.

    Args:
        algo_class: RL algorithm class (PPO, SAC, A2C, etc.)
        env_config: Environment configuration object or dictionary
        search_space: Dictionary defining the hyperparameter search space
        fixed_params: Fixed parameters that won't be optimized
        total_timesteps: Number of training timesteps
        num_eval_episodes: Number of evaluation episodes
        optimization_metric: Metric to optimize (default: test_avg_return_pct).
                             Can use any metric from MetricsCalculator (e.g., 'test_avg_sharpe_ratio').

    Returns:
        Callable: The objective function for Optuna.
    """
    fixed_params = fixed_params or {}

    # Normalize env_config
    if isinstance(env_config, dict):
        # Legacy support
        env_config_obj = BacktestEnvironmentConfig.from_dict(env_config)
    else:
        env_config_obj = env_config

    def objective(trial: optuna.Trial) -> float:
        try:
            # Sample hyperparameters from the search space
            params = self._sample_hyperparameters(trial, search_space)
            params.update(fixed_params)

            console.print(
                f"[bold cyan]Starting Trial {trial.number}[/bold cyan] with params: [yellow]{params}[/yellow]"
            )

            # Create Job
            job = ExperimentJob(
                algorithm_class=algo_class,
                env_config=env_config_obj,
                algorithm_config=params,
                total_timesteps=total_timesteps,
                num_eval_episodes=num_eval_episodes,
                tags={"trial": str(trial.number), "tuner": "optuna"},
            )

            # Run the backtesting experiment
            result = self.runner.run_job(job)

            if result.status == "failed":
                raise RuntimeError(f"Job failed: {result.error}")

            # Extract the target value for Optuna to optimize
            target_value = result.metrics.get(optimization_metric)
            if target_value is None:
                console.print(
                    f"[bold yellow]⚠️ Optimization metric '[/bold yellow][cyan]{optimization_metric}[/cyan]"
                    f"[bold yellow]' not found in results.[/bold yellow] "
                    "Defaulting to -1000.0.",
                    style="yellow",
                )
                target_value = -1000.0

            console.print(
                f"[bold green]Trial {trial.number} finished.[/bold green] "
                f"[blue]{optimization_metric}[/blue] = [cyan]{target_value:.4f}[/cyan] ✓"
            )

            return target_value

        except Exception as e:
            console.print(
                f"[bold red]❌ Trial {trial.number} failed with an exception:[/bold red] {str(e)}", style="red"
            )
            console.print_exception()
            raise optuna.exceptions.TrialPruned()

    return objective

optimize_hyperparameters(algo_class, env_config, search_space, study_name, n_trials=100, fixed_params=None, total_timesteps=50000, num_eval_episodes=5, optimization_metric='test_avg_return_pct', direction='maximize', timeout=None, n_jobs=1, sampler=None, pruner=None)

Run hyperparameter optimization using Optuna.

Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
def optimize_hyperparameters(
    self,
    algo_class,
    env_config: Dict[str, Any],
    search_space: Dict[str, Any],
    study_name: str,
    n_trials: int = 100,
    fixed_params: Optional[Dict[str, Any]] = None,
    total_timesteps: int = 50000,
    num_eval_episodes: int = 5,
    optimization_metric: str = "test_avg_return_pct",
    direction: str = "maximize",
    timeout: Optional[float] = None,
    n_jobs: int = 1,
    sampler: Optional[optuna.samplers.BaseSampler] = None,
    pruner: Optional[optuna.pruners.BasePruner] = None,
) -> optuna.Study:
    """Run hyperparameter optimization using Optuna."""
    sampler = sampler or optuna.samplers.TPESampler(seed=42)
    pruner = pruner or optuna.pruners.MedianPruner()

    study = optuna.create_study(
        study_name=study_name,
        storage=self.storage_url,
        direction=direction,
        sampler=sampler,
        pruner=pruner,
        load_if_exists=True,
    )

    objective_func = self.create_objective_function(
        algo_class=algo_class,
        env_config=env_config,
        search_space=search_space,
        fixed_params=fixed_params,
        total_timesteps=total_timesteps,
        num_eval_episodes=num_eval_episodes,
        optimization_metric=optimization_metric,
    )

    console.rule(
        (f"[bold blue]Starting optimization for {n_trials} " f"trials[/bold blue]"),
    )
    console.print()

    try:
        if n_jobs > 1 and self.storage_url.startswith("sqlite"):
            console.print(
                "[bold yellow]⚠️ n_jobs > 1 is not safe with SQLite storage "
                "— falling back to n_jobs=1.[/bold yellow]"
            )
            n_jobs = 1

        study.optimize(
            objective_func,
            n_trials=n_trials,
            timeout=timeout,
            n_jobs=n_jobs,
        )

        console.rule("[bold green]Optimization finished successfully[/bold green]")
        completed = [t for t in study.trials if t.value is not None]
        if completed:
            console.print(f"[bold blue]Best trial value:[/bold blue] [cyan]{study.best_value:.4f}[/cyan]")
            console.print("[bold blue]Best params:[/bold blue]")
            console.print(study.best_params, style="yellow")
        else:
            console.print("[bold yellow]⚠️ No trials completed successfully — all were pruned.[/bold yellow]")

    except Exception as e:
        console.print(f"[bold red]❌ Optimization loop failed with an exception:[/bold red] {str(e)}", style="red")
        console.print_exception()
        raise

    return study

create_ppo_search_space()

Create a default search space for PPO hyperparameters.

Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
def create_ppo_search_space() -> Dict[str, Any]:
    """Create a default search space for PPO hyperparameters."""
    return {
        "learning_rate": {"type": "float", "low": 1e-5, "high": 1e-2, "log": True},
        "n_steps": {"type": "categorical", "choices": [256, 512, 1024, 2048, 4096]},
        "batch_size": {"type": "categorical", "choices": [32, 64, 128, 256, 512]},
        "gamma": {"type": "float", "low": 0.9, "high": 0.9999},
        "gae_lambda": {"type": "float", "low": 0.8, "high": 1.0},
        "clip_range": {"type": "float", "low": 0.1, "high": 0.4},
        "ent_coef": {"type": "float", "low": 1e-8, "high": 1e-1, "log": True},
    }

create_sac_search_space()

Create a default search space for SAC hyperparameters.

Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
def create_sac_search_space() -> Dict[str, Any]:
    """Create a default search space for SAC hyperparameters."""
    return {
        "learning_rate": {"type": "float", "low": 1e-5, "high": 1e-2, "log": True},
        "batch_size": {"type": "categorical", "choices": [64, 128, 256, 512]},
        "gamma": {"type": "float", "low": 0.9, "high": 0.9999},
        "tau": {"type": "float", "low": 0.001, "high": 0.1},
        "train_freq": {"type": "categorical", "choices": [1, 4, 8, 16]},
        "gradient_steps": {"type": "categorical", "choices": [1, 2, 4, 8]},
        "target_update_interval": {"type": "categorical", "choices": [1, 2, 4, 8]},
    }

create_a2c_search_space()

Create a default search space for A2C hyperparameters.

Source code in src/quantrl_lab/experiments/tuning/optuna_runner.py
def create_a2c_search_space() -> Dict[str, Any]:
    """Create a default search space for A2C hyperparameters."""
    return {
        "learning_rate": {"type": "float", "low": 1e-5, "high": 1e-2, "log": True},
        "n_steps": {"type": "categorical", "choices": [5, 16, 32, 64, 128]},
        "gamma": {"type": "float", "low": 0.9, "high": 0.9999},
        "gae_lambda": {"type": "float", "low": 0.8, "high": 1.0},
        "ent_coef": {"type": "float", "low": 1e-8, "high": 1e-1, "log": True},
        "vf_coef": {"type": "float", "low": 0.1, "high": 1.0},
    }

Data Partitioning

date_range

Date range-based data splitter for time series data.

DateRangeSplitter

Split DataFrame by explicit date ranges.

This splitter divides data based on specified date ranges, useful for creating specific train/test periods.

Example

splitter = DateRangeSplitter({ ... "train": ("2020-01-01", "2021-12-31"), ... "test": ("2022-01-01", "2022-12-31") ... }) splits = splitter.split(df) train_df = splits["train"] test_df = splits["test"]

Source code in src/quantrl_lab/data/partitioning/date_range.py
class DateRangeSplitter:
    """
    Split DataFrame by explicit date ranges.

    This splitter divides data based on specified date ranges,
    useful for creating specific train/test periods.

    Example:
        >>> splitter = DateRangeSplitter({
        ...     "train": ("2020-01-01", "2021-12-31"),
        ...     "test": ("2022-01-01", "2022-12-31")
        ... })
        >>> splits = splitter.split(df)
        >>> train_df = splits["train"]
        >>> test_df = splits["test"]
    """

    def __init__(self, ranges: Dict[str, Tuple[str, str]]):
        """
        Initialize DateRangeSplitter.

        Args:
            ranges (Dict[str, Tuple[str, str]]): Dictionary mapping split names to
                (start_date, end_date) tuples. Dates can be strings or datetime objects.
                Example: {"train": ("2020-01-01", "2021-12-31")}

        Raises:
            ValueError: If ranges are invalid or empty.
        """
        if not ranges:
            raise ValueError("Ranges dictionary cannot be empty")

        # Validate range format
        for name, date_range in ranges.items():
            if not isinstance(date_range, (tuple, list)) or len(date_range) != 2:
                raise ValueError(
                    f"Invalid range for '{name}': {date_range}. Must be a tuple/list of (start_date, end_date)"
                )

            start_date, end_date = date_range
            try:
                start_dt = pd.to_datetime(start_date)
                end_dt = pd.to_datetime(end_date)
                if start_dt > end_dt:
                    raise ValueError(f"Start date {start_date} is after end date {end_date} for '{name}'")
            except Exception as e:
                raise ValueError(f"Invalid dates for '{name}': {e}")

        self.ranges = ranges

    def split(self, df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
        """
        Split DataFrame by date ranges.

        Args:
            df (pd.DataFrame): Input DataFrame to split. Must have a date column
                (Date, date, timestamp, or Timestamp).

        Returns:
            Dict[str, pd.DataFrame]: Dictionary of split DataFrames.

        Raises:
            ValueError: If DataFrame is empty or date column not found.
        """
        if df.empty:
            raise ValueError("Cannot split empty DataFrame")

        # Find date column
        date_column = next((col for col in ["Date", "date", "timestamp", "Timestamp"] if col in df.columns), None)

        if not date_column:
            raise ValueError(
                "Date column not found. DataFrame must contain one of: 'Date', 'date', 'timestamp', 'Timestamp'"
            )

        # Prepare DataFrame
        df = df.copy()
        df[date_column] = pd.to_datetime(df[date_column]).dt.tz_localize(None)
        df = df.sort_values(by=date_column).reset_index(drop=True)

        split_data = {}
        metadata_ranges = {}
        metadata_shapes = {}

        for name, (start_date, end_date) in self.ranges.items():
            # Convert to datetime
            start_dt = pd.to_datetime(start_date)
            end_dt = pd.to_datetime(end_date)

            # Filter data
            mask = (df[date_column] >= start_dt) & (df[date_column] <= end_dt)
            subset = df[mask].copy()
            split_data[name] = subset

            # Track metadata
            if not subset.empty:
                metadata_ranges[name] = {
                    "start": subset[date_column].min().strftime("%Y-%m-%d"),
                    "end": subset[date_column].max().strftime("%Y-%m-%d"),
                }
            else:
                metadata_ranges[name] = {
                    "start": start_date if isinstance(start_date, str) else start_date.strftime("%Y-%m-%d"),
                    "end": end_date if isinstance(end_date, str) else end_date.strftime("%Y-%m-%d"),
                }

            metadata_shapes[name] = subset.shape

        # Store metadata for get_metadata() call
        self._last_metadata = {
            "date_ranges": metadata_ranges,
            "final_shapes": metadata_shapes,
        }

        return split_data

    def get_metadata(self) -> Dict:
        """
        Return metadata about the split.

        Returns:
            Dict: Dictionary containing:
                - type: "date_range"
                - ranges: Configuration used
                - date_ranges: Actual date ranges in each split
                - final_shapes: Shape of each split DataFrame
        """
        metadata = {
            "type": "date_range",
            "ranges": {
                name: (
                    start if isinstance(start, str) else start.strftime("%Y-%m-%d"),
                    end if isinstance(end, str) else end.strftime("%Y-%m-%d"),
                )
                for name, (start, end) in self.ranges.items()
            },
        }

        # Add metadata from last split operation if available
        if hasattr(self, "_last_metadata"):
            metadata.update(self._last_metadata)

        return metadata

__init__(ranges)

Initialize DateRangeSplitter.

Parameters:

Name Type Description Default
ranges Dict[str, Tuple[str, str]]

Dictionary mapping split names to (start_date, end_date) tuples. Dates can be strings or datetime objects. Example: {"train": ("2020-01-01", "2021-12-31")}

required

Raises:

Type Description
ValueError

If ranges are invalid or empty.

Source code in src/quantrl_lab/data/partitioning/date_range.py
def __init__(self, ranges: Dict[str, Tuple[str, str]]):
    """
    Initialize DateRangeSplitter.

    Args:
        ranges (Dict[str, Tuple[str, str]]): Dictionary mapping split names to
            (start_date, end_date) tuples. Dates can be strings or datetime objects.
            Example: {"train": ("2020-01-01", "2021-12-31")}

    Raises:
        ValueError: If ranges are invalid or empty.
    """
    if not ranges:
        raise ValueError("Ranges dictionary cannot be empty")

    # Validate range format
    for name, date_range in ranges.items():
        if not isinstance(date_range, (tuple, list)) or len(date_range) != 2:
            raise ValueError(
                f"Invalid range for '{name}': {date_range}. Must be a tuple/list of (start_date, end_date)"
            )

        start_date, end_date = date_range
        try:
            start_dt = pd.to_datetime(start_date)
            end_dt = pd.to_datetime(end_date)
            if start_dt > end_dt:
                raise ValueError(f"Start date {start_date} is after end date {end_date} for '{name}'")
        except Exception as e:
            raise ValueError(f"Invalid dates for '{name}': {e}")

    self.ranges = ranges

split(df)

Split DataFrame by date ranges.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame to split. Must have a date column (Date, date, timestamp, or Timestamp).

required

Returns:

Type Description
Dict[str, DataFrame]

Dict[str, pd.DataFrame]: Dictionary of split DataFrames.

Raises:

Type Description
ValueError

If DataFrame is empty or date column not found.

Source code in src/quantrl_lab/data/partitioning/date_range.py
def split(self, df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
    """
    Split DataFrame by date ranges.

    Args:
        df (pd.DataFrame): Input DataFrame to split. Must have a date column
            (Date, date, timestamp, or Timestamp).

    Returns:
        Dict[str, pd.DataFrame]: Dictionary of split DataFrames.

    Raises:
        ValueError: If DataFrame is empty or date column not found.
    """
    if df.empty:
        raise ValueError("Cannot split empty DataFrame")

    # Find date column
    date_column = next((col for col in ["Date", "date", "timestamp", "Timestamp"] if col in df.columns), None)

    if not date_column:
        raise ValueError(
            "Date column not found. DataFrame must contain one of: 'Date', 'date', 'timestamp', 'Timestamp'"
        )

    # Prepare DataFrame
    df = df.copy()
    df[date_column] = pd.to_datetime(df[date_column]).dt.tz_localize(None)
    df = df.sort_values(by=date_column).reset_index(drop=True)

    split_data = {}
    metadata_ranges = {}
    metadata_shapes = {}

    for name, (start_date, end_date) in self.ranges.items():
        # Convert to datetime
        start_dt = pd.to_datetime(start_date)
        end_dt = pd.to_datetime(end_date)

        # Filter data
        mask = (df[date_column] >= start_dt) & (df[date_column] <= end_dt)
        subset = df[mask].copy()
        split_data[name] = subset

        # Track metadata
        if not subset.empty:
            metadata_ranges[name] = {
                "start": subset[date_column].min().strftime("%Y-%m-%d"),
                "end": subset[date_column].max().strftime("%Y-%m-%d"),
            }
        else:
            metadata_ranges[name] = {
                "start": start_date if isinstance(start_date, str) else start_date.strftime("%Y-%m-%d"),
                "end": end_date if isinstance(end_date, str) else end_date.strftime("%Y-%m-%d"),
            }

        metadata_shapes[name] = subset.shape

    # Store metadata for get_metadata() call
    self._last_metadata = {
        "date_ranges": metadata_ranges,
        "final_shapes": metadata_shapes,
    }

    return split_data

get_metadata()

Return metadata about the split.

Returns:

Name Type Description
Dict Dict

Dictionary containing: - type: "date_range" - ranges: Configuration used - date_ranges: Actual date ranges in each split - final_shapes: Shape of each split DataFrame

Source code in src/quantrl_lab/data/partitioning/date_range.py
def get_metadata(self) -> Dict:
    """
    Return metadata about the split.

    Returns:
        Dict: Dictionary containing:
            - type: "date_range"
            - ranges: Configuration used
            - date_ranges: Actual date ranges in each split
            - final_shapes: Shape of each split DataFrame
    """
    metadata = {
        "type": "date_range",
        "ranges": {
            name: (
                start if isinstance(start, str) else start.strftime("%Y-%m-%d"),
                end if isinstance(end, str) else end.strftime("%Y-%m-%d"),
            )
            for name, (start, end) in self.ranges.items()
        },
    }

    # Add metadata from last split operation if available
    if hasattr(self, "_last_metadata"):
        metadata.update(self._last_metadata)

    return metadata

ratio

Ratio-based data splitter for time series data.

RatioSplitter

Split DataFrame by ratio (e.g., 70% train, 30% test).

This splitter divides data sequentially based on specified ratios, maintaining temporal order for time series data.

Example

splitter = RatioSplitter({"train": 0.7, "test": 0.3}) splits = splitter.split(df) train_df = splits["train"] test_df = splits["test"]

Source code in src/quantrl_lab/data/partitioning/ratio.py
class RatioSplitter:
    """
    Split DataFrame by ratio (e.g., 70% train, 30% test).

    This splitter divides data sequentially based on specified ratios,
    maintaining temporal order for time series data.

    Example:
        >>> splitter = RatioSplitter({"train": 0.7, "test": 0.3})
        >>> splits = splitter.split(df)
        >>> train_df = splits["train"]
        >>> test_df = splits["test"]
    """

    def __init__(self, ratios: Dict[str, float]):
        """
        Initialize RatioSplitter.

        Args:
            ratios (Dict[str, float]): Dictionary mapping split names to ratios.
                Ratios must sum to <= 1.0. Example: {"train": 0.7, "test": 0.3}

        Raises:
            ValueError: If ratios sum to > 1.0 or any ratio is invalid.
        """
        if not ratios:
            raise ValueError("Ratios dictionary cannot be empty")

        total_ratio = sum(ratios.values())
        if total_ratio > 1.0:
            raise ValueError(f"Ratios sum to {total_ratio:.2f}, which exceeds 1.0")

        for name, ratio in ratios.items():
            if ratio <= 0 or ratio > 1:
                raise ValueError(f"Invalid ratio for '{name}': {ratio}. Must be in range (0, 1]")

        self.ratios = ratios

    def split(self, df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
        """
        Split DataFrame by ratio.

        Args:
            df (pd.DataFrame): Input DataFrame to split. Should be sorted by time.

        Returns:
            Dict[str, pd.DataFrame]: Dictionary of split DataFrames.

        Raises:
            ValueError: If DataFrame is empty.
        """
        if df.empty:
            raise ValueError("Cannot split empty DataFrame")

        # Find date column for metadata
        date_column = next((col for col in ["Date", "date", "timestamp", "Timestamp"] if col in df.columns), None)

        if date_column:
            # Ensure date column is datetime and remove timezone
            df = df.copy()
            df[date_column] = pd.to_datetime(df[date_column]).dt.tz_localize(None)
            df = df.sort_values(by=date_column).reset_index(drop=True)

        total_len = len(df)
        start_idx = 0
        split_data = {}
        metadata_ranges = {}
        metadata_shapes = {}

        for name, ratio in self.ratios.items():
            end_idx = start_idx + int(total_len * ratio)
            subset = df.iloc[start_idx:end_idx].copy()
            split_data[name] = subset

            # Track metadata
            if not subset.empty and date_column:
                metadata_ranges[name] = {
                    "start": subset[date_column].min().strftime("%Y-%m-%d"),
                    "end": subset[date_column].max().strftime("%Y-%m-%d"),
                }
            metadata_shapes[name] = subset.shape
            start_idx = end_idx

        # Store metadata for get_metadata() call
        self._last_metadata = {
            "date_ranges": metadata_ranges,
            "final_shapes": metadata_shapes,
        }

        return split_data

    def get_metadata(self) -> Dict:
        """
        Return metadata about the split.

        Returns:
            Dict: Dictionary containing:
                - type: "ratio"
                - ratios: Configuration used
                - date_ranges: Date ranges for each split (if date column exists)
                - final_shapes: Shape of each split DataFrame
        """
        metadata = {
            "type": "ratio",
            "ratios": self.ratios,
        }

        # Add metadata from last split operation if available
        if hasattr(self, "_last_metadata"):
            metadata.update(self._last_metadata)

        return metadata

__init__(ratios)

Initialize RatioSplitter.

Parameters:

Name Type Description Default
ratios Dict[str, float]

Dictionary mapping split names to ratios. Ratios must sum to <= 1.0. Example: {"train": 0.7, "test": 0.3}

required

Raises:

Type Description
ValueError

If ratios sum to > 1.0 or any ratio is invalid.

Source code in src/quantrl_lab/data/partitioning/ratio.py
def __init__(self, ratios: Dict[str, float]):
    """
    Initialize RatioSplitter.

    Args:
        ratios (Dict[str, float]): Dictionary mapping split names to ratios.
            Ratios must sum to <= 1.0. Example: {"train": 0.7, "test": 0.3}

    Raises:
        ValueError: If ratios sum to > 1.0 or any ratio is invalid.
    """
    if not ratios:
        raise ValueError("Ratios dictionary cannot be empty")

    total_ratio = sum(ratios.values())
    if total_ratio > 1.0:
        raise ValueError(f"Ratios sum to {total_ratio:.2f}, which exceeds 1.0")

    for name, ratio in ratios.items():
        if ratio <= 0 or ratio > 1:
            raise ValueError(f"Invalid ratio for '{name}': {ratio}. Must be in range (0, 1]")

    self.ratios = ratios

split(df)

Split DataFrame by ratio.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame to split. Should be sorted by time.

required

Returns:

Type Description
Dict[str, DataFrame]

Dict[str, pd.DataFrame]: Dictionary of split DataFrames.

Raises:

Type Description
ValueError

If DataFrame is empty.

Source code in src/quantrl_lab/data/partitioning/ratio.py
def split(self, df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
    """
    Split DataFrame by ratio.

    Args:
        df (pd.DataFrame): Input DataFrame to split. Should be sorted by time.

    Returns:
        Dict[str, pd.DataFrame]: Dictionary of split DataFrames.

    Raises:
        ValueError: If DataFrame is empty.
    """
    if df.empty:
        raise ValueError("Cannot split empty DataFrame")

    # Find date column for metadata
    date_column = next((col for col in ["Date", "date", "timestamp", "Timestamp"] if col in df.columns), None)

    if date_column:
        # Ensure date column is datetime and remove timezone
        df = df.copy()
        df[date_column] = pd.to_datetime(df[date_column]).dt.tz_localize(None)
        df = df.sort_values(by=date_column).reset_index(drop=True)

    total_len = len(df)
    start_idx = 0
    split_data = {}
    metadata_ranges = {}
    metadata_shapes = {}

    for name, ratio in self.ratios.items():
        end_idx = start_idx + int(total_len * ratio)
        subset = df.iloc[start_idx:end_idx].copy()
        split_data[name] = subset

        # Track metadata
        if not subset.empty and date_column:
            metadata_ranges[name] = {
                "start": subset[date_column].min().strftime("%Y-%m-%d"),
                "end": subset[date_column].max().strftime("%Y-%m-%d"),
            }
        metadata_shapes[name] = subset.shape
        start_idx = end_idx

    # Store metadata for get_metadata() call
    self._last_metadata = {
        "date_ranges": metadata_ranges,
        "final_shapes": metadata_shapes,
    }

    return split_data

get_metadata()

Return metadata about the split.

Returns:

Name Type Description
Dict Dict

Dictionary containing: - type: "ratio" - ratios: Configuration used - date_ranges: Date ranges for each split (if date column exists) - final_shapes: Shape of each split DataFrame

Source code in src/quantrl_lab/data/partitioning/ratio.py
def get_metadata(self) -> Dict:
    """
    Return metadata about the split.

    Returns:
        Dict: Dictionary containing:
            - type: "ratio"
            - ratios: Configuration used
            - date_ranges: Date ranges for each split (if date column exists)
            - final_shapes: Shape of each split DataFrame
    """
    metadata = {
        "type": "ratio",
        "ratios": self.ratios,
    }

    # Add metadata from last split operation if available
    if hasattr(self, "_last_metadata"):
        metadata.update(self._last_metadata)

    return metadata

Indicator Registry

registry

IndicatorMetadata dataclass

Metadata for registered indicators.

Attributes:

Name Type Description
name str

Indicator name (e.g., 'SMA', 'RSI')

func Callable

The callable function that computes the indicator

required_columns Set[str]

Set of required DataFrame columns (e.g., {'close', 'volume'})

output_columns List[str]

List of column names this indicator adds to DataFrame

dependencies List[str]

List of other indicator names that must be computed first

description str

Human-readable description of what the indicator computes

Source code in src/quantrl_lab/data/indicators/registry.py
@dataclass
class IndicatorMetadata:
    """
    Metadata for registered indicators.

    Attributes:
        name: Indicator name (e.g., 'SMA', 'RSI')
        func: The callable function that computes the indicator
        required_columns: Set of required DataFrame columns (e.g., {'close', 'volume'})
        output_columns: List of column names this indicator adds to DataFrame
        dependencies: List of other indicator names that must be computed first
        description: Human-readable description of what the indicator computes
    """

    name: str
    func: Callable
    required_columns: Set[str] = field(default_factory=set)
    output_columns: List[str] = field(default_factory=list)
    dependencies: List[str] = field(default_factory=list)
    description: str = ""

    def __post_init__(self):
        """Auto-generate output_columns if not provided."""
        if not self.output_columns:
            self.output_columns = [self.name]

__post_init__()

Auto-generate output_columns if not provided.

Source code in src/quantrl_lab/data/indicators/registry.py
def __post_init__(self):
    """Auto-generate output_columns if not provided."""
    if not self.output_columns:
        self.output_columns = [self.name]

IndicatorRegistry

Registry for technical indicators with metadata and validation.

This registry uses a decorator pattern to register indicator functions along with metadata about their requirements and outputs. It provides validation to ensure DataFrames have the required columns before applying indicators.

Example

@IndicatorRegistry.register( ... name='SMA', ... required_columns={'close'}, ... output_columns=['SMA'], ... description="Simple Moving Average" ... ) ... def sma(df, window=20, column='close'): ... df[f'SMA_{window}'] = df[column].rolling(window=window).mean() ... return df

Use with validation

df = IndicatorRegistry.apply_safe('SMA', df, window=20)

Source code in src/quantrl_lab/data/indicators/registry.py
class IndicatorRegistry:
    """
    Registry for technical indicators with metadata and validation.

    This registry uses a decorator pattern to register indicator functions
    along with metadata about their requirements and outputs. It provides
    validation to ensure DataFrames have the required columns before applying
    indicators.

    Example:
        >>> @IndicatorRegistry.register(
        ...     name='SMA',
        ...     required_columns={'close'},
        ...     output_columns=['SMA'],
        ...     description="Simple Moving Average"
        ... )
        ... def sma(df, window=20, column='close'):
        ...     df[f'SMA_{window}'] = df[column].rolling(window=window).mean()
        ...     return df
        >>>
        >>> # Use with validation
        >>> df = IndicatorRegistry.apply_safe('SMA', df, window=20)
    """

    # Mapping of indicator names to metadata objects
    _indicators: Dict[str, IndicatorMetadata] = {}

    @classmethod
    def register(
        cls,
        name: Optional[str] = None,
        required_columns: Optional[Set[str]] = None,
        output_columns: Optional[List[str]] = None,
        dependencies: Optional[List[str]] = None,
        description: str = "",
    ) -> Callable:
        """
        Register an indicator function with metadata.

        Args:
            name: Indicator name. If None, uses function name.
            required_columns: Set of required DataFrame columns (case-insensitive).
                Example: {'close'}, {'high', 'low', 'close'}
            output_columns: List of column names this indicator will add.
                If None, defaults to [name].
            dependencies: List of other indicator names that must be applied first.
            description: Human-readable description of the indicator.

        Returns:
            Decorator function that registers the indicator.

        Example:
            >>> @IndicatorRegistry.register(
            ...     name='RSI',
            ...     required_columns={'close'},
            ...     output_columns=['RSI'],
            ...     description="Relative Strength Index"
            ... )
            ... def rsi(df, window=14):
            ...     # calculation
            ...     return df
        """

        def decorator(func: Callable):
            indicator_name = name or func.__name__

            metadata = IndicatorMetadata(
                name=indicator_name,
                func=func,
                required_columns=required_columns or set(),
                output_columns=output_columns or [indicator_name],
                dependencies=dependencies or [],
                description=description,
            )

            cls._indicators[indicator_name] = metadata
            return func

        return decorator

    @classmethod
    def get(cls, name: str) -> Callable:
        """
        Get the indicator function by name.

        Args:
            name: Name of the indicator

        Raises:
            KeyError: If the name is not found in the registry

        Returns:
            Callable: Indicator function
        """
        if name not in cls._indicators:
            raise KeyError(f"Indicator '{name}' not registered")
        return cls._indicators[name].func

    @classmethod
    def get_metadata(cls, name: str) -> IndicatorMetadata:
        """
        Get the full metadata for an indicator.

        Args:
            name: Name of the indicator

        Raises:
            KeyError: If the name is not found in the registry

        Returns:
            IndicatorMetadata: Metadata object for the indicator
        """
        if name not in cls._indicators:
            raise KeyError(f"Indicator '{name}' not registered")
        return cls._indicators[name]

    @classmethod
    def list_all(cls) -> List[str]:
        """
        List all registered indicators.

        Returns:
            List[str]: List of indicator names
        """
        return list(cls._indicators.keys())

    @classmethod
    def validate_compatibility(cls, df: pd.DataFrame, indicator_name: str) -> bool:
        """
        Check if DataFrame has required columns for indicator.

        Performs case-insensitive column checking.

        Args:
            df: DataFrame to validate
            indicator_name: Name of the indicator to check

        Returns:
            bool: True if DataFrame has all required columns

        Raises:
            KeyError: If indicator is not registered
        """
        if indicator_name not in cls._indicators:
            raise KeyError(f"Indicator '{indicator_name}' not registered")

        metadata = cls._indicators[indicator_name]

        # Case-insensitive column check
        df_columns_lower = {col.lower() for col in df.columns}
        required_lower = {col.lower() for col in metadata.required_columns}

        return required_lower.issubset(df_columns_lower)

    @classmethod
    def get_missing_columns(cls, df: pd.DataFrame, indicator_name: str) -> Set[str]:
        """
        Get the set of missing required columns for an indicator.

        Args:
            df: DataFrame to check
            indicator_name: Name of the indicator

        Returns:
            Set[str]: Set of missing column names (from required_columns)

        Raises:
            KeyError: If indicator is not registered
        """
        if indicator_name not in cls._indicators:
            raise KeyError(f"Indicator '{indicator_name}' not registered")

        metadata = cls._indicators[indicator_name]

        # Case-insensitive column check
        df_columns_lower = {col.lower() for col in df.columns}
        required_lower = {col.lower() for col in metadata.required_columns}

        missing = required_lower - df_columns_lower

        # Map back to original case from metadata
        result = set()
        for req_col in metadata.required_columns:
            if req_col.lower() in missing:
                result.add(req_col)

        return result

    @classmethod
    def apply(cls, name: str, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        """
        Apply the indicator function to the dataframe.

        Args:
            name: Name of the indicator
            df: Input dataframe
            **kwargs: Additional keyword arguments to be passed to the indicator function

        Returns:
            pd.DataFrame: DataFrame with the indicator added

        Raises:
            KeyError: If indicator is not registered
        """
        indicator_func = cls.get(name)
        return indicator_func(df, **kwargs)

    @classmethod
    def apply_safe(cls, name: str, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        """
        Apply indicator with validation.

        Validates that the DataFrame has all required columns before applying
        the indicator. Raises a descriptive error if columns are missing.

        Args:
            name: Name of the indicator
            df: Input dataframe
            **kwargs: Additional keyword arguments to be passed to the indicator function

        Returns:
            pd.DataFrame: DataFrame with the indicator added

        Raises:
            KeyError: If indicator is not registered
            ValueError: If DataFrame is missing required columns
        """
        if not cls.validate_compatibility(df, name):
            missing = cls.get_missing_columns(df, name)
            raise ValueError(
                f"Cannot apply indicator '{name}': missing required columns {missing}. "
                f"Available columns: {list(df.columns)}"
            )

        return cls.apply(name, df, **kwargs)

register(name=None, required_columns=None, output_columns=None, dependencies=None, description='') classmethod

Register an indicator function with metadata.

Parameters:

Name Type Description Default
name Optional[str]

Indicator name. If None, uses function name.

None
required_columns Optional[Set[str]]

Set of required DataFrame columns (case-insensitive). Example: {'close'}, {'high', 'low', 'close'}

None
output_columns Optional[List[str]]

List of column names this indicator will add. If None, defaults to [name].

None
dependencies Optional[List[str]]

List of other indicator names that must be applied first.

None
description str

Human-readable description of the indicator.

''

Returns:

Type Description
Callable

Decorator function that registers the indicator.

Example

@IndicatorRegistry.register( ... name='RSI', ... required_columns={'close'}, ... output_columns=['RSI'], ... description="Relative Strength Index" ... ) ... def rsi(df, window=14): ... # calculation ... return df

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def register(
    cls,
    name: Optional[str] = None,
    required_columns: Optional[Set[str]] = None,
    output_columns: Optional[List[str]] = None,
    dependencies: Optional[List[str]] = None,
    description: str = "",
) -> Callable:
    """
    Register an indicator function with metadata.

    Args:
        name: Indicator name. If None, uses function name.
        required_columns: Set of required DataFrame columns (case-insensitive).
            Example: {'close'}, {'high', 'low', 'close'}
        output_columns: List of column names this indicator will add.
            If None, defaults to [name].
        dependencies: List of other indicator names that must be applied first.
        description: Human-readable description of the indicator.

    Returns:
        Decorator function that registers the indicator.

    Example:
        >>> @IndicatorRegistry.register(
        ...     name='RSI',
        ...     required_columns={'close'},
        ...     output_columns=['RSI'],
        ...     description="Relative Strength Index"
        ... )
        ... def rsi(df, window=14):
        ...     # calculation
        ...     return df
    """

    def decorator(func: Callable):
        indicator_name = name or func.__name__

        metadata = IndicatorMetadata(
            name=indicator_name,
            func=func,
            required_columns=required_columns or set(),
            output_columns=output_columns or [indicator_name],
            dependencies=dependencies or [],
            description=description,
        )

        cls._indicators[indicator_name] = metadata
        return func

    return decorator

get(name) classmethod

Get the indicator function by name.

Parameters:

Name Type Description Default
name str

Name of the indicator

required

Raises:

Type Description
KeyError

If the name is not found in the registry

Returns:

Name Type Description
Callable Callable

Indicator function

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def get(cls, name: str) -> Callable:
    """
    Get the indicator function by name.

    Args:
        name: Name of the indicator

    Raises:
        KeyError: If the name is not found in the registry

    Returns:
        Callable: Indicator function
    """
    if name not in cls._indicators:
        raise KeyError(f"Indicator '{name}' not registered")
    return cls._indicators[name].func

get_metadata(name) classmethod

Get the full metadata for an indicator.

Parameters:

Name Type Description Default
name str

Name of the indicator

required

Raises:

Type Description
KeyError

If the name is not found in the registry

Returns:

Name Type Description
IndicatorMetadata IndicatorMetadata

Metadata object for the indicator

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def get_metadata(cls, name: str) -> IndicatorMetadata:
    """
    Get the full metadata for an indicator.

    Args:
        name: Name of the indicator

    Raises:
        KeyError: If the name is not found in the registry

    Returns:
        IndicatorMetadata: Metadata object for the indicator
    """
    if name not in cls._indicators:
        raise KeyError(f"Indicator '{name}' not registered")
    return cls._indicators[name]

list_all() classmethod

List all registered indicators.

Returns:

Type Description
List[str]

List[str]: List of indicator names

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def list_all(cls) -> List[str]:
    """
    List all registered indicators.

    Returns:
        List[str]: List of indicator names
    """
    return list(cls._indicators.keys())

validate_compatibility(df, indicator_name) classmethod

Check if DataFrame has required columns for indicator.

Performs case-insensitive column checking.

Parameters:

Name Type Description Default
df DataFrame

DataFrame to validate

required
indicator_name str

Name of the indicator to check

required

Returns:

Name Type Description
bool bool

True if DataFrame has all required columns

Raises:

Type Description
KeyError

If indicator is not registered

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def validate_compatibility(cls, df: pd.DataFrame, indicator_name: str) -> bool:
    """
    Check if DataFrame has required columns for indicator.

    Performs case-insensitive column checking.

    Args:
        df: DataFrame to validate
        indicator_name: Name of the indicator to check

    Returns:
        bool: True if DataFrame has all required columns

    Raises:
        KeyError: If indicator is not registered
    """
    if indicator_name not in cls._indicators:
        raise KeyError(f"Indicator '{indicator_name}' not registered")

    metadata = cls._indicators[indicator_name]

    # Case-insensitive column check
    df_columns_lower = {col.lower() for col in df.columns}
    required_lower = {col.lower() for col in metadata.required_columns}

    return required_lower.issubset(df_columns_lower)

get_missing_columns(df, indicator_name) classmethod

Get the set of missing required columns for an indicator.

Parameters:

Name Type Description Default
df DataFrame

DataFrame to check

required
indicator_name str

Name of the indicator

required

Returns:

Type Description
Set[str]

Set[str]: Set of missing column names (from required_columns)

Raises:

Type Description
KeyError

If indicator is not registered

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def get_missing_columns(cls, df: pd.DataFrame, indicator_name: str) -> Set[str]:
    """
    Get the set of missing required columns for an indicator.

    Args:
        df: DataFrame to check
        indicator_name: Name of the indicator

    Returns:
        Set[str]: Set of missing column names (from required_columns)

    Raises:
        KeyError: If indicator is not registered
    """
    if indicator_name not in cls._indicators:
        raise KeyError(f"Indicator '{indicator_name}' not registered")

    metadata = cls._indicators[indicator_name]

    # Case-insensitive column check
    df_columns_lower = {col.lower() for col in df.columns}
    required_lower = {col.lower() for col in metadata.required_columns}

    missing = required_lower - df_columns_lower

    # Map back to original case from metadata
    result = set()
    for req_col in metadata.required_columns:
        if req_col.lower() in missing:
            result.add(req_col)

    return result

apply(name, df, **kwargs) classmethod

Apply the indicator function to the dataframe.

Parameters:

Name Type Description Default
name str

Name of the indicator

required
df DataFrame

Input dataframe

required
**kwargs

Additional keyword arguments to be passed to the indicator function

{}

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with the indicator added

Raises:

Type Description
KeyError

If indicator is not registered

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def apply(cls, name: str, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
    """
    Apply the indicator function to the dataframe.

    Args:
        name: Name of the indicator
        df: Input dataframe
        **kwargs: Additional keyword arguments to be passed to the indicator function

    Returns:
        pd.DataFrame: DataFrame with the indicator added

    Raises:
        KeyError: If indicator is not registered
    """
    indicator_func = cls.get(name)
    return indicator_func(df, **kwargs)

apply_safe(name, df, **kwargs) classmethod

Apply indicator with validation.

Validates that the DataFrame has all required columns before applying the indicator. Raises a descriptive error if columns are missing.

Parameters:

Name Type Description Default
name str

Name of the indicator

required
df DataFrame

Input dataframe

required
**kwargs

Additional keyword arguments to be passed to the indicator function

{}

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with the indicator added

Raises:

Type Description
KeyError

If indicator is not registered

ValueError

If DataFrame is missing required columns

Source code in src/quantrl_lab/data/indicators/registry.py
@classmethod
def apply_safe(cls, name: str, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
    """
    Apply indicator with validation.

    Validates that the DataFrame has all required columns before applying
    the indicator. Raises a descriptive error if columns are missing.

    Args:
        name: Name of the indicator
        df: Input dataframe
        **kwargs: Additional keyword arguments to be passed to the indicator function

    Returns:
        pd.DataFrame: DataFrame with the indicator added

    Raises:
        KeyError: If indicator is not registered
        ValueError: If DataFrame is missing required columns
    """
    if not cls.validate_compatibility(df, name):
        missing = cls.get_missing_columns(df, name)
        raise ValueError(
            f"Cannot apply indicator '{name}': missing required columns {missing}. "
            f"Available columns: {list(df.columns)}"
        )

    return cls.apply(name, df, **kwargs)

Environment Utilities

market_data

detect_column_index(df, candidates)

Detect a column index from a list of candidates (case-insensitive).

Parameters:

Name Type Description Default
df DataFrame

The DataFrame to search.

required
candidates List[str]

List of column names to look for.

required

Returns:

Type Description
Optional[int]

The index of the first matching column, or None if not found.

Source code in src/quantrl_lab/environments/utils/market_data.py
def detect_column_index(df: pd.DataFrame, candidates: List[str]) -> Optional[int]:
    """
    Detect a column index from a list of candidates (case-insensitive).

    Args:
        df: The DataFrame to search.
        candidates: List of column names to look for.

    Returns:
        The index of the first matching column, or None if not found.
    """
    columns = df.columns.tolist()

    # Exact match first
    for candidate in candidates:
        if candidate in columns:
            return columns.index(candidate)

    # Case insensitive match
    columns_lower = [c.lower() for c in columns]
    for candidate in candidates:
        if candidate.lower() in columns_lower:
            return columns_lower.index(candidate.lower())

    return None

auto_detect_price_column(df)

Auto-detect the price column index from a DataFrame using standard naming conventions.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame with price data.

required

Returns:

Name Type Description
int int

Index of the detected price column.

Raises:

Type Description
ValueError

If no suitable price column is found.

Source code in src/quantrl_lab/environments/utils/market_data.py
def auto_detect_price_column(df: pd.DataFrame) -> int:
    """
    Auto-detect the price column index from a DataFrame using standard
    naming conventions.

    Args:
        df: Input DataFrame with price data.

    Returns:
        int: Index of the detected price column.

    Raises:
        ValueError: If no suitable price column is found.
    """
    columns = df.columns.tolist()

    # Priority order for price column detection
    price_candidates = [
        "close",
        "Close",
        "CLOSE",
        "price",
        "Price",
        "PRICE",
        "adj_close",
        "Adj Close",
        "ADJ_CLOSE",
        "adjusted_close",
        "Adjusted_Close",
    ]

    # Use the helper to find it
    idx = detect_column_index(df, price_candidates)

    if idx is not None:
        return idx

    # If no obvious price column found, check partial matches as fallback
    for i, col in enumerate(columns):
        col_lower = col.lower()
        if any(candidate.lower() in col_lower for candidate in ["close", "price"]):
            return i

    raise ValueError(
        f"Could not auto-detect price column. Available columns: {columns}. "
        f"Please ensure your DataFrame has a column named 'close', 'price', or similar."
    )

calc_trend(prices)

Calculate the trend strength of a price series using linear regression.

Parameters:

Name Type Description Default
prices ndarray

Array of price data.

required

Returns:

Name Type Description
float float

The calculated trend strength (slope / max_price). Returns 0.0 if not enough data.

Source code in src/quantrl_lab/environments/utils/market_data.py
def calc_trend(prices: np.ndarray) -> float:
    """
    Calculate the trend strength of a price series using linear
    regression.

    Args:
        prices (np.ndarray): Array of price data.

    Returns:
        float: The calculated trend strength (slope / max_price).
               Returns 0.0 if not enough data.
    """
    if len(prices) < 2:
        return 0.0

    x = np.arange(len(prices))
    slope, _ = np.polyfit(x, prices, 1)

    max_price = np.max(prices)
    if max_price > 1e-9:
        return slope / max_price

    return 0.0