User Guide
WIP
import ciclo
def step(state, batch):
return state + batch
state, history, elapsed = ciclo.loop(
state=0,
dataset=range(10),
tasks={
ciclo.every(1): [step],
},
)
assert state == 45
def step(state, batch):
return state + batch
state, history, elapsed = ciclo.loop(
state=0, # any pytree
dataset=range(10),
tasks={
ciclo.every(1): [step],
},
)
def step(state, batch):
return state + batch
state, history, elapsed = ciclo.loop(
state=0,
dataset=range(10), # any iterable of pytrees
tasks={
ciclo.every(1): [step],
},
)
def step(state, batch):
return state + batch
state, history, elapsed = ciclo.loop(
state=0,
dataset=range(10),
tasks={
ciclo.every(1): [step], # Schedule: List[Callback]
},
)
Schedules
def schedule(elapsed: Elapsed) -> bool
ciclo.every(steps=10) # by steps
ciclo.every(samples=10) # by samples
ciclo.every(time=10) # by time (seconds)
ciclo.after(steps=5).every(steps=20) # by steps with offset
Callbacks
def callback([state, batch, elasped, loop_state])
-> None | state | logs | (logs | None, state | None)
def step(state, batch) -> state
ciclo.keras_bar
ciclo.checkpoint
ciclo.wandb
ciclo.early_stopping
Logs
def step(state, batch):
state = state + batch
logs = ciclo.logs()
logs.add_entry("states", "state", state) # (collection, name, value)
logs.add_entry("inputs", "batch", batch)
return logs, state
state, history, elapsed = ciclo.loop(
state=0,
dataset=range(10),
tasks={
ciclo.every(1): [step], # Schedule: List[Callback]
},
)
states, batches = history.collect("states.state", "inputs.batch")
assert batches = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
states, batches = history.collect("state", "batch") # shorthand
Logs.add_metric # adds an entry to the "metrics" collection
Logs.add_stateful_metric # adds an entry to the "stateful_metrics" collection
Logs.add_loss # adds an entry to the "losses" collection
Logs.add_output # adds an entry to the "outputs" collection
Example: Linear Regression
import numpy as np
import jax
import jax.numpy as jnp
import ciclo
X = np.linspace(0, 1, 100)
Y = 0.8 * X + 0.1 + np.random.normal(0, 0.1, size=X.shape)
def dataset(batch_size):
while True:
idx = np.random.choice(len(X), size=batch_size)
yield X[idx], Y[idx]
state = {"w": 0.0, "b": 0.0}
@jax.jit
def train_step(state, batch):
x, y = batch
def loss_fn(params):
w, b = params
y_pred = w * X + b
return jnp.mean((y - y_pred) ** 2)
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn((state["w"], state["b"]))
# sdg
state = jax.tree_map(lambda x, g: x - 0.1 * g, state, grad)
logs = ciclo.logs()
logs.add_metric("loss", loss)
return logs, state
@jax.jit
def test_step(state):
y_pred = state["w"] * X + state["b"]
loss = jnp.mean((Y - y_pred) ** 2)
logs = ciclo.logs()
logs.add_metric("mse", loss)
return logs
total_steps = 10_000
state, history, elapsed = ciclo.loop(
state=state,
dataset=dataset(batch_size=32),
tasks={
ciclo.every(100): [test_step],
ciclo.every(1): [
train_step,
ciclo.keras_bar(total=total_steps)
],
},
stop=total_steps,
)