import warnings
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
warnings.simplefilter(action="ignore", category=FutureWarning)
[docs]def plot_trajectories(df, title, ax=None, show_legend=True):
unique_foragers = df["forager"].unique()
if ax is None:
_, ax = plt.subplots()
for forager in unique_foragers:
df_forager = df[df["forager"] == forager]
(line,) = ax.plot(df_forager["x"], df_forager["y"])
init_loc = df_forager[df_forager.time == 0]
# use same color as the trajectory
ax.scatter(
init_loc["x"],
init_loc["y"],
color=line.get_color(),
s=50,
marker="o",
label=f"Forager {forager}: initial",
)
final_loc = df_forager[df_forager.time == df_forager.time.max()]
ax.scatter(
final_loc["x"],
final_loc["y"],
color=line.get_color(),
s=50,
marker="x",
label=f"Forager {forager}: final",
)
ax.set_aspect("equal")
ax.invert_yaxis()
ax.set_axis_off()
if show_legend:
ax.legend()
ax.set_title(f"Trajectories: {title}", fontsize=16)
return ax
[docs]def plot_distances(distances, title=""):
distances_list = [
distance
for sublist in distances
for df in sublist
for distance in df["distance"].tolist()
]
distances_list = list(filter(lambda x: x != 0, distances_list))
fig = px.histogram(
distances_list,
template="presentation",
width=700,
title=f"Distances: {title}",
labels={"value": "inter-bird distance (grid units)"},
opacity=0.4,
nbins=60,
)
fig.update_layout(showlegend=False)
return fig
[docs]def animate_foragers(
sim,
width=800,
height=800,
point_size=15,
plot_rewards=True,
plot_traces=False,
plot_visibility=0,
plot_proximity=0,
plot_communicate=0,
plot_velocity=0,
trace_multiplier=10,
visibility_multiplier=10,
proximity_multiplier=10,
communicate_multiplier=10,
velocity_multiplier=10,
color_by_state=False,
produce_object=False,
autosize=False,
):
if plot_rewards:
rew = sim.rewardsDF.copy()
if color_by_state:
rew["state"] = "reward"
else:
rew["forager"] = "reward"
df = pd.concat([sim.foragersDF, rew])
else:
df = sim.foragersDF.copy()
if plot_traces:
tr = sim.tracesDF.copy()
tr["forager"] = "trace"
df = pd.concat([df, tr])
if plot_visibility > 0:
vis = sim.visibilityDF.copy()
vis = vis[vis["forager"] == plot_visibility]
vis["who"] = vis["forager"]
vis["forager"] = "visibility"
df = pd.concat([df, vis])
if plot_proximity > 0:
prox = sim.proximityDF.copy()
prox = prox[prox["forager"] == plot_proximity]
prox["who"] = prox["forager"]
prox["forager"] = "proximity"
df = pd.concat([df, prox])
if plot_communicate > 0:
com = sim.communicatesDF.copy()
com = com[com["forager"] == plot_communicate]
com["who"] = com["forager"]
com["forager"] = "communicate"
com = com.reset_index(drop=True)
df = df.reset_index(drop=True)
df = pd.concat([com, df], axis=0, ignore_index=True, verify_integrity=True)
if plot_velocity > 0:
vel = sim.velocity_scoresDF.copy()
vel = vel[vel["forager"] == plot_velocity]
vel["who"] = vel["forager"]
vel["forager"] = "velocity"
vel = vel.reset_index(drop=True)
df = df.reset_index(drop=True)
df = pd.concat([vel, df], axis=0, ignore_index=True, verify_integrity=True)
if not color_by_state:
fig = px.scatter(df, x="x", y="y", animation_frame="time", color="forager")
else:
fig = px.scatter(df, x="x", y="y", animation_frame="time", color="state")
fig.update_layout(
template="presentation",
xaxis=dict(
range=[-1, sim.grid_size + 1],
showgrid=False,
zeroline=False,
ticks="",
showticklabels=False,
title="",
),
yaxis=dict(
range=[-1, sim.grid_size + 1],
showgrid=False,
zeroline=False,
ticks="",
showticklabels=False,
title="",
scaleanchor="x", # This makes the y-axis scale match the x-axis
),
autosize=autosize,
width=width,
height=height,
)
fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0
fig.update_traces(marker=dict(size=point_size))
for t in range(0, len(fig.frames)):
for trace in fig.frames[t].data:
if trace.name.isdigit():
trace.marker.symbol = "square"
trace.marker.size = 14
trace.marker.line = dict(width=3)
trace.marker.opacity = 0.8
if plot_rewards:
fig.update_traces(
showlegend=False,
marker=dict(symbol="square", color="yellow"),
selector=dict(name="reward"),
)
for frame in fig.frames:
for trace in frame.data:
if trace.name == "reward":
trace.marker.symbol = "square"
trace.marker.color = "yellow"
trace.showlegend = False
if plot_velocity > 0:
fig.update_traces(showlegend=False, selector=dict(name="velocity"))
for t in range(0, len(fig.frames)):
selected_rows = vel[(vel["time"] == t + 1)]
for trace in fig.frames[t].data:
if trace.name == "velocity":
trace.marker.symbol = "circle"
trace.marker.color = "red"
trace.showlegend = False
trace.marker.size = (
selected_rows["velocity_score"] * velocity_multiplier
)
trace.marker.opacity = 0.3
if plot_communicate > 0:
fig.update_traces(showlegend=False, selector=dict(name="communicate"))
for t in range(0, len(fig.frames)):
selected_rows = com[(com["time"] == t + 1)]
for trace in fig.frames[t].data:
if trace.name == "communicate":
trace.marker.symbol = "circle"
trace.marker.color = "red"
trace.showlegend = False
trace.marker.size = (
selected_rows["communicate"] * communicate_multiplier
)
trace.marker.opacity = 0.3
if plot_traces:
fig.update_traces(showlegend=False, selector=dict(name="trace"))
for t in range(0, len(fig.frames)):
selected_rows = sim.tracesDF[sim.tracesDF["time"] == t + 1]
for trace in fig.frames[t].data:
if trace.name == "trace":
trace.marker.symbol = "circle"
trace.marker.color = "orange"
trace.showlegend = False
trace.marker.size = selected_rows["trace"] * trace_multiplier
trace.marker.opacity = 0.3
if plot_visibility > 0:
fig.update_traces(showlegend=False, selector=dict(name="visibility"))
for t in range(0, len(fig.frames)):
selected_rows = vis[(vis["time"] == t + 1)]
for trace in fig.frames[t].data:
if trace.name == "visibility":
trace.marker.symbol = "circle"
trace.marker.color = "gray"
trace.showlegend = False
trace.marker.size = (
selected_rows["visibility"] * visibility_multiplier
)
trace.marker.opacity = 0.3
if plot_proximity > 0:
color_scale = "Greys"
fig.update_traces(showlegend=False, selector=dict(name="proximity"))
for t in range(0, len(fig.frames)):
selected_rows = prox[(prox["time"] == t + 1)]
for trace in fig.frames[t].data:
if trace.name == "proximity":
trace.marker.symbol = "circle"
# trace.marker.color = "red"
trace.showlegend = False
trace.marker.color = (
selected_rows["proximity"] * proximity_multiplier
)
trace.marker.colorscale = color_scale
trace.marker.size = 5
trace.marker.opacity = 0.6
fig = go.Figure(
data=fig["frames"][0]["data"],
frames=fig["frames"],
layout=fig.layout,
)
if produce_object:
return fig
else:
fig.show()
[docs]def visualise_forager_predictors(
outcome: torch.Tensor,
predictors: List[torch.Tensor],
predictor_names: List[str],
outcome_name: str,
sampling_rate: float = 1.0,
titles=None,
):
def sample_tensor(tensor, sampling_rate):
sample_size = int(sampling_rate * len(tensor))
return np.random.choice(tensor, size=sample_size, replace=False)
def custom_copy(tr):
if isinstance(tr, torch.Tensor):
return tr.clone()
else:
return tr.copy()
if sampling_rate != 1:
outcome_sub = sample_tensor(outcome, sampling_rate)
predictors_sub = [
sample_tensor(predictor, sampling_rate) for predictor in predictors
]
else:
outcome_sub = custom_copy(outcome)
predictors_sub = [custom_copy(predictor) for predictor in predictors]
df = pd.DataFrame({"outcome": outcome_sub})
for name, predictor_sub in zip(predictor_names, predictors_sub):
df[name] = predictor_sub
for idx, name in enumerate(predictor_names):
fig = px.scatter(
df,
x=name,
y="outcome",
opacity=0.3,
template="presentation",
width=700,
)
title = titles[idx] if titles else name
fig.update_layout(
title=title.capitalize(),
xaxis_title=name,
yaxis_title=outcome_name,
)
fig.update_traces(marker={"size": 4})
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)
fig.show()
[docs]def plot_coefs(
selected_samples: Dict[str, torch.Tensor],
title: str,
nbins=20,
ann_start_y=100,
ann_break_y=50,
generate_object=False,
):
for key in selected_samples.keys():
selected_samples[key] = selected_samples[key].flatten()
samplesDF = pd.DataFrame(selected_samples)
samplesDF_medians = samplesDF.median(axis=0)
fig_coefs = px.histogram(
samplesDF,
template="presentation",
opacity=0.4,
labels={"variable": "coefficient"},
width=700,
title=title,
nbins=nbins,
marginal="rug",
barmode="overlay",
)
color_scale = px.colors.qualitative.Alphabet
for i, median_value in enumerate(samplesDF_medians):
color = color_scale[i % len(color_scale)]
fig_coefs.add_vline(
x=median_value,
line_dash="dash",
line_color=color,
name=f"Median ({samplesDF_medians.iloc[i]})",
)
fig_coefs.add_annotation(
x=samplesDF_medians.iloc[i],
y=ann_start_y
+ ann_break_y * i, # Adjust the vertical position of the label
text=f"{samplesDF_medians.iloc[i]:.2f}",
showarrow=False,
bordercolor="black",
borderwidth=0.5,
bgcolor="white",
opacity=0.8,
)
fig_coefs.update_layout(
legend=dict(
orientation="h", # Horizontal legend
yanchor="top", # Anchor the legend to the top of the container
y=-0.25, # Position it below the plot
xanchor="center", # Center it horizontally
x=0.5, # Center it horizontally in the plot
title_text="Legend", # Optional: Title for the legend
)
)
fig_coefs.update_traces(marker=dict(line=dict(width=2, color="Black")))
fig_coefs.show()
if generate_object:
return fig_coefs