Source code for collab.foraging.toolkit.animate_foragers

import warnings
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch

warnings.simplefilter(action="ignore", category=FutureWarning)


[docs]def plot_trajectories(df, title, ax=None, show_legend=True): unique_foragers = df["forager"].unique() if ax is None: _, ax = plt.subplots() for forager in unique_foragers: df_forager = df[df["forager"] == forager] (line,) = ax.plot(df_forager["x"], df_forager["y"]) init_loc = df_forager[df_forager.time == 0] # use same color as the trajectory ax.scatter( init_loc["x"], init_loc["y"], color=line.get_color(), s=50, marker="o", label=f"Forager {forager}: initial", ) final_loc = df_forager[df_forager.time == df_forager.time.max()] ax.scatter( final_loc["x"], final_loc["y"], color=line.get_color(), s=50, marker="x", label=f"Forager {forager}: final", ) ax.set_aspect("equal") ax.invert_yaxis() ax.set_axis_off() if show_legend: ax.legend() ax.set_title(f"Trajectories: {title}", fontsize=16) return ax
[docs]def plot_distances(distances, title=""): distances_list = [ distance for sublist in distances for df in sublist for distance in df["distance"].tolist() ] distances_list = list(filter(lambda x: x != 0, distances_list)) fig = px.histogram( distances_list, template="presentation", width=700, title=f"Distances: {title}", labels={"value": "inter-bird distance (grid units)"}, opacity=0.4, nbins=60, ) fig.update_layout(showlegend=False) return fig
[docs]def animate_foragers( sim, width=800, height=800, point_size=15, plot_rewards=True, plot_traces=False, plot_visibility=0, plot_proximity=0, plot_communicate=0, plot_velocity=0, trace_multiplier=10, visibility_multiplier=10, proximity_multiplier=10, communicate_multiplier=10, velocity_multiplier=10, color_by_state=False, produce_object=False, autosize=False, ): if plot_rewards: rew = sim.rewardsDF.copy() if color_by_state: rew["state"] = "reward" else: rew["forager"] = "reward" df = pd.concat([sim.foragersDF, rew]) else: df = sim.foragersDF.copy() if plot_traces: tr = sim.tracesDF.copy() tr["forager"] = "trace" df = pd.concat([df, tr]) if plot_visibility > 0: vis = sim.visibilityDF.copy() vis = vis[vis["forager"] == plot_visibility] vis["who"] = vis["forager"] vis["forager"] = "visibility" df = pd.concat([df, vis]) if plot_proximity > 0: prox = sim.proximityDF.copy() prox = prox[prox["forager"] == plot_proximity] prox["who"] = prox["forager"] prox["forager"] = "proximity" df = pd.concat([df, prox]) if plot_communicate > 0: com = sim.communicatesDF.copy() com = com[com["forager"] == plot_communicate] com["who"] = com["forager"] com["forager"] = "communicate" com = com.reset_index(drop=True) df = df.reset_index(drop=True) df = pd.concat([com, df], axis=0, ignore_index=True, verify_integrity=True) if plot_velocity > 0: vel = sim.velocity_scoresDF.copy() vel = vel[vel["forager"] == plot_velocity] vel["who"] = vel["forager"] vel["forager"] = "velocity" vel = vel.reset_index(drop=True) df = df.reset_index(drop=True) df = pd.concat([vel, df], axis=0, ignore_index=True, verify_integrity=True) if not color_by_state: fig = px.scatter(df, x="x", y="y", animation_frame="time", color="forager") else: fig = px.scatter(df, x="x", y="y", animation_frame="time", color="state") fig.update_layout( template="presentation", xaxis=dict( range=[-1, sim.grid_size + 1], showgrid=False, zeroline=False, ticks="", showticklabels=False, title="", ), yaxis=dict( range=[-1, sim.grid_size + 1], showgrid=False, zeroline=False, ticks="", showticklabels=False, title="", scaleanchor="x", # This makes the y-axis scale match the x-axis ), autosize=autosize, width=width, height=height, ) fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0 fig.update_traces(marker=dict(size=point_size)) for t in range(0, len(fig.frames)): for trace in fig.frames[t].data: if trace.name.isdigit(): trace.marker.symbol = "square" trace.marker.size = 14 trace.marker.line = dict(width=3) trace.marker.opacity = 0.8 if plot_rewards: fig.update_traces( showlegend=False, marker=dict(symbol="square", color="yellow"), selector=dict(name="reward"), ) for frame in fig.frames: for trace in frame.data: if trace.name == "reward": trace.marker.symbol = "square" trace.marker.color = "yellow" trace.showlegend = False if plot_velocity > 0: fig.update_traces(showlegend=False, selector=dict(name="velocity")) for t in range(0, len(fig.frames)): selected_rows = vel[(vel["time"] == t + 1)] for trace in fig.frames[t].data: if trace.name == "velocity": trace.marker.symbol = "circle" trace.marker.color = "red" trace.showlegend = False trace.marker.size = ( selected_rows["velocity_score"] * velocity_multiplier ) trace.marker.opacity = 0.3 if plot_communicate > 0: fig.update_traces(showlegend=False, selector=dict(name="communicate")) for t in range(0, len(fig.frames)): selected_rows = com[(com["time"] == t + 1)] for trace in fig.frames[t].data: if trace.name == "communicate": trace.marker.symbol = "circle" trace.marker.color = "red" trace.showlegend = False trace.marker.size = ( selected_rows["communicate"] * communicate_multiplier ) trace.marker.opacity = 0.3 if plot_traces: fig.update_traces(showlegend=False, selector=dict(name="trace")) for t in range(0, len(fig.frames)): selected_rows = sim.tracesDF[sim.tracesDF["time"] == t + 1] for trace in fig.frames[t].data: if trace.name == "trace": trace.marker.symbol = "circle" trace.marker.color = "orange" trace.showlegend = False trace.marker.size = selected_rows["trace"] * trace_multiplier trace.marker.opacity = 0.3 if plot_visibility > 0: fig.update_traces(showlegend=False, selector=dict(name="visibility")) for t in range(0, len(fig.frames)): selected_rows = vis[(vis["time"] == t + 1)] for trace in fig.frames[t].data: if trace.name == "visibility": trace.marker.symbol = "circle" trace.marker.color = "gray" trace.showlegend = False trace.marker.size = ( selected_rows["visibility"] * visibility_multiplier ) trace.marker.opacity = 0.3 if plot_proximity > 0: color_scale = "Greys" fig.update_traces(showlegend=False, selector=dict(name="proximity")) for t in range(0, len(fig.frames)): selected_rows = prox[(prox["time"] == t + 1)] for trace in fig.frames[t].data: if trace.name == "proximity": trace.marker.symbol = "circle" # trace.marker.color = "red" trace.showlegend = False trace.marker.color = ( selected_rows["proximity"] * proximity_multiplier ) trace.marker.colorscale = color_scale trace.marker.size = 5 trace.marker.opacity = 0.6 fig = go.Figure( data=fig["frames"][0]["data"], frames=fig["frames"], layout=fig.layout, ) if produce_object: return fig else: fig.show()
[docs]def visualise_forager_predictors( outcome: torch.Tensor, predictors: List[torch.Tensor], predictor_names: List[str], outcome_name: str, sampling_rate: float = 1.0, titles=None, ): def sample_tensor(tensor, sampling_rate): sample_size = int(sampling_rate * len(tensor)) return np.random.choice(tensor, size=sample_size, replace=False) def custom_copy(tr): if isinstance(tr, torch.Tensor): return tr.clone() else: return tr.copy() if sampling_rate != 1: outcome_sub = sample_tensor(outcome, sampling_rate) predictors_sub = [ sample_tensor(predictor, sampling_rate) for predictor in predictors ] else: outcome_sub = custom_copy(outcome) predictors_sub = [custom_copy(predictor) for predictor in predictors] df = pd.DataFrame({"outcome": outcome_sub}) for name, predictor_sub in zip(predictor_names, predictors_sub): df[name] = predictor_sub for idx, name in enumerate(predictor_names): fig = px.scatter( df, x=name, y="outcome", opacity=0.3, template="presentation", width=700, ) title = titles[idx] if titles else name fig.update_layout( title=title.capitalize(), xaxis_title=name, yaxis_title=outcome_name, ) fig.update_traces(marker={"size": 4}) fig.update_xaxes(showgrid=False) fig.update_yaxes(showgrid=False) fig.show()
[docs]def plot_coefs( selected_samples: Dict[str, torch.Tensor], title: str, nbins=20, ann_start_y=100, ann_break_y=50, generate_object=False, ): for key in selected_samples.keys(): selected_samples[key] = selected_samples[key].flatten() samplesDF = pd.DataFrame(selected_samples) samplesDF_medians = samplesDF.median(axis=0) fig_coefs = px.histogram( samplesDF, template="presentation", opacity=0.4, labels={"variable": "coefficient"}, width=700, title=title, nbins=nbins, marginal="rug", barmode="overlay", ) color_scale = px.colors.qualitative.Alphabet for i, median_value in enumerate(samplesDF_medians): color = color_scale[i % len(color_scale)] fig_coefs.add_vline( x=median_value, line_dash="dash", line_color=color, name=f"Median ({samplesDF_medians.iloc[i]})", ) fig_coefs.add_annotation( x=samplesDF_medians.iloc[i], y=ann_start_y + ann_break_y * i, # Adjust the vertical position of the label text=f"{samplesDF_medians.iloc[i]:.2f}", showarrow=False, bordercolor="black", borderwidth=0.5, bgcolor="white", opacity=0.8, ) fig_coefs.update_layout( legend=dict( orientation="h", # Horizontal legend yanchor="top", # Anchor the legend to the top of the container y=-0.25, # Position it below the plot xanchor="center", # Center it horizontally x=0.5, # Center it horizontally in the plot title_text="Legend", # Optional: Title for the legend ) ) fig_coefs.update_traces(marker=dict(line=dict(width=2, color="Black"))) fig_coefs.show() if generate_object: return fig_coefs