2024년 1월 9일
Towards System 2 Reasoning in LLMs: Learning How to Think With Meta Chain-of-Though
(Violet Xiang, Charlie Snell, Kanishk Gandhi, Alon Albalak, Anikait Singh, Chase Blagden, Duy Phung, Rafael Rafailov, Nathan Lile, Dakota Mahan, Louis Castricato, Jan-Philipp Franken, Nick Haber, Chelsea Finn)
We propose a novel framework, Meta Chain-of-Thought (Meta-CoT), which extends traditional Chain-of-Thought (CoT) by explicitly modeling the underlying reasoning required to arrive at a particular CoT. We present empirical evidence from state-of-the-art models exhibiting behaviors consistent with in-context search, and explore methods for producing Meta-CoT via process supervision, synthetic data generation, and search algorithms. Finally, we outline a concrete pipeline for training a model to produce Meta-CoTs, incorporating instruction tuning with linearized search traces and reinforcement learning post-training. Finally, we discuss open research questions, including scaling laws, verifier roles, and the potential for discovering novel reasoning algorithms. This work provides a theoretical and practical roadmap to enable Meta-CoT in LLMs, paving the way for more powerful and human-like reasoning in artificial intelligence.
추론 능력에 대한 방대한 탐색. 기본적인 아이디어는 CoT에 대한 CoT가 필요하다는 것입니다. CoT가 문제에 대한 해결 과정이라고 한다면 이 해결 과정을 선택하기 위한 과정이 필요하다는 것이죠. 그런데 해결 과정을 상세하게 기술한 데이터도 드물지만 그 특정한 해결 과정에 도달하기 위한 과정이 포함된 데이터는 더 찾아보기 어렵죠. 리포트에서는 이 부분을 채우기 위해서 탐색을 모델에 내재화하는 것이 필요하다는 주장을 합니다.
추론 능력에 탐색이 필요한지는 여전히 논쟁적인 주제죠. 여기서는 MCTS 기반 탐색 결과가 o1이나 R1의 추론 과정과 비슷하다는 이야기도 합니다. 다만 Gemini Flash Thinking이나 QwQ는 좀 다르다는 이야기도 하는군요. 사실 추론 능력에 도달하기 위한 방법이 하나일 필요도 없긴 합니다. (https://x.com/denny_zhou/status/1870551510741811644)
정확하게는 모델이 내부적으로 탐색을 할 수 있어야 한다는 문제에 대해서 그것을 탐색을 통해 얻은 결과를 사용해 학습시킬 것인가 혹은 스스로 탐색을 하도록 하는 어떤 레시피를 찾을 것인가의 문제라고 할 수 있겠죠.
An extensive exploration of reasoning abilities. The basic idea is that we need a CoT for CoT itself. If we consider CoT as a detailed solution process for a problem, we need a process to choose the specific approach or method for that solution. However, while data containing detailed solution steps is already scarce, data that includes the exploration process leading to a specific solution method is even rarer. The report argues that to address this gap, we need to internalize exploration capabilities within the model.
Whether search is necessary for reasoning abilities remains a debatable topic. This report suggests that search results based on MCTS are similar to the reasoning processes of models like o1 or R1. However, it also notes that approaches like Gemini Flash Thinking or QwQ show different patterns. In fact, there's no reason to assume there's only one way to achieve reasoning capabilities. (https://x.com/denny_zhou/status/1870551510741811644)
More precisely, regarding the issue of models needing internal search capabilities, the question is whether to train the model using results obtained through search, or to find a recipe for training the model to perform searches on its own.
#reasoning #rl #mcts #search
Grokking at the Edge of Numerical Stability
(Lucas Prieto, Melih Barsbey, Pedro A.M. Mediano, Tolga Birdal)
Grokking, the sudden generalization that occurs after prolonged overfitting, is a surprising phenomenon challenging our understanding of deep learning. Although significant progress has been made in understanding grokking, the reasons behind the delayed generalization and its dependence on regularization remain unclear. In this work, we argue that without regularization, grokking tasks push models to the edge of numerical stability, introducing floating point errors in the Softmax function, which we refer to as Softmax Collapse (SC). We demonstrate that SC prevents grokking and that mitigating SC enables grokking without regularization. Investigating the root cause of SC, we find that beyond the point of overfitting, the gradients strongly align with what we call the na"ive loss minimization (NLM) direction. This component of the gradient does not alter the model's predictions but decreases the loss by scaling the logits, typically by scaling the weights along their current direction. We show that this scaling of the logits explains the delay in generalization characteristic of grokking and eventually leads to SC, halting further learning. To validate our hypotheses, we introduce two key contributions that address the challenges in grokking tasks: StableMax, a new activation function that prevents SC and enables grokking without regularization, and $\perp$Grad, a training algorithm that promotes quick generalization in grokking tasks by preventing NLM altogether. These contributions provide new insights into grokking, elucidating its delayed generalization, reliance on regularization, and the effectiveness of existing grokking-inducing methods. Code for this paper is available at https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability.
Grokking이 일어나려면 Regularization이 필요한 이유에 대한 분석. 모델이 100% 정확도를 나타내는 상황에서도 Loss를 낮추기 위해 Logit의 크기를 계속 키우도록 학습되고 Logit이 일정 이상 커지면 Softmax의 그래디언트가 0이 되기 때문이라는 분석.
This paper analyzes why regularization is necessary for grokking to occur. The authors argue that even when a model achieves 100% accuracy, it continues to be trained to increase logit sizes in order to reduce the loss. When the logit scale exceeds a certain threshold, the gradient of the softmax function becomes zero, preventing further learning.
#grokking #optimizer
rStar-Math: Small LLMs Can Master Math Reasoning with Self-Evolved Deep Thinking
(Xinyu Guan, Li Lyna Zhang, Yifei Liu, Ning Shang, Youran Sun, Yi Zhu, Fan Yang, Mao Yang)
We present rStar-Math to demonstrate that small language models (SLMs) can rival or even surpass the math reasoning capability of OpenAI o1, without distillation from superior models. rStar-Math achieves this by exercising "deep thinking" through Monte Carlo Tree Search (MCTS), where a math policy SLM performs test-time search guided by an SLM-based process reward model. rStar-Math introduces three innovations to tackle the challenges in training the two SLMs: (1) a novel code-augmented CoT data sythesis method, which performs extensive MCTS rollouts to generate step-by-step verified reasoning trajectories used to train the policy SLM; (2) a novel process reward model training method that avoids na"ive step-level score annotation, yielding a more effective process preference model (PPM); (3) a self-evolution recipe in which the policy SLM and PPM are built from scratch and iteratively evolved to improve reasoning capabilities. Through 4 rounds of self-evolution with millions of synthesized solutions for 747k math problems, rStar-Math boosts SLMs' math reasoning to state-of-the-art levels. On the MATH benchmark, it improves Qwen2.5-Math-7B from 58.8% to 90.0% and Phi3-mini-3.8B from 41.4% to 86.4%, surpassing o1-preview by +4.5% and +0.9%. On the USA Math Olympiad (AIME), rStar-Math solves an average of 53.3% (8/15) of problems, ranking among the top 20% the brightest high school math students. Code and data will be available at https://github.com/microsoft/rStar.
코드 생성과 검증을 기반으로 MCTS. 각 스텝에 대한 평가는 정답으로 이어지는 스텝과 오답으로 이어지는 스텝에 대한 Preference로 학습된 Process Reward Model을 사용. MCTS 롤아웃으로 정책 모델을 학습하고, 이 정책 모델로 생성한 롤아웃으로 PRM을 학습하고, PRM을 기반으로 MCTS를 수행하는 방법으로 자기 개선.
MCTS based on code generation and verification. Step evaluation is performed using a process reward model trained on preferences between steps that lead to correct or incorrect answers. The method employs a self-improvement cycle, first train a policy model using MCTS rollouts, then train a PRM using rollouts generated by this policy model, and perform MCTS using this PRM.
#reasoning #mcts #search